Skip to content

[Question] How to Understand the thr_layout and val_layout in make_tiled_copy_V_interleave_trans? #44

@Maximilianxu

Description

@Maximilianxu
template <typename T>
__device__ __forceinline__ auto make_tiled_copy_V_interleave_trans() {
  using namespace cute;  // NOLINT

  auto thr_layout = make_layout(make_shape(Int<32>{}, Int<4>{}, Int<1>{}, Int<1>{}),
                                make_stride(Int<4>{}, Int<1>{}, Int<0>{}, Int<0>{}));
  auto val_layout = make_layout(make_shape(Int<2>{}, Int<2>{}, Int<1>{}, Int<4>{}),
                                make_stride(Int<1>{}, Int<2>{}, Int<4>{}, Int<4>{}));

  auto tiler = make_tile(Int<64>{}, Int<32>{});

  auto tiled_copy =
      make_tiled_copy(Copy_Atom<SM75_U16x8_LDSM_T, T>{}, thr_layout, val_layout, tiler);

  return tiled_copy;
}

I found this piece very amazing. I made some experiments and am still confused.
The following code can compile:

#include <cute/tensor.hpp>
#include <cute/arch/copy.hpp>
#include <cute/arch/copy_sm75.hpp>
#include <cute/arch/mma_sm89.hpp>
#include <cute/atom/copy_atom.hpp>
#include <cute/atom/copy_traits_sm75.hpp>
#include <cute/atom/mma_traits_sm89.hpp>
#include <cute/atom/mma_traits_sm90_gmma.hpp>
#include <cute/algorithm/copy.hpp>
#include "cute/layout.hpp"

#include <cutlass/float8.h>

using namespace cute;

using Element = cutlass::float_e4m3_t;
using MmaOp = SM89_16x8x32_F32E5M2E4M3F32_TN;

using TiledMma = decltype(make_tiled_mma(
    MmaOp{}, Layout<Shape<_1, _1, _1>>{}, Tile<_64, _16, _32>{}));

namespace cute {

    template <class... Args, class ThrLayout, class ValLayout, class Tiler>
    CUTE_HOST_DEVICE auto make_tiled_copy(Copy_Atom<Args...> const &copy_atom,
                                            ThrLayout const &thr_layout, ValLayout const &val_layout,
                                            Tiler const &tiler) {
        // Take the raked_products to compute the Layout_MN
        // (M,N) -> (thr_idx, val_idx)
        auto layout_mn = raked_product(thr_layout, val_layout);
        // (thr_idx, val_idx) -> (M,N)
        auto layout_tv =
            right_inverse(layout_mn).with_shape(make_shape(size(thr_layout), size(val_layout)));
    
        return make_tiled_copy_impl(copy_atom, layout_tv, tiler);
    }
    
}  // namespace cute

template <typename T>
__device__ __forceinline__ auto make_tiled_copy_V_interleave_trans() {
  using namespace cute;  // NOLINT

  auto thr_layout = make_layout(make_shape(Int<8>{}, Int<4>{}),
                                make_stride(Int<4>{}, Int<1>{}));
  auto val_layout = make_layout(make_shape(Int<2>{}, Int<8>{}),
                                make_stride(Int<1>{}, Int<2>{}));

  auto tiler = make_tile(Int<16>{}, Int<32>{});

  auto tiled_copy =
      make_tiled_copy(Copy_Atom<SM75_U16x8_LDSM_T, T>{}, thr_layout, val_layout, tiler);

  return tiled_copy;
}

using SmemLayoutAtom = GMMA::Layout_K_SW128_Atom<Element>;
using SmemLayout = decltype(tile_to_shape(SmemLayoutAtom{}, Shape<Int<128>, Int<128>>{}));

__global__ void test_ker() {
    extern __shared__ char smem[];
    int tid = threadIdx.x;

    auto s = make_tensor(make_smem_ptr(reinterpret_cast<Element*>(smem)), SmemLayout{});

    auto st = s.compose(make_layout(make_shape(_128{}, _128{}), make_stride(_128{}, _1{})));

    TiledMma tiled_mma;
    auto thr_mma = tiled_mma.get_slice(tid);
    auto mma_s = thr_mma.partition_fragment_A(s);
    auto mma_r = make_tensor<Element>(make_layout(mma_s.shape()));

    auto smem_tiled_copy = make_tiled_copy_V_interleave_trans<Element>();
    auto smem_thr_copy = smem_tiled_copy.get_slice(tid);
    auto s2r_s = smem_thr_copy.partition_S(st);
    auto s2r_r = smem_thr_copy.retile_D(mma_r);

    copy(smem_tiled_copy, s2r_s, s2r_r);
    if (thread0()) {
        print_latex(smem_tiled_copy);
    }
}

int main(void) {
    test_ker<<<1, 1>>>();
    cudaDeviceSynchronize();
}

Once I changed the tiled copy to

  auto thr_layout = make_layout(make_shape(Int<8>{}, Int<4>{}),
                                make_stride(Int<4>{}, Int<1>{}));
  auto val_layout = make_layout(make_shape(Int<4>{}, Int<8>{}),
                                make_stride(Int<1>{}, Int<4>{}));

  auto tiler = make_tile(Int<32>{}, Int<32>{});

I got Copy_Traits: src failed to vectorize into registers error, which is exactly the same error if i use auto smem_tiled_copy = make_tiled_copy_A(Copy_Atom<SM75_U16x8_LDSM_T, Element>{}, tiled_mma);.

How to understand this layout? @reed-lau @shaochangxu

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions