template <typename T>
__device__ __forceinline__ auto make_tiled_copy_V_interleave_trans() {
using namespace cute; // NOLINT
auto thr_layout = make_layout(make_shape(Int<32>{}, Int<4>{}, Int<1>{}, Int<1>{}),
make_stride(Int<4>{}, Int<1>{}, Int<0>{}, Int<0>{}));
auto val_layout = make_layout(make_shape(Int<2>{}, Int<2>{}, Int<1>{}, Int<4>{}),
make_stride(Int<1>{}, Int<2>{}, Int<4>{}, Int<4>{}));
auto tiler = make_tile(Int<64>{}, Int<32>{});
auto tiled_copy =
make_tiled_copy(Copy_Atom<SM75_U16x8_LDSM_T, T>{}, thr_layout, val_layout, tiler);
return tiled_copy;
}
I found this piece very amazing. I made some experiments and am still confused.
The following code can compile:
#include <cute/tensor.hpp>
#include <cute/arch/copy.hpp>
#include <cute/arch/copy_sm75.hpp>
#include <cute/arch/mma_sm89.hpp>
#include <cute/atom/copy_atom.hpp>
#include <cute/atom/copy_traits_sm75.hpp>
#include <cute/atom/mma_traits_sm89.hpp>
#include <cute/atom/mma_traits_sm90_gmma.hpp>
#include <cute/algorithm/copy.hpp>
#include "cute/layout.hpp"
#include <cutlass/float8.h>
using namespace cute;
using Element = cutlass::float_e4m3_t;
using MmaOp = SM89_16x8x32_F32E5M2E4M3F32_TN;
using TiledMma = decltype(make_tiled_mma(
MmaOp{}, Layout<Shape<_1, _1, _1>>{}, Tile<_64, _16, _32>{}));
namespace cute {
template <class... Args, class ThrLayout, class ValLayout, class Tiler>
CUTE_HOST_DEVICE auto make_tiled_copy(Copy_Atom<Args...> const ©_atom,
ThrLayout const &thr_layout, ValLayout const &val_layout,
Tiler const &tiler) {
// Take the raked_products to compute the Layout_MN
// (M,N) -> (thr_idx, val_idx)
auto layout_mn = raked_product(thr_layout, val_layout);
// (thr_idx, val_idx) -> (M,N)
auto layout_tv =
right_inverse(layout_mn).with_shape(make_shape(size(thr_layout), size(val_layout)));
return make_tiled_copy_impl(copy_atom, layout_tv, tiler);
}
} // namespace cute
template <typename T>
__device__ __forceinline__ auto make_tiled_copy_V_interleave_trans() {
using namespace cute; // NOLINT
auto thr_layout = make_layout(make_shape(Int<8>{}, Int<4>{}),
make_stride(Int<4>{}, Int<1>{}));
auto val_layout = make_layout(make_shape(Int<2>{}, Int<8>{}),
make_stride(Int<1>{}, Int<2>{}));
auto tiler = make_tile(Int<16>{}, Int<32>{});
auto tiled_copy =
make_tiled_copy(Copy_Atom<SM75_U16x8_LDSM_T, T>{}, thr_layout, val_layout, tiler);
return tiled_copy;
}
using SmemLayoutAtom = GMMA::Layout_K_SW128_Atom<Element>;
using SmemLayout = decltype(tile_to_shape(SmemLayoutAtom{}, Shape<Int<128>, Int<128>>{}));
__global__ void test_ker() {
extern __shared__ char smem[];
int tid = threadIdx.x;
auto s = make_tensor(make_smem_ptr(reinterpret_cast<Element*>(smem)), SmemLayout{});
auto st = s.compose(make_layout(make_shape(_128{}, _128{}), make_stride(_128{}, _1{})));
TiledMma tiled_mma;
auto thr_mma = tiled_mma.get_slice(tid);
auto mma_s = thr_mma.partition_fragment_A(s);
auto mma_r = make_tensor<Element>(make_layout(mma_s.shape()));
auto smem_tiled_copy = make_tiled_copy_V_interleave_trans<Element>();
auto smem_thr_copy = smem_tiled_copy.get_slice(tid);
auto s2r_s = smem_thr_copy.partition_S(st);
auto s2r_r = smem_thr_copy.retile_D(mma_r);
copy(smem_tiled_copy, s2r_s, s2r_r);
if (thread0()) {
print_latex(smem_tiled_copy);
}
}
int main(void) {
test_ker<<<1, 1>>>();
cudaDeviceSynchronize();
}
Once I changed the tiled copy to
auto thr_layout = make_layout(make_shape(Int<8>{}, Int<4>{}),
make_stride(Int<4>{}, Int<1>{}));
auto val_layout = make_layout(make_shape(Int<4>{}, Int<8>{}),
make_stride(Int<1>{}, Int<4>{}));
auto tiler = make_tile(Int<32>{}, Int<32>{});
I got Copy_Traits: src failed to vectorize into registers error, which is exactly the same error if i use auto smem_tiled_copy = make_tiled_copy_A(Copy_Atom<SM75_U16x8_LDSM_T, Element>{}, tiled_mma);.
How to understand this layout? @reed-lau @shaochangxu
I found this piece very amazing. I made some experiments and am still confused.
The following code can compile:
Once I changed the tiled copy to
I got
Copy_Traits: src failed to vectorize into registerserror, which is exactly the same error if i useauto smem_tiled_copy = make_tiled_copy_A(Copy_Atom<SM75_U16x8_LDSM_T, Element>{}, tiled_mma);.How to understand this layout? @reed-lau @shaochangxu