| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| #pragma once |
|
|
| #include <cute/config.hpp> |
| #include <cute/pointer_base.hpp> |
| #include <cute/pointer_sparse.hpp> |
| #include <cute/container/array_subbyte.hpp> |
| #include <cute/numeric/integral_constant.hpp> |
| #include <cute/numeric/numeric_types.hpp> |
|
|
| namespace cute |
| { |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| template <class NewT> |
| CUTE_HOST_DEVICE constexpr |
| auto |
| recast_ptr(void* ptr) |
| { |
| if constexpr (is_sparse<NewT>::value) { |
| constexpr int sparsity = NewT::sparsity; |
| NewT* p = reinterpret_cast<NewT*>(ptr); |
| return make_sparse_ptr<sparsity>(p); |
| } else |
| if constexpr (cute::is_subbyte_v<NewT>) { |
| return subbyte_iterator<NewT>(ptr); |
| } else { |
| return reinterpret_cast<NewT*>(ptr); |
| } |
| CUTE_GCC_UNREACHABLE; |
| } |
|
|
| template <class NewT> |
| CUTE_HOST_DEVICE constexpr |
| auto |
| recast_ptr(void const* ptr) |
| { |
| if constexpr (is_sparse<NewT>::value) { |
| constexpr int sparsity = NewT::sparsity; |
| NewT const* p = reinterpret_cast<NewT const*>(ptr); |
| return make_sparse_ptr<sparsity>(p); |
| } else |
| if constexpr (cute::is_subbyte_v<NewT>) { |
| return subbyte_iterator<NewT const>(ptr); |
| } else { |
| return reinterpret_cast<NewT const*>(ptr); |
| } |
| CUTE_GCC_UNREACHABLE; |
| } |
|
|
| |
| template <class NewT> |
| CUTE_HOST_DEVICE constexpr |
| auto |
| recast_ptr(decltype(nullptr)) { |
| return recast_ptr<NewT>(static_cast<NewT*>(nullptr)); |
| } |
|
|
| |
| |
| |
|
|
| template <class P> |
| struct gmem_ptr : iter_adaptor<P, gmem_ptr<P>> { |
| using iter_adaptor<P, gmem_ptr<P>>::iter_adaptor; |
| }; |
|
|
| template <class T, class = void> |
| struct is_gmem : false_type {}; |
| template <class P> |
| struct is_gmem<gmem_ptr<P>> : true_type {}; |
| template <class P> |
| struct is_gmem<P, void_t<typename P::iterator>> : is_gmem<typename P::iterator> {}; |
| template <class P> |
| constexpr bool is_gmem_v = is_gmem<P>::value; |
|
|
| |
| template <class Iterator> |
| CUTE_HOST_DEVICE constexpr |
| auto |
| make_gmem_ptr(Iterator iter) { |
| if constexpr (is_gmem<Iterator>::value) { |
| return iter; |
| } else { |
| return gmem_ptr<Iterator>{iter}; |
| } |
| CUTE_GCC_UNREACHABLE; |
| } |
|
|
| |
| template <class T> |
| CUTE_HOST_DEVICE constexpr |
| auto |
| make_gmem_ptr(void* ptr) { |
| return make_gmem_ptr(recast_ptr<T>(ptr)); |
| } |
|
|
| |
| template <class T> |
| CUTE_HOST_DEVICE constexpr |
| auto |
| make_gmem_ptr(void const* ptr) { |
| return make_gmem_ptr(recast_ptr<T const>(ptr)); |
| } |
|
|
| |
| template <class T> |
| CUTE_HOST_DEVICE constexpr |
| auto |
| make_gmem_ptr(decltype(nullptr)) { |
| return make_gmem_ptr(recast_ptr<T>(nullptr)); |
| } |
|
|
| |
| template <class NewT, class P> |
| CUTE_HOST_DEVICE constexpr |
| auto |
| recast_ptr(gmem_ptr<P> const& ptr) { |
| return make_gmem_ptr(recast_ptr<NewT>(ptr.get())); |
| } |
|
|
| |
| |
| |
|
|
| template <class P> |
| struct smem_ptr : iter_adaptor<P, smem_ptr<P>> { |
| using iter_adaptor<P, smem_ptr<P>>::iter_adaptor; |
| }; |
|
|
| template <class T, class = void> |
| struct is_smem : false_type {}; |
| template <class P> |
| struct is_smem<smem_ptr<P>> : true_type {}; |
| template <class P> |
| struct is_smem<P, void_t<typename P::iterator>> : is_smem<typename P::iterator> {}; |
| template <class P> |
| constexpr bool is_smem_v = is_smem<P>::value; |
|
|
| |
| template <class Iterator> |
| CUTE_HOST_DEVICE constexpr |
| auto |
| make_smem_ptr(Iterator iter) { |
| if constexpr (is_smem<Iterator>::value) { |
| return iter; |
| } else { |
| return smem_ptr<Iterator>{iter}; |
| } |
| CUTE_GCC_UNREACHABLE; |
| } |
|
|
| |
| template <class Iterator, class Swizzle> |
| CUTE_HOST_DEVICE constexpr |
| auto |
| make_smem_ptr(Iterator ptr, Swizzle sw) |
| { |
| return make_swizzle_ptr(make_smem_ptr(ptr), sw); |
| } |
|
|
| |
| template <class T> |
| CUTE_HOST_DEVICE constexpr |
| auto |
| make_smem_ptr(void* ptr) { |
| return make_smem_ptr(recast_ptr<T>(ptr)); |
| } |
|
|
| |
| template <class T> |
| CUTE_HOST_DEVICE constexpr |
| auto |
| make_smem_ptr(void const* ptr) { |
| return make_smem_ptr(recast_ptr<T const>(ptr)); |
| } |
|
|
| |
| template <class T> |
| CUTE_HOST_DEVICE constexpr |
| auto |
| make_smem_ptr(decltype(nullptr)) { |
| return make_smem_ptr(recast_ptr<T>(nullptr)); |
| } |
|
|
| |
| template <class NewT, class P> |
| CUTE_HOST_DEVICE constexpr |
| auto |
| recast_ptr(smem_ptr<P> const& ptr) { |
| return make_smem_ptr(recast_ptr<NewT>(ptr.get())); |
| } |
|
|
| |
| |
| |
|
|
| template <class P> |
| struct rmem_ptr : iter_adaptor<P, rmem_ptr<P>> { |
| using iter_adaptor<P, rmem_ptr<P>>::iter_adaptor; |
| }; |
|
|
| |
| template <class T, class = void> |
| struct is_rmem : bool_constant<not (is_gmem<T>::value || is_smem<T>::value)> {}; |
| template <class P> |
| struct is_rmem<rmem_ptr<P>> : true_type {}; |
| template <class P> |
| constexpr bool is_rmem_v = is_rmem<P>::value; |
|
|
| |
| template <class Iterator> |
| CUTE_HOST_DEVICE constexpr |
| auto |
| make_rmem_ptr(Iterator iter) { |
| if constexpr (is_rmem<Iterator>::value) { |
| return iter; |
| } else { |
| return rmem_ptr<Iterator>{iter}; |
| } |
| CUTE_GCC_UNREACHABLE; |
| } |
|
|
| |
| template <class T> |
| CUTE_HOST_DEVICE constexpr |
| auto |
| make_rmem_ptr(void* ptr) { |
| return make_rmem_ptr(recast_ptr<T>(ptr)); |
| } |
|
|
| |
| template <class T> |
| CUTE_HOST_DEVICE constexpr |
| auto |
| make_rmem_ptr(void const* ptr) { |
| return make_rmem_ptr(recast_ptr<T const>(ptr)); |
| } |
|
|
| |
| template <class NewT, class P> |
| CUTE_HOST_DEVICE constexpr |
| auto |
| recast_ptr(rmem_ptr<P> const& ptr) { |
| return make_rmem_ptr(recast_ptr<NewT>(ptr.get())); |
| } |
|
|
|
|
| |
| |
| |
|
|
| template <class T> |
| struct tmem_ptr |
| { |
| using value_type = remove_cv_t<T>; |
| using element_type = T; |
| using reference = T; |
|
|
| |
| static constexpr int32_t OffsetShift = log_2(trait_ratio(sizeof_bits<uint32_t>{}, sizeof_bits<T>{})); |
|
|
| CUTE_HOST_DEVICE constexpr |
| tmem_ptr(uint32_t addr = 0) : addr_(addr) {} |
|
|
| CUTE_HOST_DEVICE constexpr |
| uint32_t const& get() const { |
| return addr_; |
| } |
| CUTE_HOST_DEVICE constexpr |
| uint32_t& get() { |
| return addr_; |
| } |
|
|
| template <class T_ = T> |
| CUTE_HOST_DEVICE constexpr |
| value_type operator*() const { |
| static_assert(dependent_false<T_>, "Attempting to dereference a tmem_ptr, want raw_pointer_cast() for address instead?"); |
| return value_type{}; |
| } |
|
|
| CUTE_HOST_DEVICE constexpr |
| reference operator[](uint32_t const& i) const { return *(*this + i); } |
|
|
| CUTE_HOST_DEVICE constexpr |
| tmem_ptr operator+(uint32_t const& i) const { |
| |
| return {addr_ + rotr(i, OffsetShift)}; |
| } |
|
|
| |
| |
| |
| union { |
| uint32_t addr_; |
| struct { |
| uint16_t col_; |
| uint8_t dp_; |
| uint8_t idx_; |
| |
| }; |
| }; |
| }; |
|
|
| template <class T, class = void> |
| struct is_tmem : false_type {}; |
| template <class T> |
| struct is_tmem<tmem_ptr<T>> : true_type {}; |
| template <class P> |
| struct is_tmem<P, void_t<typename P::iterator>> : is_tmem<typename P::iterator> {}; |
| template <class P> |
| constexpr bool is_tmem_v = is_tmem<P>::value; |
|
|
| template <class T> |
| CUTE_HOST_DEVICE constexpr |
| tmem_ptr<T> |
| make_tmem_ptr(uint32_t addr = 0) { |
| return tmem_ptr<T>(addr); |
| } |
|
|
| template <class T> |
| CUTE_HOST_DEVICE constexpr |
| uint32_t |
| raw_pointer_cast(tmem_ptr<T> const& ptr) { |
| return ptr.get(); |
| } |
|
|
| |
| |
| template <class NewT, class T> |
| CUTE_HOST_DEVICE constexpr |
| auto |
| recast_ptr(tmem_ptr<T> const& ptr) { |
| return tmem_ptr<NewT>{ptr.addr_}; |
| } |
|
|
|
|
| |
| |
| |
|
|
| template <class T> |
| CUTE_HOST_DEVICE void print(gmem_ptr<T> ptr) |
| { |
| printf("gmem_"); print(ptr.get()); |
| } |
|
|
| template <class T> |
| CUTE_HOST_DEVICE void print(smem_ptr<T> ptr) |
| { |
| printf("smem_"); print(ptr.get()); |
| } |
|
|
| template <class T> |
| CUTE_HOST_DEVICE void print(rmem_ptr<T> ptr) |
| { |
| printf("rmem_"); print(ptr.get()); |
| } |
|
|
|
|
| template <class T> |
| CUTE_HOST_DEVICE void print(tmem_ptr<T> ptr) |
| { |
| printf("tmem_["); print(sizeof_bits<T>::value); printf("b](0x%04x.%04x)", ptr.addr_ >> 16, ptr.addr_ & 0xFFFF); |
| } |
|
|
|
|
| #if !defined(__CUDACC_RTC__) |
| template <class T> |
| CUTE_HOST std::ostream& operator<<(std::ostream& os, gmem_ptr<T> ptr) |
| { |
| return os << "gmem_[" << int(sizeof_bits<iter_value_t<T>>::value) << "b]"; |
| } |
|
|
| template <class T> |
| CUTE_HOST std::ostream& operator<<(std::ostream& os, smem_ptr<T> ptr) |
| { |
| return os << "smem_[" << int(sizeof_bits<iter_value_t<T>>::value) << "b]"; |
| } |
|
|
| template <class T> |
| CUTE_HOST std::ostream& operator<<(std::ostream& os, rmem_ptr<T> ptr) |
| { |
| return os << "rmem_[" << int(sizeof_bits<iter_value_t<T>>::value) << "b]"; |
| } |
|
|
|
|
| template <class T> |
| CUTE_HOST std::ostream& operator<<(std::ostream& os, tmem_ptr<T> ptr) |
| { |
| return os << "tmem_[" << int(sizeof_bits<T>::value) << "b](" << ptr.addr_ << ")"; |
| } |
|
|
| #endif |
|
|
| } |
|
|