From 839cf8974c0f351a0aa11e0c76f05f6eb14ffcb8 Mon Sep 17 00:00:00 2001 From: feifei14119 Date: Sat, 8 Feb 2025 11:23:02 +0800 Subject: [PATCH 1/3] [flatmm] implement framwork --- example/ck_tile/18_flatmm/CMakeLists.txt | 9 + example/ck_tile/18_flatmm/README.md | 35 + example/ck_tile/18_flatmm/flatmm_basic.cpp | 123 +++ example/ck_tile/18_flatmm/flatmm_basic.hpp | 100 +++ .../ck_tile/18_flatmm/run_flatmm_example.inc | 728 ++++++++++++++++++ example/ck_tile/CMakeLists.txt | 1 + include/ck_tile/ops/flatmm.hpp | 47 ++ .../ops/flatmm/kernel/flatmm_kernel.hpp | 621 +++++++++++++++ .../flatmm_pipeline_agmem_bgmem_creg_v1.hpp | 296 +++++++ ...ine_agmem_bgmem_creg_v1_default_policy.hpp | 517 +++++++++++++ ...tmm_universal_pipeline_ag_bg_cr_policy.hpp | 470 +++++++++++ 11 files changed, 2947 insertions(+) create mode 100644 example/ck_tile/18_flatmm/CMakeLists.txt create mode 100644 example/ck_tile/18_flatmm/README.md create mode 100644 example/ck_tile/18_flatmm/flatmm_basic.cpp create mode 100644 example/ck_tile/18_flatmm/flatmm_basic.hpp create mode 100644 example/ck_tile/18_flatmm/run_flatmm_example.inc create mode 100644 include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp create mode 100644 include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp create mode 100644 include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp create mode 100644 include/ck_tile/ops/flatmm/pipeline/flatmm_universal_pipeline_ag_bg_cr_policy.hpp diff --git a/example/ck_tile/18_flatmm/CMakeLists.txt b/example/ck_tile/18_flatmm/CMakeLists.txt new file mode 100644 index 0000000000..dc52e049d3 --- /dev/null +++ b/example/ck_tile/18_flatmm/CMakeLists.txt @@ -0,0 +1,9 @@ +add_executable(tile_example_flatmm_basic EXCLUDE_FROM_ALL flatmm_basic.cpp) + +set(EXAMPLE_FLATMM_COMPILE_OPTIONS) +list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) +list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -Wno-unused-variable -Wno-unused-parameter) +list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -Wno-unused-local-typedef) +#list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -ggdb -g -O0 -v -save-temps) +list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -DFEIFEI_DEBUG=1) +target_compile_options(tile_example_flatmm_basic PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS}) diff --git a/example/ck_tile/18_flatmm/README.md b/example/ck_tile/18_flatmm/README.md new file mode 100644 index 0000000000..beaac785fc --- /dev/null +++ b/example/ck_tile/18_flatmm/README.md @@ -0,0 +1,35 @@ +# FLATMM Matrix Multiplication + +This folder contains example for FLATMM using ck_tile tile-programming implementation. Currently, it only supports the basic feature of the CK Tile FLATMM, but creates the placeholders for the future support on different FLATMM pipeline and different FLATMM modules. In the near future, we will gradually migrate all the FLATMM features from old CK to CK Tile. + +## build +``` +# in the root of ck_tile +mkdir build && cd build +# you can replace with the appropriate architecture (for example gfx90a or gfx942) or leave it blank +sh ../script/cmake-ck-dev.sh ../ +# The basic pipeline method on the flatmm calculation +make tile_example_flatmm_basic -j +``` +This will result in an executable `build/bin/tile_example_flatmm_basic` + +## example +``` +args: + -b batch size (default:1) + -m m dimension (default:1024) + -n n dimension (default:2048) + -k k dimension (default:64) + -a_layout Tensor A data layout (default: R) + -b_layout Tensor B data layout (default: R) + -c_layout Tensor C data layout (default: R) + -stride_a Tensor A stride (default:0) + -stride_b Tensor B stride (default:0) + -stride_c Tensor C stride (default:0) + -v 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:2) + -e Absolute error tolerance (default:1e-5) + -prec data type. fp16/bf16/fp8/bf8 (default:fp16) + -warmup number of iterations before benchmark the kernel (default:10) + -repeat number of iterations to benchmark the kernel (default:100) + -timer gpu:gpu timer, cpu:cpu timer (default:gpu) +``` diff --git a/example/ck_tile/18_flatmm/flatmm_basic.cpp b/example/ck_tile/18_flatmm/flatmm_basic.cpp new file mode 100644 index 0000000000..a46d61c8f4 --- /dev/null +++ b/example/ck_tile/18_flatmm/flatmm_basic.cpp @@ -0,0 +1,123 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include +#include +#include +#include +#include + +#include "ck_tile/host.hpp" +#include "flatmm_basic.hpp" + +template +float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_config& s) +{ + // The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part. + constexpr bool kPadM = false; + constexpr bool kPadN = false; + constexpr bool kPadK = false; + + constexpr bool kTilePermute = false; + // The rank and permutation will also be generate out by the CodeGen part. + constexpr ck_tile::index_t kOutputRank = 2; + + constexpr int kBlockPerCu = 1; + + // This part comes from the Codegen + constexpr ck_tile::index_t M_Tile = 128; + constexpr ck_tile::index_t N_Tile = 128; + constexpr ck_tile::index_t K_Tile = 32; + + constexpr ck_tile::index_t M_Warp = 2; + constexpr ck_tile::index_t N_Warp = 2; + constexpr ck_tile::index_t K_Warp = 1; + + constexpr ck_tile::index_t M_Warp_Tile = 32; + constexpr ck_tile::index_t N_Warp_Tile = 32; + constexpr ck_tile::index_t K_Warp_Tile = 8; + + // Whether doing the CShuffle (transpose before the global memory), depending on the output + // layout. + constexpr bool CShuffleEpilogue = + std::is_same_v; + + using CodegenGemmShape = + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; + + using TilePartitioner = ck_tile::GemmTile2DPartitioner; + + using CodegenGemmTraits = + ck_tile::TileGemmTraits; + using CodegenPipelineProblem = ck_tile:: + GemmPipelineProblem; + using CodegenFlatmmPolicy = ck_tile::UniversalFlatmmPipelineAgBgCrPolicy; + + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + + using CodegenFlatmmPipeline = + ck_tile::FlatmmPipelineAGmemBGmemCRegV1; + + // ToDo: Will add the codegen part to test different pipeline policies in GEMM. + // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. + using Kernel = ck_tile::FlatmmKernel; + + auto kargs = Kernel::MakeKernelArgs(args); + + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + constexpr dim3 blocks = Kernel::BlockSize(); + +#if FEIFEI_DEBUG + printf("[FEIFEI] --- flatmm_calc() ---\n"); + printf("[FEIFEI] BlockPerCu = %d\n", static_cast(kBlockPerCu)); + printf("[FEIFEI] BlockTile M = %d\n", static_cast(M_Tile)); + printf("[FEIFEI] BlockTile N = %d\n", static_cast(N_Tile)); + printf("[FEIFEI] BlockTile K = %d\n", static_cast(K_Tile)); + printf("[FEIFEI] WavePerBlock M = %d\n", static_cast(M_Warp)); + printf("[FEIFEI] WavePerBlock N = %d\n", static_cast(N_Warp)); + printf("[FEIFEI] WavePerBlock K = %d\n", static_cast(K_Warp)); + printf("[FEIFEI] WaveTile M = %d\n", static_cast(M_Warp_Tile)); + printf("[FEIFEI] WaveTile N = %d\n", static_cast(N_Warp_Tile)); + printf("[FEIFEI] WaveTile K = %d\n", static_cast(K_Warp_Tile)); + printf("[FEIFEI] grids = [%d, %d, %d]\n", grids.x, grids.y, grids.z); + printf("[FEIFEI] blocks = [%d, %d, %d]\n", blocks.x, blocks.y, blocks.z); +#endif + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args:" + << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } + + float ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + + return ave_time; +} + +#include "run_flatmm_example.inc" + +int main(int argc, char* argv[]) { return !run_flatmm_example(argc, argv); } diff --git a/example/ck_tile/18_flatmm/flatmm_basic.hpp b/example/ck_tile/18_flatmm/flatmm_basic.hpp new file mode 100644 index 0000000000..19d0d362ef --- /dev/null +++ b/example/ck_tile/18_flatmm/flatmm_basic.hpp @@ -0,0 +1,100 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/flatmm.hpp" + +#define CK_TILE_PIPELINE_COMPUTE 1 +#define CK_TILE_PIPELINE_MEMORY 2 + +#ifndef CK_TILE_PIPELINE_DEFAULT +#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE +#endif + +#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) +#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrMem +#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrMem +#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Interwave +#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE) +#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV3 +#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV3 +#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave +#else +#error "unsupported CK_TILE_PIPELINE_DEFAULT value" +#endif + +template +struct GemmBasicTypeConfig; + +template <> +struct GemmBasicTypeConfig +{ + using ADataType = ck_tile::half_t; + using BDataType = ck_tile::half_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; + // ToDo: Add more bias config to support different categories of GEMM. +}; + +template +struct DataTypeTraits; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp32"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp64"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp16"; +}; + +using Types = GemmBasicTypeConfig; + +// Specific type aliases for easy access +using ADataType = Types::ADataType; +using BDataType = Types::BDataType; +using AccDataType = Types::AccDataType; +using CDataType = Types::CDataType; + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser + .insert("m", "128", "m dimension") // 128, 3840 + .insert("n", "128", "n dimension") // 128, 4096 + .insert("k", "64", "k dimension") // 64, 2048 + .insert("a_layout", "R", "A tensor data layout - Row by default") + .insert("b_layout", "R", "B tensor data layout - Row by default") + .insert("c_layout", "R", "C tensor data layout - Row by default") + .insert("stride_a", "0", "Tensor A stride") + .insert("stride_b", "0", "Tensor B stride") + .insert("stride_c", "0", "Tensor C stride") + .insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU") + .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") + .insert("warmup", "50", "number of iterations before benchmark the kernel") + .insert("repeat", "100", "number of iterations to benchmark the kernel") + .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") + .insert("split_k", "1", "splitK value"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +// host API +float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_config& s); diff --git a/example/ck_tile/18_flatmm/run_flatmm_example.inc b/example/ck_tile/18_flatmm/run_flatmm_example.inc new file mode 100644 index 0000000000..3ea7fb3765 --- /dev/null +++ b/example/ck_tile/18_flatmm/run_flatmm_example.inc @@ -0,0 +1,728 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +#pragma once + +template +static constexpr inline auto is_row_major(Layout layout_) +{ + return ck_tile::bool_constant, + ck_tile::tensor_layout::gemm::RowMajor>>{}; +} + +// mfma_type, 0:32x32, 1:16x16 +template +auto shuffle_b(const ck_tile::HostTensor& t, std::string mfma_dtype, int mfma_type = 0) +{ + assert(t.get_lengths().size() == 2); + int n_ = t.get_lengths()[0]; + int k_ = t.get_lengths()[1]; + printf("[FF] shuffle_b: mfma_dtype = %s, mfma_type = %d, n_ = %d, k_ = %d\n", + mfma_dtype.c_str(), + mfma_type, + n_, + k_); + if((mfma_dtype == "bf16" || mfma_dtype == "fp16") && mfma_type == 0) + { + ck_tile::HostTensor t_view({n_ / 32, 32, k_ / 16, 2, 8}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); + } + else if((mfma_dtype == "bf16" || mfma_dtype == "fp16") && mfma_type == 1) + { + ck_tile::HostTensor t_view({n_ / 16, 16, k_ / 32, 4, 8}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); + } + else if((mfma_dtype == "int8" || mfma_dtype == "fp8") && mfma_type == 0) + { + ck_tile::HostTensor t_view({n_ / 32, 32, k_ / 32, 2, 16}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); + } + else if((mfma_dtype == "int8" || mfma_dtype == "fp8") && mfma_type == 1) + { + ck_tile::HostTensor t_view({n_ / 16, 16, k_ / 64, 4, 16}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); + } + return t; +} + +auto calculate_rtol_atol(const ck_tile::index_t K, + const ck_tile::index_t kbatch, + const float max_accumulated_value) +{ + using ComputeType = + std::conditional_t; + // Calculate thresholds + const auto rtol = ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(K, kbatch)); + const auto atol = ck_tile::get_absolute_threshold( + max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); + // Calculate error due to split_k accumulation + const auto rtol_split_k = + ck_tile::get_relative_threshold(kbatch); + const auto atol_split_k = ck_tile::get_absolute_threshold( + max_accumulated_value, kbatch); + // Use higher threshold + return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); +} + +template +float invoke_flatmm(ck_tile::DeviceMem& a_m_k_dev_buf, + ck_tile::DeviceMem& b_shuffle_dev_buf, + ck_tile::DeviceMem& c_m_n_dev_buf, + ck_tile::index_t M, + ck_tile::index_t N, + ck_tile::index_t K, + ck_tile::index_t stride_A, + ck_tile::index_t stride_B, + ck_tile::index_t stride_C, + ck_tile::index_t kbatch, + int n_warmup, + int n_repeat +#if FEIFEI_DEBUG + , + ck_tile::DeviceMem& b_k_n_dev_buf, + ck_tile::DeviceMem& dbg_int_buf, + ck_tile::DeviceMem& dbg_fp32_buf, + ck_tile::DeviceMem& dbg_f168_buf +#endif +) +{ + ck_tile::FlatmmHostArgs args; + args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer(); + args.b_shuffle_ptr = b_shuffle_dev_buf.GetDeviceBuffer(); + args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer(); + + args.k_batch = kbatch; + args.M = M; + args.N = N; + args.K = K; + args.stride_A = stride_A; + args.stride_B = stride_B; + args.stride_C = stride_C; + +#if FEIFEI_DEBUG + args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer(); + args.dbg_int_ptr = dbg_int_buf.GetDeviceBuffer(); + args.dbg_fp32_ptr = dbg_fp32_buf.GetDeviceBuffer(); + args.dbg_f168_ptr = dbg_f168_buf.GetDeviceBuffer(); + + printf("[FEIFEI] --- invoke_flatmm: ---\n"); + printf("[FEIFEI] args.M = %d\n", static_cast(args.M)); + printf("[FEIFEI] args.N = %d\n", static_cast(args.N)); + printf("[FEIFEI] args.K = %d\n", static_cast(args.K)); + printf("[FEIFEI] args.stride_A = %d\n", static_cast(args.stride_A)); + printf("[FEIFEI] args.stride_B = %d\n", static_cast(args.stride_B)); + printf("[FEIFEI] args.stride_C = %d\n", static_cast(args.stride_C)); + printf("[FEIFEI] args.k_batch = %d\n", static_cast(args.k_batch)); +#endif + + float ave_time = flatmm_calc( + args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_byte = + sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N; + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_byte / 1.E6 / ave_time; + + std::cout << "Run Flatmm kernel with M =" << M << " N =" << N << " K =" << K + << " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideC =" << stride_C + << " : " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << std::endl; + + return ave_time; +} + +template +int run_flatmm_example_with_layouts(int argc, + char* argv[], + const ALayout a_layout = ALayout{}, + const BLayout b_layout = BLayout{}, + [[maybe_unused]] const CLayout c_layout = CLayout{}) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + ck_tile::index_t M = arg_parser.get_int("m"); + ck_tile::index_t N = arg_parser.get_int("n"); + ck_tile::index_t K = arg_parser.get_int("k"); + + ck_tile::index_t stride_A = arg_parser.get_int("stride_a"); + ck_tile::index_t stride_B = arg_parser.get_int("stride_b"); + ck_tile::index_t stride_C = arg_parser.get_int("stride_c"); + + ck_tile::index_t kbatch = arg_parser.get_int("split_k"); + int n_warmup = arg_parser.get_int("warmup"); + int n_repeat = arg_parser.get_int("repeat"); +#if FEIFEI_DEBUG + n_warmup = 1; + n_repeat = 2; +#endif + + stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout)); + stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout)); + stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{})); + + ck_tile::HostTensor a_m_k( + ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout))); + ck_tile::HostTensor b_k_n( + ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout))); + ck_tile::HostTensor c_m_n_dev_result( + ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); + + // TODO: add different init types + ck_tile::FillUniformDistribution{-5.f, 5.f}(a_m_k); + ck_tile::FillUniformDistribution{-5.f, 5.f}(b_k_n); + + ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); + ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); + ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes()); + +#if FEIFEI_DEBUG + ck_tile::HostTensor dbg_int({M * N * 64}); + ck_tile::HostTensor dbg_fp32({M * N * 64}); + ck_tile::HostTensor dbg_f168({M * N * 64}); + + ck_tile::DeviceMem dbg_int_buf(dbg_int.get_element_space_size_in_bytes()); + ck_tile::DeviceMem dbg_fp32_buf(dbg_fp32.get_element_space_size_in_bytes()); + ck_tile::DeviceMem dbg_f168_buf(dbg_f168.get_element_space_size_in_bytes()); +#endif + + a_m_k_dev_buf.ToDevice(a_m_k.data()); + b_k_n_dev_buf.ToDevice(b_k_n.data()); + c_m_n_dev_buf.SetZero(); + c_m_n_dev_result.SetZero(); + + // do pre-shuffle + std::string mfma = arg_parser.get_str("prec"); + ck_tile::HostTensor b_shuffle_host = shuffle_b(b_k_n, mfma, 1); + ck_tile::DeviceMem b_shuffle_dev_buf(b_shuffle_host.get_element_space_size_in_bytes()); + b_shuffle_dev_buf.ToDevice(b_shuffle_host.data()); + + invoke_flatmm(a_m_k_dev_buf, + b_shuffle_dev_buf, + c_m_n_dev_buf, + M, + N, + K, + stride_A, + stride_B, + stride_C, + kbatch, + n_warmup, + n_repeat +#if FEIFEI_DEBUG + , + b_k_n_dev_buf, + dbg_int_buf, + dbg_fp32_buf, + dbg_f168_buf +#endif + ); + + c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); + bool pass = true; + + if(arg_parser.get_int("v") == 1) + { + ck_tile::HostTensor c_m_n_host_ref( + ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); + c_m_n_host_ref.SetZero(); + + ck_tile::reference_gemm( + a_m_k, b_k_n, c_m_n_host_ref); + const float max_accumulated_value = + *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); + const auto rtol_atol = calculate_rtol_atol(K, kbatch, max_accumulated_value); + pass = ck_tile::check_err(c_m_n_dev_result, + c_m_n_host_ref, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + + std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; + std::cout << "The CPU veification result is:" << (pass ? "correct" : "fail") << std::endl; +#if FEIFEI_DEBUG + // c_ref + { + std::ofstream file("ff_c_cpu_ref.txt"); + int X = static_cast(N); + int Y = static_cast(M); + file << " [c_cpu_ref]: Row = " << Y << ", Col = " << X << std::endl; + + for(int y = 0; y < Y; y++) + { + file << "\n ========== row : [" << y << " / " << Y << "] =========="; + for(int x = 0; x < X; x++) + { + if(x % 64 == 0) + { + file << "\n [" << x << " : " << x + 63 << "]: "; + } + int idx = X * y + x; + file << ck_tile::type_convert(c_m_n_host_ref.mData[idx]) << ", "; + } + } + + file.close(); + } +#endif + } + else if(arg_parser.get_int("v") == 2) + { + ck_tile::HostTensor c_m_n_gpu_ref( + ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); + ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_gpu_ref.get_element_space_size_in_bytes()); + c_m_n_gpu_ref.SetZero(); + c_m_n_gpu_buf_ref.SetZero(); + + ADataType* d_A; + BDataType* d_B; + CDataType* d_C; + + ck_tile::hip_check_error(hipMalloc(&d_A, M * K * sizeof(ADataType))); + ck_tile::hip_check_error(hipMalloc(&d_B, N * K * sizeof(BDataType))); + ck_tile::hip_check_error(hipMalloc(&d_C, M * N * sizeof(CDataType))); + + ck_tile::hip_check_error(hipMemcpy(d_A, + a_m_k_dev_buf.GetDeviceBuffer(), + M * K * sizeof(ADataType), + hipMemcpyHostToDevice)); + ck_tile::hip_check_error(hipMemcpy(d_B, + b_k_n_dev_buf.GetDeviceBuffer(), + N * K * sizeof(BDataType), + hipMemcpyHostToDevice)); + + ck_tile::reference_gemm_gpu(d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C); + + ck_tile::hip_check_error(hipMemcpy(c_m_n_gpu_buf_ref.GetDeviceBuffer(), + d_C, + M * N * sizeof(CDataType), + hipMemcpyDeviceToHost)); + + ck_tile::hip_check_error(hipFree(d_A)); + ck_tile::hip_check_error(hipFree(d_B)); + ck_tile::hip_check_error(hipFree(d_C)); + + c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data()); + const float max_accumulated_value = + *std::max_element(c_m_n_gpu_ref.mData.begin(), c_m_n_gpu_ref.mData.end()); + const auto rtol_atol = calculate_rtol_atol(K, kbatch, max_accumulated_value); + pass = ck_tile::check_err(c_m_n_dev_result, + c_m_n_gpu_ref, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + + std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; + std::cout << "The GPU veification result is: " << (pass ? "correct" : "fail") << std::endl; +#if FEIFEI_DEBUG + // c_ref + { + std::ofstream file("ff_c_gpu_ref.txt"); + int X = static_cast(N); + int Y = static_cast(M); + file << " [c_gpu_ref]: Row = " << Y << ", Col = " << X << std::endl; + + for(int y = 0; y < Y; y++) + { + file << "\n ========== row : [" << y << " / " << Y << "] =========="; + for(int x = 0; x < X; x++) + { + if(x % 64 == 0) + { + file << "\n [" << x << " : " << x + 63 << "]: "; + } + int idx = X * y + x; + file << ck_tile::type_convert(c_m_n_gpu_ref.mData[idx]) << ", "; + } + } + + file.close(); + } +#endif + } + +#if FEIFEI_DEBUG + int GridDimX = 1; + int GridDimY = 1; + int BlockDimX = 64; + int BlockDimY = 4; + int DbgCnt = 64; + int BlockSize = BlockDimX * BlockDimY; + // a_host + { + std::ofstream file("ff_a_host.txt"); + int X = static_cast(K); + int Y = static_cast(M); + file << " [a_host]: Row = " << Y << ", Col = " << X << std::endl; + + for(int y = 0; y < Y; y++) + { + file << "\n ========== row : [" << y << " / " << Y << "] =========="; + for(int x = 0; x < X; x++) + { + int idx = X * y + x; + if(idx % 16 == 0) + { + file << "\n [" << x << " : " << x + 15 << " ]: "; + } + + file << ck_tile::type_convert(a_m_k.mData[idx]) << ", "; + } + } + + file.close(); + } + // b_host + { + std::ofstream file("ff_b_host.txt"); + int X = static_cast(K); + int Y = static_cast(N); + file << " [b_host]: Row = " << Y << ", Col = " << X << std::endl; + + for(int y = 0; y < Y; y++) + { + file << "\n ========== row : [" << y << " / " << Y << "] =========="; + for(int x = 0; x < X; x++) + { + int idx = X * y + x; + if(idx % 16 == 0) + { + file << "\n [" << x << " : " << x + 15 << " ]: "; + } + + file << ck_tile::type_convert(b_k_n.mData[idx]) << ", "; + } + } + + file.close(); + } + // b_shuffle + { + std::ofstream file("ff_b_shuffle_host.txt"); + int X = static_cast(K); + int Y = static_cast(N); + file << " [b_shuffle_host]: Row = " << Y << ", Col = " << X << std::endl; + + for(int y = 0; y < Y; y++) + { + file << "\n ========== row : [" << y << " / " << Y << "] =========="; + for(int x = 0; x < X; x++) + { + int idx = X * y + x; + if(idx % 16 == 0) + { + file << "\n [" << x << " : " << x + 15 << " ]: "; + } + + file << ck_tile::type_convert(b_shuffle_host.mData[idx]) << ", "; + } + } + + file.close(); + } + // c_dev ---> kernel + { + auto c_dev = c_m_n_dev_buf.ToHost(); + std::ofstream file("ff_c_dev_kernel.txt"); + file << " [c_dev]: Grid = [" << GridDimX << ", " << GridDimY << "], Block = " << BlockSize + << std::endl; + + for(int bidy = 0; bidy < GridDimY; bidy++) + { + for(int bidx = 0; bidx < GridDimX; bidx++) + { + file << "\n ========== block : [" << bidx << ", " << bidy << "] =========="; + for(int tid = 0; tid < BlockSize; tid++) + { + int gid = (BlockSize * GridDimX) * bidy + BlockSize * bidx + tid; + + file << "\n [" << tid << "]: "; + for(int i = 0; i < DbgCnt; i++) // multi output per thread + file << ck_tile::type_convert(c_dev.mData[gid * DbgCnt + i]) << ", "; + } + } + } + + file.close(); + } + // c_dev + { + // auto d_dev = d_buf.ToHost(); + auto c_dev = c_m_n_dev_buf.ToHost(); + std::ofstream file("ff_c_dev.txt"); + int X = static_cast(N); + int Y = static_cast(M); + file << " [c_dev]: Row = " << Y << ", Col = " << X << std::endl; + + for(int y = 0; y < Y; y++) + { + file << "\n ========== row : [" << y << " / " << Y << "] =========="; + for(int x = 0; x < X; x++) + { + if(x % 64 == 0) + { + file << "\n [" << x << " : " << x + 63 << "]: "; + } + int idx = X * y + x; + file << ck_tile::type_convert(c_dev.mData[idx]) << ", "; + } + } + + file.close(); + } + // dbg_int ---> kernel + { + auto dbg_int_dev = dbg_int_buf.ToHost(); + std::ofstream file("ff_dbg_int_kernel.txt"); + file << " [dbg_int]: Grid = [" << GridDimX << ", " << GridDimY << "], Block = " << BlockSize + << std::endl; + + for(int bidy = 0; bidy < GridDimY; bidy++) + { + for(int bidx = 0; bidx < GridDimX; bidx++) + { + file << "\n ========== block : [" << bidx << ", " << bidy << "] =========="; + for(int tid = 0; tid < BlockSize; tid++) + { + int gid = (BlockSize * GridDimX) * bidy + BlockSize * bidx + tid; + + file << "\n [" << tid << "]: "; + for(int i = 0; i < DbgCnt; i++) + file << ck_tile::type_convert(dbg_int_dev.mData[gid * DbgCnt + i]) + << ", "; + } + } + } + + file.close(); + } + // dbg_int + { + auto dbg_int_dev = dbg_int_buf.ToHost(); + std::ofstream file("ff_dbg_int.txt"); + int X = static_cast(N); + int Y = static_cast(M); + file << " [dbg_int]: Row = " << Y << ", Col = " << X << std::endl; + + for(int m = 0; m < Y; m++) + { + file << "\n ========== row : [" << m << " / " << Y << "] =========="; + for(int n = 0; n < X; n++) + { + if(n % 64 == 0) + { + file << "\n [" << n << " : " << n + 63 << "]: "; + } + int idx = X * m + n; + file << ck_tile::type_convert(dbg_int_dev.mData[idx]) << ", "; + } + } + + file.close(); + } + // dbg_fp32 ---> kernel + { + auto dbg_fp32_dev = dbg_fp32_buf.ToHost(); + std::ofstream file("ff_dbg_fp32_kernel.txt"); + file << " [dbg_fp32]: Grid = [" << GridDimX << ", " << GridDimY + << "], Block = " << BlockSize << std::endl; + + for(int bidy = 0; bidy < GridDimY; bidy++) + { + for(int bidx = 0; bidx < GridDimX; bidx++) + { + file << "\n ========== block : [" << bidx << ", " << bidy << "] =========="; + for(int tid = 0; tid < BlockSize; tid++) + { + int gid = (BlockSize * GridDimX) * bidy + BlockSize * bidx + tid; + + file << "\n [" << tid << "]: "; + for(int i = 0; i < DbgCnt; i++) + file << ck_tile::type_convert(dbg_fp32_dev.mData[gid * DbgCnt + i]) + << ", "; + } + } + } + + file.close(); + } + // dbg_fp32 + { + auto dbg_fp32_dev = dbg_fp32_buf.ToHost(); + std::ofstream file("ff_dbg_fp32.txt"); + int X = static_cast(N); + int Y = static_cast(M); + file << " [dbg_fp32]: Row = " << Y << ", Col = " << X << std::endl; + + for(int m = 0; m < Y; m++) + { + file << "\n ========== row : [" << m << " / " << Y << "] =========="; + for(int n = 0; n < X; n++) + { + if(n % 64 == 0) + { + file << "\n [" << n << " : " << n + 63 << "]: "; + } + int idx = X * m + n; + file << ck_tile::type_convert(dbg_fp32_dev.mData[idx]) << ", "; + } + } + + file.close(); + } + // dbg_fp16 ---> kernel + { + auto dbg_fp16_dev = dbg_f168_buf.ToHost(); + std::ofstream file("ff_dbg_fp16_kernel.txt"); + file << " [dbg_fp16]: Grid = [" << GridDimX << ", " << GridDimY + << "], Block = " << BlockSize << std::endl; + + for(int bidy = 0; bidy < GridDimY; bidy++) + { + for(int bidx = 0; bidx < GridDimX; bidx++) + { + file << "\n ========== block : [" << bidx << ", " << bidy << "] =========="; + for(int tid = 0; tid < BlockSize; tid++) + { + int gid = (BlockSize * GridDimX) * bidy + BlockSize * bidx + tid; + + file << "\n [" << tid << "]: "; + for(int i = 0; i < DbgCnt; i++) + file << ck_tile::type_convert(dbg_fp16_dev.mData[gid * DbgCnt + i]) + << ", "; + } + } + } + + file.close(); + } + // dbg_fp16 + { + auto dbg_fp16_dev = dbg_f168_buf.ToHost(); + std::ofstream file("ff_dbg_fp16.txt"); + int X = static_cast(N); + int Y = static_cast(M); + file << " [dbg_fp16]: Row = " << Y << ", Col = " << X << std::endl; + + for(int m = 0; m < Y; m++) + { + file << "\n ========== row : [" << m << " / " << Y << "] =========="; + for(int n = 0; n < X; n++) + { + if(n % 64 == 0) + { + file << "\n [" << n << " : " << n + 63 << "]: "; + } + int idx = X * m + n; + file << ck_tile::type_convert(dbg_fp16_dev.mData[idx]) << ", "; + } + } + + file.close(); + } + // dbg_fp8 ---> kernel + { + auto dbg_fp8_dev = dbg_f168_buf.ToHost(); + std::ofstream file("ff_dbg_fp8_kernel.txt"); + file << " [dbg_fp8]: Grid = [" << GridDimX << ", " << GridDimY << "], Block = " << BlockSize + << std::endl; + + for(int bidy = 0; bidy < GridDimY; bidy++) + { + for(int bidx = 0; bidx < GridDimX; bidx++) + { + file << "\n ========== block : [" << bidx << ", " << bidy << "] =========="; + for(int tid = 0; tid < BlockSize; tid++) + { + int gid = (BlockSize * GridDimX) * bidy + BlockSize * bidx + tid; + + file << "\n [" << tid << "]: "; + for(int i = 0; i < DbgCnt; i++) + file << ck_tile::type_convert(dbg_fp8_dev.mData[gid * DbgCnt + i]) + << ", "; + } + } + } + + file.close(); + } + // dbg_fp8 + { + auto dbg_fp8_dev = dbg_f168_buf.ToHost(); + std::ofstream file("ff_dbg_fp8.txt"); + int X = static_cast(N); + int Y = static_cast(M); + file << " [dbg_fp8]: Row = " << Y << ", Col = " << X << std::endl; + + for(int m = 0; m < Y; m++) + { + file << "\n ========== row : [" << m << " / " << Y << "] =========="; + for(int n = 0; n < X; n++) + { + if(n % 64 == 0) + { + file << "\n [" << n << " : " << n + 63 << "]: "; + } + int idx = X * m + n; + file << ck_tile::type_convert(dbg_fp8_dev.mData[idx]) << ", "; + } + } + + file.close(); + } +#endif + + return pass; +} + +int run_flatmm_example(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + + std::string a_layout = arg_parser.get_str("a_layout"); + std::string b_layout = arg_parser.get_str("b_layout"); + + if(a_layout == "R" && b_layout == "R") + { + return run_flatmm_example_with_layouts(argc, argv, Row{}, Row{}, Row{}); + } + else if(a_layout == "R" && b_layout == "C") + { + return run_flatmm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } + // TODO: Fixme: with latest changes to GemmPipelineAGmemBGmemCRegV1DefaultPolicy below do not + // work. + // else if(a_layout == "C" && b_layout == "C") + // { + // return run_flatmm_example_with_layouts(argc, argv, Col{}, Col{}, Row{}); + // } + // else if(a_layout == "C" && b_layout == "R") + // { + // return run_flatmm_example_with_layouts(argc, argv, Col{}, Row{}, Row{}); + // } + else + { + throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!"); + } +} diff --git a/example/ck_tile/CMakeLists.txt b/example/ck_tile/CMakeLists.txt index 7f4ba2ed35..88efe0d8d9 100644 --- a/example/ck_tile/CMakeLists.txt +++ b/example/ck_tile/CMakeLists.txt @@ -17,4 +17,5 @@ add_subdirectory(14_moe_smoothquant) add_subdirectory(15_fused_moe) add_subdirectory(16_batched_gemm) add_subdirectory(17_grouped_gemm) +add_subdirectory(18_flatmm) add_subdirectory(35_batched_transpose) diff --git a/include/ck_tile/ops/flatmm.hpp b/include/ck_tile/ops/flatmm.hpp index 82f6d48eda..8c8e37d90f 100644 --- a/include/ck_tile/ops/flatmm.hpp +++ b/include/ck_tile/ops/flatmm.hpp @@ -3,6 +3,53 @@ #pragma once +#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_default_policy.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_one_warp_v1.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_custom_policy.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_default_policy.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp" +#include "ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp" +#include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp" +// #include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp" +#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp" +#include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" +// #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp" +// #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp" +// #include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" +#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp" +#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_impl.hpp" + +#include "ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp" +#include "ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp" +#include "ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp" +#include "ck_tile/ops/flatmm/pipeline/flatmm_universal_pipeline_ag_bg_cr_policy.hpp" + #include "ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp" #include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32.hpp" #include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32_itl.hpp" diff --git a/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp new file mode 100644 index 0000000000..dd52739545 --- /dev/null +++ b/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp @@ -0,0 +1,621 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" + +namespace ck_tile { + +struct FlatmmProblem +{ + CK_TILE_HOST FlatmmProblem() = default; + CK_TILE_HOST FlatmmProblem( + index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, index_t stride_C_) + : M(M_), N(N_), K(K_), stride_A(stride_A_), stride_B(stride_B_), stride_C(stride_C_) + { + } + + index_t M; + index_t N; + index_t K; + index_t stride_A; + index_t stride_B; + index_t stride_C; +}; + +struct FlatmmHostArgs : public FlatmmProblem +{ + CK_TILE_HOST FlatmmHostArgs() = default; + CK_TILE_HOST FlatmmHostArgs(const void* a_ptr_, + const void* b_shuffle_ptr_, + void* c_ptr_, + index_t k_batch_, + index_t M_, + index_t N_, + index_t K_, + index_t stride_A_, + index_t stride_B_, + index_t stride_C_) + : FlatmmProblem(M_, N_, K_, stride_A_, stride_B_, stride_C_), + a_ptr(a_ptr_), + b_shuffle_ptr(b_shuffle_ptr_), + c_ptr(c_ptr_), + k_batch(k_batch_) + { + } + + const void* a_ptr; + const void* b_shuffle_ptr; + void* c_ptr; + index_t k_batch; + +#if FEIFEI_DEBUG + const void* b_ptr; + + void* dbg_int_ptr; + void* dbg_fp32_ptr; + void* dbg_f168_ptr; +#endif +}; + +template +struct FlatmmKernel +{ + using TilePartitioner = remove_cvref_t; + using FlatmmPipeline = remove_cvref_t; + using EpiloguePipeline = remove_cvref_t; + using ALayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + static constexpr index_t KernelBlockSize = FlatmmPipeline::BlockSize; + + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + + static constexpr auto I0 = number<0>(); + static constexpr auto I1 = number<1>(); + static constexpr auto I2 = number<2>(); + + CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch) + { + return TilePartitioner::GridSize(M, N); + // return dim3(TilePartitioner::GridSize(M, N), 1, KBatch); + } + + __host__ static constexpr auto BlockSize() { return dim3(KernelBlockSize); } + + struct FlatmmKernelArgs + { + const void* a_ptr; + const void* b_shuffle_ptr; + void* c_ptr; + index_t M; + index_t N; + index_t K; + index_t stride_A; + index_t stride_B; + index_t stride_C; + index_t KBatch; +#if FEIFEI_DEBUG + const void* b_ptr; + + void* dbg_int_ptr; + void* dbg_fp32_ptr; + void* dbg_f168_ptr; +#endif + }; + + CK_TILE_HOST static constexpr FlatmmKernelArgs MakeKernelArgs(const FlatmmHostArgs& hostArgs) + { + return FlatmmKernelArgs{hostArgs.a_ptr, + hostArgs.b_shuffle_ptr, + hostArgs.c_ptr, + hostArgs.M, + hostArgs.N, + hostArgs.K, + hostArgs.stride_A, + hostArgs.stride_B, + hostArgs.stride_C, + hostArgs.k_batch +#if FEIFEI_DEBUG + , + hostArgs.b_ptr, + hostArgs.dbg_int_ptr, + hostArgs.dbg_fp32_ptr, + hostArgs.dbg_f168_ptr +#endif + }; + } + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + return max(FlatmmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); + } + + struct SplitKBatchOffset + { + __device__ SplitKBatchOffset(const FlatmmKernelArgs& kargs, + const std::size_t k_id = blockIdx.z) + { + constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}); + const index_t K_t = kargs.KBatch * K1; + const index_t KRead = (kargs.K + K_t - 1) / K_t * K1; + + if constexpr(std::is_same_v) + { + a_k_split_offset = k_id * KRead; + } + else if constexpr(std::is_same_v) + { + a_k_split_offset = k_id * KRead * kargs.stride_A; + } + + if constexpr(std::is_same_v) + { + b_k_split_offset = k_id * KRead * kargs.stride_B; + } + else if constexpr(std::is_same_v) + { + b_k_split_offset = k_id * KRead; + } + + if(k_id < static_cast(kargs.KBatch - 1)) + { + splitted_k = KRead; + } + else + { + splitted_k = kargs.K - KRead * (kargs.KBatch - 1); + } + } + + index_t a_k_split_offset; + index_t b_k_split_offset; + index_t splitted_k; + }; + + CK_TILE_HOST static bool IsSupportedArgument(const FlatmmKernelArgs& kargs) + { + if constexpr(EpiloguePipeline::template GetVectorSizeC() % 2 != 0 && + is_any_of::value) + { + if(kargs.k_batch != 1) + { + std::cerr << "Conditions not met for Kbatch >1 !" << std::endl; + return false; + } + } + + if constexpr(std::is_same_v) + { + if(kargs.K % TilePartitioner::KPerBlock != 0 && FlatmmPipeline::kPadK == false) + { + std::cerr << "Can't support K that is not a multiple of KPerBlock" + " without padding!" + << std::endl; + return false; + } + if(kargs.K % FlatmmPipeline::GetVectorSizeA() != 0) + { + std::cerr << "K is not a multiple of vector load size for A tensor!" << std::endl; + return false; + } + } + else + { + if(kargs.M % TilePartitioner::MPerBlock != 0 && FlatmmPipeline::kPadM == false) + { + std::cerr << "Can't support M that is not a multiple of MPerBlock" + " without padding!" + << std::endl; + return false; + } + if(kargs.M % FlatmmPipeline::GetVectorSizeA() != 0) + { + std::cerr << "M is not a multiple of vector load size for A tensor!" << std::endl; + return false; + } + } + + if constexpr(std::is_same_v) + { + if(kargs.N % TilePartitioner::NPerBlock != 0 && FlatmmPipeline::kPadN == false) + { + std::cerr << "Can't support N that is not a multiple of NPerBlock" + " without padding!" + << std::endl; + return false; + } + if(kargs.N % FlatmmPipeline::GetVectorSizeB() != 0) + { + std::cerr << "N is not a multiple of vector load size for B tensor!" << std::endl; + return false; + } + } + else + { + if(kargs.K % TilePartitioner::KPerBlock != 0 && FlatmmPipeline::kPadK == false) + { + std::cerr << "Can't support K that is not a multiple of KPerBlock" + " without padding!" + << std::endl; + return false; + } + if(kargs.K % FlatmmPipeline::GetVectorSizeB() != 0) + { + std::cerr << "K is not a multiple of vector load size for B tensor!" << std::endl; + return false; + } + } + + if constexpr(std::is_same_v) + { + if(kargs.N % TilePartitioner::NPerBlock != 0 && FlatmmPipeline::kPadN == false) + { + std::cerr << "Can't support N that is not a multiple of NPerBlock" + " without padding!" + << std::endl; + return false; + } + if(kargs.N % EpiloguePipeline::template GetVectorSizeC() != 0) + { + std::cerr << "N is not a multiple of vector load size for C tensor!" << std::endl; + return false; + } + } + else + { + if(kargs.M % TilePartitioner::MPerBlock != 0 && FlatmmPipeline::kPadM == false) + { + std::cerr << "Can't support M that is not a multiple of MPerBlock" + " without padding!" + << std::endl; + return false; + } + if(kargs.M % EpiloguePipeline::template GetVectorSizeC() != 0) + { + std::cerr << "M is not a multiple of vector load size for C tensor!" << std::endl; + return false; + } + } + return true; + } + + template + CK_TILE_DEVICE static auto MakeGemmTensorViews(const ADataType* a_ptr, + const BDataType* b_ptr, + CDataType* c_ptr, + const FlatmmKernelArgs& kargs, + const SplitKBatchOffset& splitk_batch_offset) + { + const auto& a_tensor_view = [&]() { + if constexpr(std::is_same_v) + { + return make_naive_tensor_view( + a_ptr, + make_tuple(kargs.M, splitk_batch_offset.splitted_k), + make_tuple(kargs.stride_A, 1), + number{}, + number<1>{}); + } + else + { + return make_naive_tensor_view( + a_ptr, + make_tuple(splitk_batch_offset.splitted_k, kargs.M), + make_tuple(kargs.stride_A, 1), + number{}, + number<1>{}); + } + }(); + + const auto& b_tensor_view = [&]() { + if constexpr(std::is_same_v) + { + return make_naive_tensor_view( + b_ptr, + make_tuple(splitk_batch_offset.splitted_k, kargs.N), + make_tuple(kargs.stride_B, 1), + number{}, + number<1>{}); + } + else + { + return make_naive_tensor_view( + b_ptr, + make_tuple(kargs.N, splitk_batch_offset.splitted_k), + make_tuple(kargs.stride_B, 1), + number{}, + number<1>{}); + } + }(); + + // TODO: enable vector write for C in ColMajor + const auto& c_tensor_view = [&]() { + if constexpr(std::is_same_v) + { + return make_naive_tensor_view( + c_ptr, + make_tuple(kargs.M, kargs.N), + make_tuple(kargs.stride_C, 1), + number()>{}, + number<1>{}); + } + else + { + return make_naive_tensor_view( + c_ptr, + make_tuple(kargs.M, kargs.N), + make_tuple(1, kargs.stride_C), + number<1>{}, + number<1>{}); + } + }(); + + return make_tuple(a_tensor_view, b_tensor_view, c_tensor_view); + } + + template + CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views) + { + const auto& a_pad_view = [&]() { + const auto& a_tensor_view = views.at(I0); + if constexpr(std::is_same_v) + { + return pad_tensor_view(a_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(a_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + }(); + + const auto& b_pad_view = [&]() { + const auto& b_tensor_view = views.at(I1); + if constexpr(std::is_same_v) + { + return pad_tensor_view(b_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(b_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + }(); + + const auto& c_pad_view = [&]() { + const auto& c_tensor_view = views.at(I2); + if constexpr(std::is_same_v) + { + return pad_tensor_view(c_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(c_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + }(); + + return make_tuple(a_pad_view, b_pad_view, c_pad_view); + } + + template + CK_TILE_DEVICE static auto + MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n) + { + const auto& a_pad_view = views.at(I0); + const auto& a_block_window = make_tile_window( + a_pad_view, + make_tuple(number{}, number{}), + {i_m, 0}); + + const auto& b_pad_view = views.at(I1); + const auto& b_block_window = make_tile_window( + b_pad_view, + make_tuple(number{}, number{}), + {i_n, 0}); + + const auto& c_pad_view = views.at(I2); + auto c_block_window = make_tile_window( + c_pad_view, + make_tuple(number{}, number{}), + {i_m, i_n}); + + return make_tuple(a_block_window, b_block_window, c_block_window); + } + + /** + * @brief Runs single GEMM problem cooperatively by whole workgroup. + * + * @param a_ptr input A pointer + * @param b_shuffle_ptr input B pointer + * @param c_ptr output C pointer + * @param kargs GEMM kernel arguments + * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. + * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. + * + * @tparam DstInMemOp Destination memory operation (default: set). + */ + template + CK_TILE_DEVICE static void RunFlatmm(const ADataType* a_ptr, + const BDataType* b_shuffle_ptr, + CDataType* c_ptr, + void* smem_ptr, + const FlatmmKernelArgs& kargs, + const SplitKBatchOffset& splitk_batch_offset, + const index_t block_idx_m, + const index_t block_idx_n +#if FEIFEI_DEBUG + , + const BDataType* b_ptr, + int* dbg_int, + float* dbg_fp32, + short* dbg_f168 +#endif + ) + { + + // Create Flatmm tensor views, pad views and tile windows + const auto& gemm_tensor_views_tuple = MakeGemmTensorViews( + a_ptr, b_shuffle_ptr, c_ptr, kargs, splitk_batch_offset); + // origin layout + // const auto& gemm_tensor_views_tuple = + // MakeGemmTensorViews(a_ptr, b_ptr, c_ptr, kargs, splitk_batch_offset); + + const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); + auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); + + const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k); + + // Run GEMM cooperatively by whole workgroup. + const auto& a_block_window = gemm_tile_windows.at(I0); + const auto& b_block_window = gemm_tile_windows.at(I1); + const auto& c_block_tile = FlatmmPipeline{}.template operator()(a_block_window, + b_block_window, + num_loop, + smem_ptr +#if FEIFEI_DEBUG + , + b_ptr, + dbg_int, + dbg_fp32, + dbg_f168 +#endif + ); + + // feifei TODO: Un-comment bellow once pipeline() is implemented +#if 0 + // Run Epilogue Pipeline + auto& c_block_window = gemm_tile_windows.at(I2); + + constexpr bool is_output_c_reg_transposed = + EpiloguePipeline::IsOutputTransposed() != FlatmmPipeline::IsTransposeC(); + if constexpr((DstInMemOp == memory_operation_enum::set) || (sizeof(CDataType) > 2) || + (FlatmmPipeline::VectorSizeC % 2 == 0 && + std::is_same_v && + is_output_c_reg_transposed)) + { + EpiloguePipeline{} + .template operator()( + c_block_window, c_block_tile); + } +#endif + } + + CK_TILE_DEVICE void operator()(FlatmmKernelArgs kargs) const + { +#if FEIFEI_DEBUG + if(threadIdx.x == 0 && blockIdx.x == 0 && threadIdx.y == 0 && blockIdx.y == 0) + { + printf("[KERNEL] ===== FlatmmKernel() =====\n"); + printf("[KERNEL] blockDim: [%d, %d], gridDim: [%d, %d]\n", + static_cast(blockDim.x), + static_cast(blockDim.y), + static_cast(gridDim.x), + static_cast(gridDim.y)); + printf("[KERNEL] lds = %.3f (KB)\n", GetSmemSize() / 1024.0f); + } + + uint32_t tidx = threadIdx.x; + uint32_t tidy = threadIdx.y; + uint32_t bidx = blockIdx.x; + uint32_t bidy = blockIdx.y; + uint32_t bdmx = blockDim.x; + uint32_t bdmy = blockDim.y; + uint32_t gdmx = gridDim.x; + uint32_t gdmy = gridDim.y; + uint32_t gid = ((bdmx * bdmy) * gdmx) * bidy + (bdmx * bdmy) * bidx + bdmx * tidy + tidx; + + const SplitKBatchOffset _splitk_batch_offset(kargs); + const BDataType* b_ptr = + static_cast(kargs.b_ptr) + _splitk_batch_offset.b_k_split_offset; + + int* dbg_int = static_cast(kargs.dbg_int_ptr); + float* dbg_fp32 = static_cast(kargs.dbg_fp32_ptr); + short* dbg_f168 = static_cast(kargs.dbg_f168_ptr); + + dbg_int[gid] = 1; + dbg_fp32[gid] = 1.0f; + dbg_f168[gid] = ck_tile::type_convert(1.0f); +#endif + + const auto [iM, iN] = TilePartitioner::GetOutputTileIndex(blockIdx.x, blockIdx.y); + const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock); + const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock); + + const SplitKBatchOffset splitk_batch_offset(kargs); + // options + const ADataType* a_ptr = + static_cast(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset; + const BDataType* b_shuffle_ptr = static_cast(kargs.b_shuffle_ptr) + + splitk_batch_offset.b_k_split_offset; + CDataType* c_ptr = static_cast(kargs.c_ptr); + + // allocate LDS + __shared__ char smem_ptr[GetSmemSize()]; + + if(kargs.KBatch == 1) + { + RunFlatmm(a_ptr, + b_shuffle_ptr, + c_ptr, + smem_ptr, + kargs, + splitk_batch_offset, + i_m, + i_n +#if FEIFEI_DEBUG + , + b_ptr, + dbg_int, + dbg_fp32, + dbg_f168 +#endif + ); + } + else + { + RunFlatmm(a_ptr, + b_shuffle_ptr, + c_ptr, + smem_ptr, + kargs, + splitk_batch_offset, + i_m, + i_n +#if FEIFEI_DEBUG + , + b_ptr, + dbg_int, + dbg_fp32, + dbg_f168 +#endif + ); + } + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp new file mode 100644 index 0000000000..4c37e827f3 --- /dev/null +++ b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -0,0 +1,296 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp" + +namespace ck_tile { + +// A Tile Window: global memory +// B Tile Window: global memory +// C Distributed tensor: register +template +struct FlatmmPipelineAGmemBGmemCRegV1 +{ + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + using ALayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + + using BlockGemm = remove_cvref_t())>; + + static constexpr index_t BlockSize = Problem::kBlockSize; + + static constexpr index_t kMPerBlock = BlockGemmShape::kM; + static constexpr index_t kNPerBlock = BlockGemmShape::kN; + static constexpr index_t kKPerBlock = BlockGemmShape::kK; + + static constexpr index_t GetVectorSizeA() { return Problem::VectorSizeA; } + static constexpr index_t GetVectorSizeB() { return Problem::VectorSizeB; } + static constexpr index_t GetVectorSizeC() { return Problem::VectorSizeC; } + + static constexpr bool kPadM = Problem::kPadM; + static constexpr bool kPadN = Problem::kPadN; + static constexpr bool kPadK = Problem::kPadK; + + CK_TILE_HOST_DEVICE static constexpr index_t GetStaticLdsSize() + { + return integer_divide_ceil( + sizeof(ADataType) * + Policy::template MakeALdsBlockDescriptor().get_element_space_size(), + 16) * + 16 + + sizeof(BDataType) * + Policy::template MakeBLdsBlockDescriptor().get_element_space_size(); + } + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() { return Policy::IsTransposeC(); } + + template + CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + index_t num_loop, + void* p_smem +#if FEIFEI_DEBUG + , + const BDataType* b_ptr, + int* dbg_int, + float* dbg_fp32, + short* dbg_f168 +#endif + ) const + { +#if FEIFEI_DEBUG + if(threadIdx.x == 0 && blockIdx.x == 0 && threadIdx.y == 0 && blockIdx.y == 0) + { + printf("[PIPELN] FlatmmPipelinen():\n"); + } + + uint32_t tidx = threadIdx.x; + uint32_t tidy = threadIdx.y; + uint32_t bidx = blockIdx.x; + uint32_t bidy = blockIdx.y; + uint32_t bdmx = blockDim.x; + uint32_t bdmy = blockDim.y; + uint32_t gdmx = gridDim.x; + uint32_t gdmy = gridDim.y; + uint32_t gid = ((bdmx * bdmy) * gdmx) * bidy + (bdmx * bdmy) * bidx + bdmx * tidy + tidx; + + dbg_int[gid] = -1; + dbg_fp32[gid] = -1.0f; + dbg_f168[gid] = ck_tile::type_convert(-1.0f); +#endif + static_assert( + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); + +#if 1 + // feifei TODO: Implement gemm here + return nullptr; +#else + // A tile in LDS + ADataType* p_a_lds = static_cast(p_smem); + + constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor(); + + auto a_lds_block = make_tensor_view(p_a_lds, a_lds_block_desc); + + constexpr index_t a_lds_block_space_size_aligned = + integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16) * + 16; + + // B tile in LDS + BDataType* p_b_lds = static_cast( + static_cast(static_cast(p_smem) + a_lds_block_space_size_aligned)); + + constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor(); + + auto b_lds_block = make_tensor_view(p_b_lds, b_lds_block_desc); + + // A DRAM tile window for load + auto a_copy_dram_window = + make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + a_dram_block_window_tmp.get_window_origin(), + Policy::template MakeADramTileDistribution()); + + // A LDS tile window for store + auto a_copy_lds_window = make_tile_window( + a_lds_block, make_tuple(number{}, number{}), {0, 0}); + + // B DRAM tile window for load + auto b_copy_dram_window = + make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + b_dram_block_window_tmp.get_window_origin(), + Policy::template MakeBDramTileDistribution()); + + // B LDS tile window for store + auto b_copy_lds_window = make_tile_window( + b_lds_block, make_tuple(number{}, number{}), {0, 0}); + + // A LDS tile for block GEMM + auto a_lds_gemm_window = make_tile_window( + a_lds_block, make_tuple(number{}, number{}), {0, 0}); + + // B LDS tile for block GEMM + auto b_lds_gemm_window = make_tile_window( + b_lds_block, make_tuple(number{}, number{}), {0, 0}); + + // Block GEMM + auto block_gemm = BlockGemm(); + + // Acc register tile + auto c_block_tile = decltype(block_gemm(a_lds_gemm_window, b_lds_gemm_window)){}; + + // return c_block_tile; + + // prefetch + // global read 0 + auto a_block_tile = load_tile(a_copy_dram_window); + auto b_block_tile = load_tile(b_copy_dram_window); + + { + // move to 1 + move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + move_tile_window(b_copy_dram_window, {0, kKPerBlock}); + + // initialize C + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + + // LDS write 0 + if constexpr(std::is_same_v) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegBlockDescriptor()); + shuffle_tile(a_shuffle_tmp, a_block_tile); + const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_shuffle_tmp); + store_tile(a_copy_lds_window, a_block_tile_tmp); + } + else + { + store_tile(a_copy_lds_window, tile_elementwise_in(a_element_func, a_block_tile)); + } + + // LDS write 0 + if constexpr(std::is_same_v) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBRegBlockDescriptor()); + shuffle_tile(b_shuffle_tmp, b_block_tile); + const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_shuffle_tmp); + store_tile(b_copy_lds_window, b_block_tile_tmp); + } + else + { + store_tile(b_copy_lds_window, tile_elementwise_in(b_element_func, b_block_tile)); + } + } + + index_t iCounter = num_loop - 1; + while(iCounter > 0) + { + // global read i + 1 + a_block_tile = load_tile(a_copy_dram_window); + b_block_tile = load_tile(b_copy_dram_window); + + block_sync_lds(); + + // GEMM i + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + + block_sync_lds(); + + // move to i + 2 + move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + move_tile_window(b_copy_dram_window, {0, kKPerBlock}); + + // LDS write i + 1 + const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); + store_tile(a_copy_lds_window, a_block_tile_tmp); + + // LDS write i + 1 + if constexpr(std::is_same_v) + { + auto b_shuffle_tmp_loop = make_static_distributed_tensor( + Policy::template MakeShuffledBRegBlockDescriptor()); + shuffle_tile(b_shuffle_tmp_loop, b_block_tile); + store_tile(b_copy_lds_window, + tile_elementwise_in(b_element_func, b_shuffle_tmp_loop)); + } + else + { + const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile); + store_tile(b_copy_lds_window, b_block_tile_tmp); + } + + iCounter--; + } + + // tail + { + block_sync_lds(); + + // GEMM num_loop - 1 + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + } + + return c_block_tile; +#endif + } + + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + index_t num_loop, + void* p_smem +#if FEIFEI_DEBUG + , + const BDataType* b_ptr, + int* dbg_int, + float* dbg_fp32, + short* dbg_f168 +#endif + ) const + { + return operator()( + a_dram_block_window_tmp, + [](const ADataType & a) { return a; }, + b_dram_block_window_tmp, + [](const BDataType & b) { return b; }, + num_loop, + p_smem +#if FEIFEI_DEBUG + , + b_ptr, + dbg_int, + dbg_fp32, + dbg_f168 +#endif + ); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp new file mode 100644 index 0000000000..0cafaa37ab --- /dev/null +++ b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp @@ -0,0 +1,517 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" + +namespace ck_tile { + +// Default policy for GemmPipelineAGmemBGmemCRegV1 +// Default policy class should not be templated, put template on member functions instead +struct FlatmmPipelineAGmemBGmemCRegV1DefaultPolicy +{ + static constexpr auto I0 = number<0>{}; + static constexpr auto I1 = number<1>{}; + static constexpr auto I2 = number<2>{}; + + static constexpr bool TransposeC = true; + +#if 0 + // 2d + template + CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() + { + using namespace ck_tile; + + constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + + constexpr auto a_lds_block_desc = + make_naive_tensor_descriptor_packed(make_tuple(kMPerBlock, kKPerBlock), number<32>{}); + + return a_lds_block_desc; + } + + // 2d + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor() + { + using namespace ck_tile; + + constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + + constexpr auto b_lds_block_desc = + make_naive_tensor_descriptor_packed(make_tuple(kNPerBlock, kKPerBlock), number<32>{}); + + return b_lds_block_desc; + } +#elif 1 + // 3d + padding + template + CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() + { + using namespace ck_tile; + + constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + + // TODO: this 8 is AK1! should be a policy parameter! + constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number<8>{}), + make_tuple(number<(kMPerBlock + 1) * 8>{}, number<8>{}, number<1>{}), + number<8>{}, + number<1>{}); + + constexpr auto a_lds_block_desc = transform_tensor_descriptor( + a_lds_block_desc_0, + make_tuple(make_pass_through_transform(kMPerBlock), + make_merge_transform(make_tuple(kKPerBlock / 8, 8))), + make_tuple(sequence<1>{}, sequence<0, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return a_lds_block_desc; + } + + // 3d + padding + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor() + { + constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + + constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number<8>{}), + make_tuple(number<(kNPerBlock + 1) * 8>{}, number<8>{}, number<1>{}), + number<8>{}, + number<1>{}); + + constexpr auto b_lds_block_desc = transform_tensor_descriptor( + b_lds_block_desc_0, + make_tuple(make_pass_through_transform(kNPerBlock), + make_merge_transform(make_tuple(kKPerBlock / 8, 8))), + make_tuple(sequence<1>{}, sequence<0, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return b_lds_block_desc; + } + + template + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA() + { + constexpr index_t smem_size_a = sizeof(typename Problem::ADataType) * + MakeALdsBlockDescriptor().get_element_space_size(); + return smem_size_a; + } + + template + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeB() + { + constexpr index_t smem_size_b = sizeof(typename Problem::BDataType) * + MakeBLdsBlockDescriptor().get_element_space_size(); + return smem_size_b; + } + + template + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + constexpr index_t smem_size_a = GetSmemSizeA(); + constexpr index_t smem_size_b = GetSmemSizeB(); + constexpr index_t smem_size = smem_size_a + smem_size_b; + + return smem_size; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackA() + { + using ADataType = remove_cvref_t; + return Problem::VectorLoadSize / sizeof(ADataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackB() + { + using BDataType = remove_cvref_t; + return Problem::VectorLoadSize / sizeof(BDataType); + } +#elif 1 + // fake XOR + template + CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() + { + using namespace ck_tile; + + using ADataType = remove_cvref_t; + + constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + + constexpr auto a_lds_block_desc_d1_d2_d3 = make_naive_tensor_descriptor_packed( + make_tuple(number{}, number<2>{}, number{}), + number{}); + + constexpr index_t kK1 = 16 / sizeof(ADataType); + + constexpr auto a_lds_block_desc_d4_d5_d6 = transform_tensor_descriptor( + a_lds_block_desc_d1_d2_d3, + make_tuple( + make_xor_transform(make_tuple(number{}, number{}), kK1), + make_pass_through_transform(2)), + make_tuple(sequence<0, 2>{}, sequence<1>{}), + make_tuple(sequence<0, 2>{}, sequence<1>{})); + + constexpr auto a_lds_block_desc_m_k = transform_tensor_descriptor( + a_lds_block_desc_d4_d5_d6, + make_tuple(make_merge_transform(make_tuple(number{}, number<2>{})), + make_pass_through_transform(kKPerBlock)), + make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return a_lds_block_desc_m_k; + } + + // fake XOR + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor() + { + using namespace ck_tile; + + using BDataType = remove_cvref_t; + + constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + + constexpr auto b_lds_block_desc_d1_d2_d3 = make_naive_tensor_descriptor_packed( + make_tuple(number{}, number<2>{}, number{}), + number{}); + + constexpr index_t kK1 = 16 / sizeof(BDataType); + + constexpr auto b_lds_block_desc_d4_d5_d6 = transform_tensor_descriptor( + b_lds_block_desc_d1_d2_d3, + make_tuple( + make_xor_transform(make_tuple(number{}, number{}), kK1), + make_pass_through_transform(2)), + make_tuple(sequence<0, 2>{}, sequence<1>{}), + make_tuple(sequence<0, 2>{}, sequence<1>{})); + + constexpr auto b_lds_block_desc_n_k = transform_tensor_descriptor( + b_lds_block_desc_d4_d5_d6, + make_tuple(make_merge_transform(make_tuple(number{}, number<2>{})), + make_pass_through_transform(kKPerBlock)), + make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return b_lds_block_desc_n_k; + } +#endif + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution() + { + using ADataType = remove_cvref_t; + using ALayout = remove_cvref_t; + + constexpr index_t BlockSize = Problem::kBlockSize; + + constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + + if constexpr(std::is_same_v) + { + constexpr index_t M1 = Problem::VectorLoadSize / sizeof(ADataType); + constexpr index_t M0 = MPerBlock / M1; + constexpr index_t total_pixels = MPerBlock * KPerBlock / BlockSize; + static_assert(total_pixels % M1 == 0); + constexpr index_t K3 = total_pixels / M1; + constexpr index_t KPack = GetSmemPackA(); + static_assert(KPack % K3 == 0); + constexpr index_t K2 = KPack / K3; + if constexpr(get_warp_size() % (K2 * M0)) + { + constexpr index_t K1 = get_warp_size() / (K2 * M0); + constexpr index_t K0 = BlockSize / get_warp_size(); + static_assert(KPerBlock == K0 * K1 * K2 * K3); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1, 2>>, + tuple, sequence<1, 0, 2>>, + sequence<2, 1>, + sequence<3, 1>>{}); + } + else + { + constexpr index_t K1 = (K2 * M0) / get_warp_size(); + constexpr index_t K2_m = K2 / K1; + constexpr index_t K0 = BlockSize / get_warp_size() / K1; + static_assert(KPerBlock == K0 * K1 * K2_m * K3); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<0, 2>>, + sequence<2, 1>, + sequence<3, 1>>{}); + } + } + else + { + constexpr index_t K1 = 16 / sizeof(ADataType); + constexpr index_t K0 = KPerBlock / K1; + constexpr index_t M2 = get_warp_size() / K0; + // coalesce reading for each blocks + if constexpr(get_warp_size() % (M2 * K0) == 0) + { + constexpr index_t M1 = BlockSize / get_warp_size(); + static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error."); + static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error."); + constexpr index_t M0 = MPerBlock / (M2 * M1); + static_assert(M0 * M1 * M2 == MPerBlock, + "Incorrect M0, M2, M1 configuration! " + "M0, M1, M2 must cover whole MPerBlock!"); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + else + { + constexpr index_t M0 = BlockSize / get_warp_size(); + constexpr index_t M1 = MPerBlock / (M2 * M0); + static_assert(M0 * M1 * M2 == MPerBlock, + "Incorrect M0, M1, M2 configuration! " + "M0, M1, M2 must cover whole MPerBlock!"); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<1, 1>>{}); + } + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution() + { + using BDataType = remove_cvref_t; + using BLayout = remove_cvref_t; + + constexpr index_t BlockSize = Problem::kBlockSize; + + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + + if constexpr(std::is_same_v) + { + constexpr index_t N1 = Problem::VectorLoadSize / sizeof(BDataType); + constexpr index_t N0 = NPerBlock / N1; + constexpr index_t total_pixels = NPerBlock * KPerBlock / BlockSize; + static_assert(total_pixels % N1 == 0); + constexpr index_t K3 = total_pixels / N1; + constexpr index_t KPack = GetSmemPackB(); + static_assert(KPack % K3 == 0); + constexpr index_t K2 = KPack / K3; + if constexpr(get_warp_size() % (K2 * N0) == 0) + { + constexpr index_t K1 = get_warp_size() / (K2 * N0); + constexpr index_t K0 = BlockSize / get_warp_size(); + static_assert(KPerBlock == K0 * K1 * K2 * K3); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1, 2>>, + tuple, sequence<1, 0, 2>>, + sequence<2, 1>, + sequence<3, 1>>{}); + } + else + { + constexpr index_t K1 = (K2 * N0) / get_warp_size(); + constexpr index_t K2_m = K2 / K1; + constexpr index_t K0 = BlockSize / get_warp_size() / K1; + static_assert(KPerBlock == K0 * K1 * K2_m * K3); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<0, 2>>, + sequence<2, 1>, + sequence<3, 1>>{}); + } + } + else + { + + constexpr index_t K1 = Problem::VectorLoadSize / sizeof(BDataType); + constexpr index_t K0 = KPerBlock / K1; + constexpr index_t N2 = get_warp_size() / K0; + // coalesce reading for each blocks + if constexpr(get_warp_size() % (N2 * K0) == 0) + { + constexpr index_t N1 = BlockSize / get_warp_size(); + static_assert(N2 != 0, "N2 is zero, which will lead to a division by zero error."); + static_assert(N1 != 0, "N1 is zero, which will lead to a division by zero error."); + constexpr index_t N0 = NPerBlock / (N2 * N1); + static_assert(N0 * N1 * N2 == NPerBlock, + "Incorrect N0, N1, N2 configuration! " + "N0, N1, N2 must cover whole NPerBlock!"); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + // coalesce reading for each warps + else + { + constexpr index_t N0 = BlockSize / get_warp_size(); + constexpr index_t N1 = NPerBlock / (N2 * N0); + static_assert(N0 * N1 * N2 == NPerBlock, + "Incorrect N0, N1, N2 configuration! " + "N0, N1, N2 must cover whole NPerBlock!"); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<1, 1>>{}); + } + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledBRegBlockDescriptor() + { + using BLayout = remove_cvref_t; + using BDataType = remove_cvref_t; + static_assert(std::is_same_v); + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + + constexpr index_t N1 = Problem::VectorLoadSize / sizeof(BDataType); + constexpr index_t N0 = kNPerBlock / N1; + constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; + static_assert(total_pixels % N1 == 0); + constexpr index_t K3 = total_pixels / N1; + constexpr index_t kKPack = GetSmemPackB(); + static_assert(kKPack % K3 == 0); + constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave + constexpr index_t warp_size = get_warp_size(); + if constexpr(warp_size % (K2 * N0) == 0) + { + constexpr index_t K1 = warp_size / (K2 * N0); + constexpr index_t K0 = kBlockSize / warp_size; + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1, 2>>, + tuple, sequence<1, 0, 2>>, + sequence<1, 2>, + sequence<1, 3>>{}); + } + else + { + constexpr index_t K1 = (K2 * N0) / get_warp_size(); + constexpr index_t K2_m = K2 / K1; + constexpr index_t K0 = kBlockSize / get_warp_size() / K1; + static_assert(kKPerBlock == K0 * K1 * K2_m * K3); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<0, 2>>, + sequence<1, 2>, + sequence<1, 3>>{}); + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledARegBlockDescriptor() + { + using ALayout = remove_cvref_t; + using ADataType = remove_cvref_t; + static_assert(std::is_same_v); + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + + constexpr index_t M1 = Problem::VectorLoadSize / sizeof(ADataType); + constexpr index_t M0 = kMPerBlock / M1; + constexpr index_t total_pixels = kMPerBlock * kKPerBlock / kBlockSize; + static_assert(total_pixels % M1 == 0); + constexpr index_t K3 = total_pixels / M1; + constexpr index_t kKPack = GetSmemPackA(); + static_assert(kKPack % K3 == 0); + constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave + constexpr index_t warp_size = get_warp_size(); + if constexpr(warp_size % (K2 * M0) == 0) + { + constexpr index_t K1 = warp_size / (K2 * M0); + constexpr index_t K0 = kBlockSize / warp_size; + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1, 2>>, + tuple, sequence<1, 0, 2>>, + sequence<1, 2>, + sequence<1, 3>>{}); + } + else + { + constexpr index_t K1 = (K2 * M0) / get_warp_size(); + constexpr index_t K2_m = K2 / K1; + constexpr index_t K0 = kBlockSize / get_warp_size() / K1; + static_assert(kKPerBlock == K0 * K1 * K2_m * K3); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<0, 2>>, + sequence<1, 2>, + sequence<1, 3>>{}); + } + } + + CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() { return TransposeC; } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() + { + using AccDataType = float; + using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; + using WarpTile = typename Problem::BlockGemmShape::WarpTile; + using WarpGemm = WarpGemmMfmaDispatcher; + using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy; + + return BlockUniversalGemmAsBsCr{}; + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/flatmm/pipeline/flatmm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/flatmm/pipeline/flatmm_universal_pipeline_ag_bg_cr_policy.hpp new file mode 100644 index 0000000000..c60a1b0b1e --- /dev/null +++ b/include/ck_tile/ops/flatmm/pipeline/flatmm_universal_pipeline_ag_bg_cr_policy.hpp @@ -0,0 +1,470 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" + +namespace ck_tile { + +// UniversalGemm Policy +struct UniversalFlatmmPipelineAgBgCrPolicy +{ + static constexpr auto I0 = number<0>{}; + static constexpr auto I1 = number<1>{}; + static constexpr auto I2 = number<2>{}; + + static constexpr bool TransposeC = true; + + template + CK_TILE_HOST_DEVICE static constexpr auto GetVectorLoadSize() + { + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t elements_per_thread = MNPerBlock * KPerBlock / BlockSize; + + if constexpr(elements_per_thread % (16 / sizeof(DataType)) == 0) + { + return (16 / sizeof(DataType)); + } + else if constexpr(elements_per_thread % (8 / sizeof(DataType)) == 0) + { + return (8 / sizeof(DataType)); + } + else if constexpr(elements_per_thread % (4 / sizeof(DataType)) == 0 && + sizeof(DataType) >= 4) + { + return (4 / sizeof(DataType)); + } + else if constexpr(elements_per_thread % (2 / sizeof(DataType)) == 0 && + sizeof(DataType) >= 2) + { + return (2 / sizeof(DataType)); + } + else + { + return 1; + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() + { + + using ADataType = remove_cvref_t; + + constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t KPack = GetVectorLoadSize(); + + constexpr auto DataTypeSize = sizeof(ADataType); + constexpr auto MLdsLayer = + (32 * 4 / KPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerBlock / DataTypeSize); + + constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, + number{}, + number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc_0, + make_tuple(make_xor_transform(make_tuple(number{}, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<1, 0>{}, sequence<2>{}), + make_tuple(sequence<1, 0>{}, sequence<2>{})); + + constexpr auto a_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple(make_unmerge_transform( + make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); + + constexpr auto a_lds_block_desc = transform_tensor_descriptor( + a_lds_block_desc_xk0_mnldslayer_mn_xk1, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(number{}, number{})), + make_merge_transform_v3_division_mod( + make_tuple(number{}, number{}))), + make_tuple(sequence<1, 2>{}, sequence<0, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return a_lds_block_desc; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor() + { + + using BDataType = remove_cvref_t; + + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t KPack = GetVectorLoadSize(); + + constexpr auto DataTypeSize = sizeof(BDataType); + constexpr auto NLdsLayer = + (32 * 4 / KPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerBlock / DataTypeSize); + + constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, + number{}, + number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc_0, + make_tuple(make_xor_transform(make_tuple(number{}, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<1, 0>{}, sequence<2>{}), + make_tuple(sequence<1, 0>{}, sequence<2>{})); + + constexpr auto b_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple(make_unmerge_transform( + make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); + + constexpr auto b_lds_block_desc = transform_tensor_descriptor( + b_lds_block_desc_xk0_mnldslayer_mn_xk1, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(number{}, number{})), + make_merge_transform_v3_division_mod( + make_tuple(number{}, number{}))), + make_tuple(sequence<1, 2>{}, sequence<0, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + return b_lds_block_desc; + } + + template + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA() + { + constexpr index_t smem_size_a = sizeof(typename Problem::ADataType) * + MakeALdsBlockDescriptor().get_element_space_size(); + return smem_size_a; + } + + template + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeB() + { + constexpr index_t smem_size_b = sizeof(typename Problem::BDataType) * + MakeBLdsBlockDescriptor().get_element_space_size(); + return smem_size_b; + } + + template + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + constexpr index_t smem_size_a = GetSmemSizeA(); + constexpr index_t smem_size_b = GetSmemSizeB(); + index_t smem_size = 0; + smem_size += smem_size_a + smem_size_b; + + return smem_size; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution() + { + using ADataType = remove_cvref_t; + using ALayout = remove_cvref_t; + + constexpr index_t BlockSize = Problem::kBlockSize; + + constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + + if constexpr(std::is_same_v) + { + constexpr index_t M1 = Problem::VectorLoadSize / sizeof(ADataType); + constexpr index_t M0 = MPerBlock / M1; + constexpr index_t total_pixels = MPerBlock * KPerBlock / BlockSize; + static_assert(total_pixels % M1 == 0); + constexpr index_t K3 = total_pixels / M1; + constexpr index_t KPack = GetVectorLoadSize(); + static_assert(KPack % K3 == 0); + constexpr index_t K2 = KPack / K3; + if constexpr(get_warp_size() % (K2 * M0) == 0) + { + constexpr index_t K1 = get_warp_size() / (K2 * M0); + constexpr index_t K0 = BlockSize / get_warp_size(); + static_assert(KPerBlock == K0 * K1 * K2 * K3); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1, 2>>, + tuple, sequence<1, 0, 2>>, + sequence<2, 1>, + sequence<3, 1>>{}); + } + else + { + constexpr index_t K1 = (K2 * M0) / get_warp_size(); + constexpr index_t K2_m = K2 / K1; + constexpr index_t K0 = BlockSize / get_warp_size() / K1; + static_assert(KPerBlock == K0 * K1 * K2_m * K3); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<0, 2>>, + sequence<2, 1>, + sequence<3, 1>>{}); + } + } + else + { + constexpr index_t K1 = Problem::VectorLoadSize / sizeof(ADataType); + constexpr index_t K0 = KPerBlock / K1; + constexpr index_t M2 = get_warp_size() / K0; + if constexpr(get_warp_size() % (M2 * K0) == 0) + { + constexpr index_t M1 = BlockSize / get_warp_size(); + static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error."); + static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error."); + constexpr index_t M0 = MPerBlock / (M2 * M1); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + else + { + constexpr index_t M0 = BlockSize / get_warp_size(); + constexpr index_t M1 = MPerBlock / (M2 * M0); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<1, 1>>{}); + } + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution() + { + using BDataType = remove_cvref_t; + using BLayout = remove_cvref_t; + + constexpr index_t BlockSize = Problem::kBlockSize; + + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + + if constexpr(std::is_same_v) + { + constexpr index_t N1 = Problem::VectorLoadSize / sizeof(BDataType); + constexpr index_t N0 = NPerBlock / N1; + constexpr index_t total_pixels = NPerBlock * KPerBlock / BlockSize; + static_assert(total_pixels % N1 == 0); + constexpr index_t K3 = total_pixels / N1; + constexpr index_t KPack = GetVectorLoadSize(); + static_assert(KPack % K3 == 0); + constexpr index_t K2 = KPack / K3; + if constexpr(get_warp_size() % (K2 * N0) == 0) + { + constexpr index_t K1 = get_warp_size() / (K2 * N0); + constexpr index_t K0 = BlockSize / get_warp_size(); + static_assert(KPerBlock == K0 * K1 * K2 * K3); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1, 2>>, + tuple, sequence<1, 0, 2>>, + sequence<2, 1>, + sequence<3, 1>>{}); + } + else + { + constexpr index_t K1 = (K2 * N0) / get_warp_size(); + constexpr index_t K2_m = K2 / K1; + constexpr index_t K0 = BlockSize / get_warp_size() / K1; + static_assert(KPerBlock == K0 * K1 * K2_m * K3); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<0, 2>>, + sequence<2, 1>, + sequence<3, 1>>{}); + } + } + else + { + + constexpr index_t K1 = Problem::VectorLoadSize / sizeof(BDataType); + constexpr index_t K0 = KPerBlock / K1; + constexpr index_t N2 = get_warp_size() / K0; + // coalesce reading for each blocks + if constexpr(get_warp_size() % (N2 * K0) == 0) + { + constexpr index_t N1 = BlockSize / get_warp_size(); + static_assert(N2 != 0, "N2 is zero, which will lead to a division by zero error."); + static_assert(N1 != 0, "N1 is zero, which will lead to a division by zero error."); + constexpr index_t N0 = NPerBlock / (N2 * N1); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + // coalesce reading for each warps + else + { + constexpr index_t N0 = BlockSize / get_warp_size(); + constexpr index_t N1 = NPerBlock / (N2 * N0); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<1, 1>>{}); + } + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledARegBlockDescriptor() + { + using ALayout = remove_cvref_t; + using ADataType = remove_cvref_t; + static_assert(std::is_same_v); + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t MPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + + constexpr index_t M1 = Problem::VectorLoadSize / sizeof(ADataType); + constexpr index_t M0 = MPerBlock / M1; + constexpr index_t total_pixels = MPerBlock * KPerBlock / BlockSize; + static_assert(total_pixels % M1 == 0); + constexpr index_t K3 = total_pixels / M1; + constexpr index_t kKPack = GetVectorLoadSize(); + static_assert(kKPack % K3 == 0); + constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave + constexpr index_t warp_size = get_warp_size(); + if constexpr(warp_size % (K2 * M0) == 0) + { + constexpr index_t K1 = warp_size / (K2 * M0); + constexpr index_t K0 = BlockSize / warp_size; + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1, 2>>, + tuple, sequence<1, 0, 2>>, + sequence<1, 2>, + sequence<1, 3>>{}); + } + else + { + constexpr index_t K1 = (K2 * M0) / get_warp_size(); + constexpr index_t K2_m = K2 / K1; + constexpr index_t K0 = BlockSize / get_warp_size() / K1; + static_assert(KPerBlock == K0 * K1 * K2_m * K3); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<0, 2>>, + sequence<1, 2>, + sequence<1, 3>>{}); + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledBRegBlockDescriptor() + { + using BLayout = remove_cvref_t; + using BDataType = remove_cvref_t; + static_assert(std::is_same_v); + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + + constexpr index_t N1 = Problem::VectorLoadSize / sizeof(BDataType); + constexpr index_t N0 = NPerBlock / N1; + constexpr index_t total_pixels = NPerBlock * KPerBlock / BlockSize; + static_assert(total_pixels % N1 == 0); + constexpr index_t K3 = total_pixels / N1; + constexpr index_t kKPack = GetVectorLoadSize(); + static_assert(kKPack % K3 == 0); + constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave + constexpr index_t warp_size = get_warp_size(); + if constexpr(warp_size % (K2 * N0) == 0) + { + constexpr index_t K1 = warp_size / (K2 * N0); + constexpr index_t K0 = BlockSize / warp_size; + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1, 2>>, + tuple, sequence<1, 0, 2>>, + sequence<1, 2>, + sequence<1, 3>>{}); + } + else + { + constexpr index_t K1 = (K2 * N0) / get_warp_size(); + constexpr index_t K2_m = K2 / K1; + constexpr index_t K0 = BlockSize / get_warp_size() / K1; + static_assert(KPerBlock == K0 * K1 * K2_m * K3); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<0, 2>>, + sequence<1, 2>, + sequence<1, 3>>{}); + } + } + + CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() { return TransposeC; } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() + { + using AccDataType = float; + using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; + using WarpTile = typename Problem::BlockGemmShape::WarpTile; + using WarpGemm = WarpGemmMfmaDispatcher; + using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy; + return BlockGemmASmemBSmemCRegV1{}; + } +}; + +} // namespace ck_tile From e076a320ac70bf1241a8b5c867073c163e4127be Mon Sep 17 00:00:00 2001 From: feifei14119 Date: Mon, 10 Feb 2025 11:20:32 +0800 Subject: [PATCH 2/3] debug a --- example/ck_tile/18_flatmm/CMakeLists.txt | 2 +- .../ck_tile/18_flatmm/run_flatmm_example.inc | 8 +- include/ck_tile/ops/flatmm.hpp | 40 +- .../block_flatmm_asmem_bsmem_creg_v1.hpp | 442 +++++++++++++++ ...atmm_asmem_bsmem_creg_v1_custom_policy.hpp | 38 ++ ...tmm_asmem_bsmem_creg_v1_default_policy.hpp | 59 ++ .../ops/flatmm/kernel/flatmm_kernel.hpp | 72 ++- .../flatmm_pipeline_agmem_bgmem_creg_v1.hpp | 137 ++++- ...ine_agmem_bgmem_creg_v1_default_policy.hpp | 517 ------------------ ...tmm_universal_pipeline_ag_bg_cr_policy.hpp | 35 +- 10 files changed, 742 insertions(+), 608 deletions(-) create mode 100644 include/ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp create mode 100644 include/ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1_custom_policy.hpp create mode 100644 include/ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1_default_policy.hpp delete mode 100644 include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp diff --git a/example/ck_tile/18_flatmm/CMakeLists.txt b/example/ck_tile/18_flatmm/CMakeLists.txt index dc52e049d3..1e0d6e5f17 100644 --- a/example/ck_tile/18_flatmm/CMakeLists.txt +++ b/example/ck_tile/18_flatmm/CMakeLists.txt @@ -5,5 +5,5 @@ list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-flo list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -Wno-unused-variable -Wno-unused-parameter) list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -Wno-unused-local-typedef) #list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -ggdb -g -O0 -v -save-temps) -list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -DFEIFEI_DEBUG=1) +list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -DFEIFEI_DEBUG=1 -DDEBUG_CNT=64) target_compile_options(tile_example_flatmm_basic PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS}) diff --git a/example/ck_tile/18_flatmm/run_flatmm_example.inc b/example/ck_tile/18_flatmm/run_flatmm_example.inc index 3ea7fb3765..9ee7889a17 100644 --- a/example/ck_tile/18_flatmm/run_flatmm_example.inc +++ b/example/ck_tile/18_flatmm/run_flatmm_example.inc @@ -183,9 +183,9 @@ int run_flatmm_example_with_layouts(int argc, ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes()); #if FEIFEI_DEBUG - ck_tile::HostTensor dbg_int({M * N * 64}); - ck_tile::HostTensor dbg_fp32({M * N * 64}); - ck_tile::HostTensor dbg_f168({M * N * 64}); + ck_tile::HostTensor dbg_int({M * N * DEBUG_CNT}); + ck_tile::HostTensor dbg_fp32({M * N * DEBUG_CNT}); + ck_tile::HostTensor dbg_f168({M * N * DEBUG_CNT}); ck_tile::DeviceMem dbg_int_buf(dbg_int.get_element_space_size_in_bytes()); ck_tile::DeviceMem dbg_fp32_buf(dbg_fp32.get_element_space_size_in_bytes()); @@ -362,7 +362,7 @@ int run_flatmm_example_with_layouts(int argc, int GridDimY = 1; int BlockDimX = 64; int BlockDimY = 4; - int DbgCnt = 64; + int DbgCnt = DEBUG_CNT; int BlockSize = BlockDimX * BlockDimY; // a_host { diff --git a/include/ck_tile/ops/flatmm.hpp b/include/ck_tile/ops/flatmm.hpp index 8c8e37d90f..0fec22675d 100644 --- a/include/ck_tile/ops/flatmm.hpp +++ b/include/ck_tile/ops/flatmm.hpp @@ -3,40 +3,9 @@ #pragma once -#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp" -#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp" -#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp" -#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp" -#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_default_policy.hpp" -#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_one_warp_v1.hpp" -#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp" -#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp" -#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp" -#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp" -#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp" -#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp" -#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp" -#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_custom_policy.hpp" -#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_default_policy.hpp" -#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp" -#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp" -#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_problem.hpp" -#include "ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp" -#include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp" -// #include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp" #include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp" -#include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp" -#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp" -#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp" -#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp" -#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" -// #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp" -// #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp" -#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp" -#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp" -// #include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" #include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp" #include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm.hpp" @@ -45,10 +14,15 @@ #include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_impl.hpp" -#include "ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp" +// block +#include "ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp" +#include "ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1_default_policy.hpp" +#include "ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1_custom_policy.hpp" +// pipeline #include "ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp" -#include "ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/flatmm/pipeline/flatmm_universal_pipeline_ag_bg_cr_policy.hpp" +// kernel +#include "ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp" #include "ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp" #include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32.hpp" diff --git a/include/ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp new file mode 100644 index 0000000000..4efff75763 --- /dev/null +++ b/include/ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp @@ -0,0 +1,442 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1_default_policy.hpp" + +namespace ck_tile { + +// A is block window on shared memory +// B is block window on shared memory +// C is block distributed tensor +template +struct BlockFlatmmASmemBSmemCRegV1 +{ + using Problem = remove_cvref_t; + using BlockPolicy = remove_cvref_t; + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + CK_TILE_DEVICE static constexpr auto MakeCBlockTile() + { + constexpr index_t MPerBlock = BlockGemmShape::kM; + constexpr index_t NPerBlock = BlockGemmShape::kN; + + constexpr auto config = BlockPolicy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); + + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); + + constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); + + auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); + return c_block_tensor; + } + +#if 1 + // C += A * B + // template + template + CK_TILE_DEVICE void operator()(const ABlockWindow& a_block_window +#if FEIFEI_DEBUG + , + const BDataType* b_ptr, + int* dbg_int, + float* dbg_fp32, + void* dbg_f168 +#endif + ) const + { +#if FEIFEI_DEBUG + if(threadIdx.x == 0 && blockIdx.x == 0 && threadIdx.y == 0 && blockIdx.y == 0) + { + printf("[BLOCK ] BlockFlatmmASmemBSmemCRegV1():\n"); + } + + uint32_t tidx = threadIdx.x; + uint32_t tidy = threadIdx.y; + uint32_t bidx = blockIdx.x; + uint32_t bidy = blockIdx.y; + uint32_t bdmx = blockDim.x; + uint32_t bdmy = blockDim.y; + uint32_t gdmx = gridDim.x; + uint32_t gdmy = gridDim.y; + uint32_t gid = ((bdmx * bdmy) * gdmx) * bidy + (bdmx * bdmy) * bidx + bdmx * tidy + tidx; + + half_t* dbg_f16 = static_cast(dbg_f168); + for(int i = 0; i < DEBUG_CNT; i++) + { + dbg_int[gid * DEBUG_CNT + i] = -1; + dbg_fp32[gid * DEBUG_CNT + i] = -1.0f; + dbg_f16[gid * DEBUG_CNT + i] = ck_tile::type_convert(-1.0f); + } +#endif + /* + static_assert(std::is_same_v && + std::is_same_v && + std::is_same_v, + "wrong!"); + */ + + constexpr index_t MPerBlock = ABlockWindow{}.get_window_lengths()[number<0>{}]; + // constexpr index_t NPerBlock = BBlockWindow{}.get_window_lengths()[number<0>{}]; + constexpr index_t KPerBlock = ABlockWindow{}.get_window_lengths()[number<1>{}]; + + /* + static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && + KPerBlock == BlockGemmShape::kK, + "wrong!"); + */ + + constexpr auto config = BlockPolicy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + // constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); + constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp; + // constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp; + constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp; + + const index_t iMWarp = get_warp_id() / NWarp; + const index_t iNWarp = get_warp_id() % NWarp; + + // construct A-warp-window + auto a_warp_window_tmp = make_tile_window( + a_block_window.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + a_block_window.get_window_origin() + multi_index<2>{iMWarp * WG::kM, 0}, + make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); + + statically_indexed_array< + statically_indexed_array, + MIterPerWarp> + a_warp_windows; + + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + a_warp_windows(mIter)(kIter) = a_warp_window_tmp; + + move_tile_window(a_warp_windows(mIter)(kIter), + {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); + }); + }); + + // Warp loop in block: + constexpr index_t kIter = 0; + constexpr index_t mIter = 0; + const auto a_warp_tensor = load_tile(a_warp_windows(number{})(number{})); + +#if 1 + // feifei TODO: Implement gemm here +#else + constexpr auto config = BlockPolicy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); + constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp; + constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp; + constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp; + + const index_t iMWarp = get_warp_id() / NWarp; + const index_t iNWarp = get_warp_id() % NWarp; + + // construct A-warp-window + auto a_warp_window_tmp = make_tile_window( + a_block_window.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + a_block_window.get_window_origin() + multi_index<2>{iMWarp * WG::kM, 0}, + make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); + +#if 0 // FIXME: using array will cause register spill + array, MIterPerWarp> a_warp_windows{ + {a_warp_window_tmp}}; + + for(index_t mIter = 0; mIter < MIterPerWarp; mIter++) + { + for(index_t kIter = 0; kIter < KIterPerWarp; kIter++) + { + move_tile_window(a_warp_windows(mIter)(kIter), + {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); + } + } +#else + statically_indexed_array< + statically_indexed_array, + MIterPerWarp> + a_warp_windows; + + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + a_warp_windows(mIter)(kIter) = a_warp_window_tmp; + + move_tile_window(a_warp_windows(mIter)(kIter), + {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); + }); + }); +#endif + + // construct B-warp-window + auto b_warp_window_tmp = make_tile_window( + b_block_window.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + b_block_window.get_window_origin() + multi_index<2>{iNWarp * WG::kN, 0}, + make_static_tile_distribution(typename WG::BWarpDstrEncoding{})); + +#if 0 // FIXME: using array will cause register spill + array, NIterPerWarp> b_warp_windows{ + {b_warp_window_tmp}}; + + for(index_t nIter = 0; nIter < NIterPerWarp; nIter++) + { + for(index_t kIter = 0; kIter < KIterPerWarp; kIter++) + { + move_tile_window(b_warp_windows(nIter)(kIter), + {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); + } + } +#else + statically_indexed_array< + statically_indexed_array, + NIterPerWarp> + b_warp_windows; + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + b_warp_windows(nIter)(kIter) = b_warp_window_tmp; + + move_tile_window(b_warp_windows(nIter)(kIter), + {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); + }); + }); +#endif + + using CWarpDstr = typename WG::CWarpDstr; + using CWarpTensor = typename WG::CWarpTensor; + + constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + // hot loop: + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // read A warp tensor from A block window + const auto a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter)); + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read B warp tensor from B Block window + const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter)); + + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); + }); + }); +#endif + } + +#else + // C += A * B + template + CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, + const ABlockWindow& a_block_window, + const BBlockWindow& b_block_window) const + { + static_assert(std::is_same_v && + std::is_same_v && + std::is_same_v, + "wrong!"); + + constexpr index_t MPerBlock = ABlockWindow{}.get_window_lengths()[number<0>{}]; + constexpr index_t NPerBlock = BBlockWindow{}.get_window_lengths()[number<0>{}]; + constexpr index_t KPerBlock = ABlockWindow{}.get_window_lengths()[number<1>{}]; + + static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && + KPerBlock == BlockGemmShape::kK, + "wrong!"); + + constexpr auto config = BlockPolicy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); + constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp; + constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp; + constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp; + + const index_t iMWarp = get_warp_id() / NWarp; + const index_t iNWarp = get_warp_id() % NWarp; + + // construct A-warp-window + auto a_warp_window_tmp = make_tile_window( + a_block_window.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + a_block_window.get_window_origin() + multi_index<2>{iMWarp * WG::kM, 0}, + make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); + +#if 0 // FIXME: using array will cause register spill + array, MIterPerWarp> a_warp_windows{ + {a_warp_window_tmp}}; + + for(index_t mIter = 0; mIter < MIterPerWarp; mIter++) + { + for(index_t kIter = 0; kIter < KIterPerWarp; kIter++) + { + move_tile_window(a_warp_windows(mIter)(kIter), + {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); + } + } +#else + statically_indexed_array< + statically_indexed_array, + MIterPerWarp> + a_warp_windows; + + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + a_warp_windows(mIter)(kIter) = a_warp_window_tmp; + + move_tile_window(a_warp_windows(mIter)(kIter), + {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); + }); + }); +#endif + + // construct B-warp-window + auto b_warp_window_tmp = make_tile_window( + b_block_window.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + b_block_window.get_window_origin() + multi_index<2>{iNWarp * WG::kN, 0}, + make_static_tile_distribution(typename WG::BWarpDstrEncoding{})); + +#if 0 // FIXME: using array will cause register spill + array, NIterPerWarp> b_warp_windows{ + {b_warp_window_tmp}}; + + for(index_t nIter = 0; nIter < NIterPerWarp; nIter++) + { + for(index_t kIter = 0; kIter < KIterPerWarp; kIter++) + { + move_tile_window(b_warp_windows(nIter)(kIter), + {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); + } + } +#else + statically_indexed_array< + statically_indexed_array, + NIterPerWarp> + b_warp_windows; + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + b_warp_windows(nIter)(kIter) = b_warp_window_tmp; + + move_tile_window(b_warp_windows(nIter)(kIter), + {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); + }); + }); +#endif + + using CWarpDstr = typename WG::CWarpDstr; + using CWarpTensor = typename WG::CWarpTensor; + + constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + // hot loop: + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // read A warp tensor from A block window + const auto a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter)); + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read B warp tensor from B Block window + const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter)); + + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); + }); + }); + } + + // C = A * B + template + CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp, + const BBlockWindow& b_block_window) const + { + auto c_block_tensor = MakeCBlockTile(); + operator()(c_block_tensor, a_block_tensor_tmp, b_block_window); + return c_block_tensor; + } +#endif +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1_custom_policy.hpp b/include/ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1_custom_policy.hpp new file mode 100644 index 0000000000..30d4a7bc21 --- /dev/null +++ b/include/ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1_custom_policy.hpp @@ -0,0 +1,38 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +// Default policy for BlockGemmASmemBSmemCRegV1 +// Default policy class should not be templated, put template on member functions instead +template +struct BlockFlatmmASmemBSmemCRegV1CustomPolicy +{ + using AType = remove_cvref_t; + using BType = remove_cvref_t; + using CType = remove_cvref_t; + + using BlockWarps = remove_cvref_t; + + static constexpr index_t kMWarps = BlockWarps::at(number<0>{}); + static constexpr index_t kNWarps = BlockWarps::at(number<1>{}); + static constexpr index_t kKWarps = BlockWarps::at(number<2>{}); + + using WarpGemm = remove_cvref_t; + + template + CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp() + { + return make_tuple(WarpGemm{}, kMWarps, kNWarps); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1_default_policy.hpp b/include/ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1_default_policy.hpp new file mode 100644 index 0000000000..1dc71b209e --- /dev/null +++ b/include/ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1_default_policy.hpp @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm.hpp" + +namespace ck_tile { + +// Default policy for BlockGemmASmemBSmemCRegV1 +// Default policy class should not be templated, put template on member functions instead +struct BlockFlatmmASmemBSmemCRegV1DefaultPolicy +{ + template + CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp() + { + if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { +#if 0 + constexpr index_t kBlockSize = Problem::kBlockSize; + + constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + + static_assert(kBlockSize % get_warp_size() == 0, "wrong!"); + + constexpr index_t NumWarp = kBlockSize / get_warp_size(); + + if constexpr(NumWarp == 4 && kMPerBlock % 128 == 0 && + kNPerBlock % 128 == 0 % kKPerBlock % 16 == 0) + { + return make_tuple(WarpGemmMfmaF16F16F32M32N32K16{}, 2, 2); + } + else + { + return make_tuple(WarpGemmMfmaF16F16F32M32N32K16{}, 2, 2); + } +#else + return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 4, 1); +#endif + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution{}, 4, 1); + } + else + { + static_assert(false, "Unsupported data type configuration for GEMM warp execution."); + } + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp index dd52739545..de6040b758 100644 --- a/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp @@ -85,7 +85,7 @@ struct FlatmmKernel CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch) { - return TilePartitioner::GridSize(M, N); + return TilePartitioner::GridSize(M, N); // feifei TODO: split K here // return dim3(TilePartitioner::GridSize(M, N), 1, KBatch); } @@ -178,7 +178,7 @@ struct FlatmmKernel index_t a_k_split_offset; index_t b_k_split_offset; - index_t splitted_k; + index_t splitted_k; // problem K after splitted }; CK_TILE_HOST static bool IsSupportedArgument(const FlatmmKernelArgs& kargs) @@ -473,7 +473,7 @@ struct FlatmmKernel const BDataType* b_ptr, int* dbg_int, float* dbg_fp32, - short* dbg_f168 + void* dbg_f168 #endif ) { @@ -481,12 +481,55 @@ struct FlatmmKernel // Create Flatmm tensor views, pad views and tile windows const auto& gemm_tensor_views_tuple = MakeGemmTensorViews( a_ptr, b_shuffle_ptr, c_ptr, kargs, splitk_batch_offset); - // origin layout - // const auto& gemm_tensor_views_tuple = - // MakeGemmTensorViews(a_ptr, b_ptr, c_ptr, kargs, splitk_batch_offset); + // Debug origin layout + // const auto& gemm_tensor_views_tuple = MakeGemmTensorViews( + // a_ptr, b_ptr, c_ptr, kargs, splitk_batch_offset); const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); - auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); + + const auto& gemm_tile_windows = + MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); + + +#if FEIFEI_DEBUG + //////////////////////////////////////////////////////// + const auto& a_gemm_tensor_views = gemm_tensor_views_tuple.at(I0); // tensor_view + const auto& a_gemm_tensor_desc = a_gemm_tensor_views.desc_; // tensor_descriptor + const auto& a_gemm_buff_views = a_gemm_tensor_views.buf_; // buffer_view + if(threadIdx.x == 0 && blockIdx.x == 0 && threadIdx.y == 0 && blockIdx.y == 0) + { + printf("[KERNEL] a_gemm_tensor_view: size = %ld, len = [%d, %d], top = [%d, %d], upper = %d, lower = %d\n", + a_gemm_tensor_desc.get_element_space_size(), + a_gemm_tensor_desc.get_length(I0), a_gemm_tensor_desc.get_length(I1), + a_gemm_tensor_desc.get_top_dimension_hidden_ids()[0], a_gemm_tensor_desc.get_top_dimension_hidden_ids()[1], + a_gemm_tensor_desc.get_upper_dimension_hidden_idss()(I0)[0], + a_gemm_tensor_desc.get_lower_dimension_hidden_idss()(I0)[0] + ); + } + + const auto& a_pad_tensor_views = gemm_pad_views.at(I0); // tensor_view + const auto& a_pad_tensor_desc = a_pad_tensor_views.desc_; // tensor_descriptor + const auto& a_pad_buff_views = a_pad_tensor_views.buf_; // buffer_view + if(threadIdx.x == 0 && blockIdx.x == 0 && threadIdx.y == 0 && blockIdx.y == 0) + { + printf("[KERNEL] a_pad_tensor_view: size = %ld, len = [%d, %d], top = [%d, %d], upper = %d, lower = %d\n", + a_pad_tensor_desc.get_element_space_size(), + a_pad_tensor_desc.get_length(I0), a_pad_tensor_desc.get_length(I1), + a_pad_tensor_desc.get_top_dimension_hidden_ids()[0], a_pad_tensor_desc.get_top_dimension_hidden_ids()[1], + a_pad_tensor_desc.get_upper_dimension_hidden_idss()(I0)[0], + a_pad_tensor_desc.get_lower_dimension_hidden_idss()(I0)[0] + ); + } + + const auto& a_tile_win = gemm_tile_windows.at(I0); // tile_window_with_static_lengths + if(threadIdx.x == 0 && blockIdx.x == 0 && threadIdx.y == 0 && blockIdx.y == 0) + { + printf("[KERNEL] a_gemm_tile_window: dim_num = %d\n", + a_tile_win.get_num_of_dimension() + ); + } + //////////////////////////////////////////////////////// +#endif const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k); @@ -555,11 +598,14 @@ struct FlatmmKernel int* dbg_int = static_cast(kargs.dbg_int_ptr); float* dbg_fp32 = static_cast(kargs.dbg_fp32_ptr); - short* dbg_f168 = static_cast(kargs.dbg_f168_ptr); + half_t* dbg_f16 = static_cast(kargs.dbg_f168_ptr); - dbg_int[gid] = 1; - dbg_fp32[gid] = 1.0f; - dbg_f168[gid] = ck_tile::type_convert(1.0f); + for(int i = 0; i < DEBUG_CNT; i++) + { + dbg_int[gid * DEBUG_CNT + i] = 0; + dbg_fp32[gid * DEBUG_CNT + i] = .0f; + dbg_f16[gid * DEBUG_CNT + i] = ck_tile::type_convert(0.f); + } #endif const auto [iM, iN] = TilePartitioner::GetOutputTileIndex(blockIdx.x, blockIdx.y); @@ -592,7 +638,7 @@ struct FlatmmKernel b_ptr, dbg_int, dbg_fp32, - dbg_f168 + kargs.dbg_f168_ptr #endif ); } @@ -611,7 +657,7 @@ struct FlatmmKernel b_ptr, dbg_int, dbg_fp32, - dbg_f168 + kargs.dbg_f168_ptr #endif ); } diff --git a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp index 4c37e827f3..bcd1ac7a14 100644 --- a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -4,14 +4,14 @@ #pragma once #include "ck_tile/core.hpp" -#include "ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp" +#include "ck_tile/ops/flatmm/pipeline/flatmm_universal_pipeline_ag_bg_cr_policy.hpp" namespace ck_tile { // A Tile Window: global memory // B Tile Window: global memory // C Distributed tensor: register -template +template // feifei TODO: add default policy struct FlatmmPipelineAGmemBGmemCRegV1 { using ADataType = remove_cvref_t; @@ -23,7 +23,8 @@ struct FlatmmPipelineAGmemBGmemCRegV1 using BLayout = remove_cvref_t; using CLayout = remove_cvref_t; - using BlockGemm = remove_cvref_t())>; + using BlockFlatmm = + remove_cvref_t())>; static constexpr index_t BlockSize = Problem::kBlockSize; @@ -41,21 +42,24 @@ struct FlatmmPipelineAGmemBGmemCRegV1 CK_TILE_HOST_DEVICE static constexpr index_t GetStaticLdsSize() { - return integer_divide_ceil( - sizeof(ADataType) * - Policy::template MakeALdsBlockDescriptor().get_element_space_size(), - 16) * + return integer_divide_ceil(sizeof(ADataType) * + PipelinePolicy::template MakeALdsBlockDescriptor() + .get_element_space_size(), + 16) * 16 + - sizeof(BDataType) * - Policy::template MakeBLdsBlockDescriptor().get_element_space_size(); + sizeof(BDataType) * PipelinePolicy::template MakeBLdsBlockDescriptor() + .get_element_space_size(); } CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { - return Policy::template GetSmemSize(); + return PipelinePolicy::template GetSmemSize(); } - CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() { return Policy::IsTransposeC(); } + CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() + { + return PipelinePolicy::IsTransposeC(); + } template (-1.0f); + half_t* dbg_f16 = static_cast(dbg_f168); + for(int i = 0; i < DEBUG_CNT; i++) + { + dbg_int[gid * DEBUG_CNT + i] = 1; + dbg_fp32[gid * DEBUG_CNT + i] = 1.0f; + dbg_f16[gid * DEBUG_CNT + i] = ck_tile::type_convert(1.0f); + } #endif static_assert( std::is_same_v> && @@ -108,12 +117,93 @@ struct FlatmmPipelineAGmemBGmemCRegV1 #if 1 // feifei TODO: Implement gemm here + // Get block flatmm + auto block_flatmm = BlockFlatmm(); // struct BlockFlatmmASmemBSmemCRegV1 + + // A tile in LDS + ADataType* p_a_lds = static_cast(p_smem); + + constexpr auto a_lds_block_desc = + PipelinePolicy::template MakeALdsBlockDescriptor(); + + auto a_lds_block = make_tensor_view(p_a_lds, a_lds_block_desc); + + constexpr index_t a_lds_block_space_size_aligned = + integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16) * + 16; + + // A DRAM tile window for load + auto a_copy_dram_window = + make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + a_dram_block_window_tmp.get_window_origin(), + PipelinePolicy::template MakeADramTileDistribution()); + + // A LDS tile window for store + auto a_copy_lds_window = make_tile_window( + a_lds_block, make_tuple(number{}, number{}), {0, 0}); + + // A LDS tile for block GEMM + auto a_lds_gemm_window = make_tile_window( + a_lds_block, make_tuple(number{}, number{}), {0, 0}); + + // Prefetch ----------------------------------------------------------- + // global read 0 + auto a_block_tile = load_tile(a_copy_dram_window); + +#if FEIFEI_DEBUG // debug A global load + int a_dim = a_block_tile.get_num_of_dimension(); + int a_sz = a_block_tile.get_thread_buffer_size(); + + if(threadIdx.x == 0 && blockIdx.x == 0 && threadIdx.y == 0 && blockIdx.y == 0) + { + printf("[PIPELN] a_dim = %d, a_sz = %d\n", a_dim, a_sz); + } + for(auto i = 0; i < a_sz; i++) + { + dbg_f16[gid * DEBUG_CNT + i] = a_block_tile.get_thread_buffer()[i]; + } + return nullptr; +#endif + + // move to 1 + move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + + // LDS write 0 + if constexpr(std::is_same_v) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + PipelinePolicy::template MakeShuffledARegBlockDescriptor()); + shuffle_tile(a_shuffle_tmp, a_block_tile); + const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_shuffle_tmp); + store_tile(a_copy_lds_window, a_block_tile_tmp); + } + else + { + store_tile(a_copy_lds_window, tile_elementwise_in(a_element_func, a_block_tile)); + } + + // Loop --------------------------------------------------------------- + // Do flatmm + block_flatmm(a_lds_gemm_window +#if FEIFEI_DEBUG + , + b_ptr, + dbg_int, + dbg_fp32, + dbg_f168 +#endif + ); + + // Tail --------------------------------------------------------------- + return nullptr; #else // A tile in LDS ADataType* p_a_lds = static_cast(p_smem); - constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor(); + constexpr auto a_lds_block_desc = + PipelinePolicy::template MakeALdsBlockDescriptor(); auto a_lds_block = make_tensor_view(p_a_lds, a_lds_block_desc); @@ -125,7 +215,8 @@ struct FlatmmPipelineAGmemBGmemCRegV1 BDataType* p_b_lds = static_cast( static_cast(static_cast(p_smem) + a_lds_block_space_size_aligned)); - constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor(); + constexpr auto b_lds_block_desc = + PipelinePolicy::template MakeBLdsBlockDescriptor(); auto b_lds_block = make_tensor_view(p_b_lds, b_lds_block_desc); @@ -134,7 +225,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1 make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), make_tuple(number{}, number{}), a_dram_block_window_tmp.get_window_origin(), - Policy::template MakeADramTileDistribution()); + PipelinePolicy::template MakeADramTileDistribution()); // A LDS tile window for store auto a_copy_lds_window = make_tile_window( @@ -145,7 +236,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1 make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(), make_tuple(number{}, number{}), b_dram_block_window_tmp.get_window_origin(), - Policy::template MakeBDramTileDistribution()); + PipelinePolicy::template MakeBDramTileDistribution()); // B LDS tile window for store auto b_copy_lds_window = make_tile_window( @@ -184,7 +275,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1 if constexpr(std::is_same_v) { auto a_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledARegBlockDescriptor()); + PipelinePolicy::template MakeShuffledARegBlockDescriptor()); shuffle_tile(a_shuffle_tmp, a_block_tile); const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_shuffle_tmp); store_tile(a_copy_lds_window, a_block_tile_tmp); @@ -198,7 +289,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1 if constexpr(std::is_same_v) { auto b_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledBRegBlockDescriptor()); + PipelinePolicy::template MakeShuffledBRegBlockDescriptor()); shuffle_tile(b_shuffle_tmp, b_block_tile); const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_shuffle_tmp); store_tile(b_copy_lds_window, b_block_tile_tmp); @@ -235,7 +326,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1 if constexpr(std::is_same_v) { auto b_shuffle_tmp_loop = make_static_distributed_tensor( - Policy::template MakeShuffledBRegBlockDescriptor()); + PipelinePolicy::template MakeShuffledBRegBlockDescriptor()); shuffle_tile(b_shuffle_tmp_loop, b_block_tile); store_tile(b_copy_lds_window, tile_elementwise_in(b_element_func, b_shuffle_tmp_loop)); @@ -271,7 +362,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1 const BDataType* b_ptr, int* dbg_int, float* dbg_fp32, - short* dbg_f168 + void* dbg_f168 #endif ) const { diff --git a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp deleted file mode 100644 index 0cafaa37ab..0000000000 --- a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp +++ /dev/null @@ -1,517 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck_tile/core.hpp" -#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" - -namespace ck_tile { - -// Default policy for GemmPipelineAGmemBGmemCRegV1 -// Default policy class should not be templated, put template on member functions instead -struct FlatmmPipelineAGmemBGmemCRegV1DefaultPolicy -{ - static constexpr auto I0 = number<0>{}; - static constexpr auto I1 = number<1>{}; - static constexpr auto I2 = number<2>{}; - - static constexpr bool TransposeC = true; - -#if 0 - // 2d - template - CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() - { - using namespace ck_tile; - - constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; - - constexpr auto a_lds_block_desc = - make_naive_tensor_descriptor_packed(make_tuple(kMPerBlock, kKPerBlock), number<32>{}); - - return a_lds_block_desc; - } - - // 2d - template - CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor() - { - using namespace ck_tile; - - constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; - constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; - - constexpr auto b_lds_block_desc = - make_naive_tensor_descriptor_packed(make_tuple(kNPerBlock, kKPerBlock), number<32>{}); - - return b_lds_block_desc; - } -#elif 1 - // 3d + padding - template - CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() - { - using namespace ck_tile; - - constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; - - // TODO: this 8 is AK1! should be a policy parameter! - constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(number{}, number{}, number<8>{}), - make_tuple(number<(kMPerBlock + 1) * 8>{}, number<8>{}, number<1>{}), - number<8>{}, - number<1>{}); - - constexpr auto a_lds_block_desc = transform_tensor_descriptor( - a_lds_block_desc_0, - make_tuple(make_pass_through_transform(kMPerBlock), - make_merge_transform(make_tuple(kKPerBlock / 8, 8))), - make_tuple(sequence<1>{}, sequence<0, 2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - return a_lds_block_desc; - } - - // 3d + padding - template - CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor() - { - constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; - constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; - - constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(number{}, number{}, number<8>{}), - make_tuple(number<(kNPerBlock + 1) * 8>{}, number<8>{}, number<1>{}), - number<8>{}, - number<1>{}); - - constexpr auto b_lds_block_desc = transform_tensor_descriptor( - b_lds_block_desc_0, - make_tuple(make_pass_through_transform(kNPerBlock), - make_merge_transform(make_tuple(kKPerBlock / 8, 8))), - make_tuple(sequence<1>{}, sequence<0, 2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - return b_lds_block_desc; - } - - template - CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA() - { - constexpr index_t smem_size_a = sizeof(typename Problem::ADataType) * - MakeALdsBlockDescriptor().get_element_space_size(); - return smem_size_a; - } - - template - CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeB() - { - constexpr index_t smem_size_b = sizeof(typename Problem::BDataType) * - MakeBLdsBlockDescriptor().get_element_space_size(); - return smem_size_b; - } - - template - CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() - { - constexpr index_t smem_size_a = GetSmemSizeA(); - constexpr index_t smem_size_b = GetSmemSizeB(); - constexpr index_t smem_size = smem_size_a + smem_size_b; - - return smem_size; - } - - template - CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackA() - { - using ADataType = remove_cvref_t; - return Problem::VectorLoadSize / sizeof(ADataType); - } - - template - CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackB() - { - using BDataType = remove_cvref_t; - return Problem::VectorLoadSize / sizeof(BDataType); - } -#elif 1 - // fake XOR - template - CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() - { - using namespace ck_tile; - - using ADataType = remove_cvref_t; - - constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; - - constexpr auto a_lds_block_desc_d1_d2_d3 = make_naive_tensor_descriptor_packed( - make_tuple(number{}, number<2>{}, number{}), - number{}); - - constexpr index_t kK1 = 16 / sizeof(ADataType); - - constexpr auto a_lds_block_desc_d4_d5_d6 = transform_tensor_descriptor( - a_lds_block_desc_d1_d2_d3, - make_tuple( - make_xor_transform(make_tuple(number{}, number{}), kK1), - make_pass_through_transform(2)), - make_tuple(sequence<0, 2>{}, sequence<1>{}), - make_tuple(sequence<0, 2>{}, sequence<1>{})); - - constexpr auto a_lds_block_desc_m_k = transform_tensor_descriptor( - a_lds_block_desc_d4_d5_d6, - make_tuple(make_merge_transform(make_tuple(number{}, number<2>{})), - make_pass_through_transform(kKPerBlock)), - make_tuple(sequence<0, 1>{}, sequence<2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - return a_lds_block_desc_m_k; - } - - // fake XOR - template - CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor() - { - using namespace ck_tile; - - using BDataType = remove_cvref_t; - - constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; - constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; - - constexpr auto b_lds_block_desc_d1_d2_d3 = make_naive_tensor_descriptor_packed( - make_tuple(number{}, number<2>{}, number{}), - number{}); - - constexpr index_t kK1 = 16 / sizeof(BDataType); - - constexpr auto b_lds_block_desc_d4_d5_d6 = transform_tensor_descriptor( - b_lds_block_desc_d1_d2_d3, - make_tuple( - make_xor_transform(make_tuple(number{}, number{}), kK1), - make_pass_through_transform(2)), - make_tuple(sequence<0, 2>{}, sequence<1>{}), - make_tuple(sequence<0, 2>{}, sequence<1>{})); - - constexpr auto b_lds_block_desc_n_k = transform_tensor_descriptor( - b_lds_block_desc_d4_d5_d6, - make_tuple(make_merge_transform(make_tuple(number{}, number<2>{})), - make_pass_through_transform(kKPerBlock)), - make_tuple(sequence<0, 1>{}, sequence<2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - return b_lds_block_desc_n_k; - } -#endif - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution() - { - using ADataType = remove_cvref_t; - using ALayout = remove_cvref_t; - - constexpr index_t BlockSize = Problem::kBlockSize; - - constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - - if constexpr(std::is_same_v) - { - constexpr index_t M1 = Problem::VectorLoadSize / sizeof(ADataType); - constexpr index_t M0 = MPerBlock / M1; - constexpr index_t total_pixels = MPerBlock * KPerBlock / BlockSize; - static_assert(total_pixels % M1 == 0); - constexpr index_t K3 = total_pixels / M1; - constexpr index_t KPack = GetSmemPackA(); - static_assert(KPack % K3 == 0); - constexpr index_t K2 = KPack / K3; - if constexpr(get_warp_size() % (K2 * M0)) - { - constexpr index_t K1 = get_warp_size() / (K2 * M0); - constexpr index_t K0 = BlockSize / get_warp_size(); - static_assert(KPerBlock == K0 * K1 * K2 * K3); - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<2, 1, 2>>, - tuple, sequence<1, 0, 2>>, - sequence<2, 1>, - sequence<3, 1>>{}); - } - else - { - constexpr index_t K1 = (K2 * M0) / get_warp_size(); - constexpr index_t K2_m = K2 / K1; - constexpr index_t K0 = BlockSize / get_warp_size() / K1; - static_assert(KPerBlock == K0 * K1 * K2_m * K3); - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<0, 2>>, - sequence<2, 1>, - sequence<3, 1>>{}); - } - } - else - { - constexpr index_t K1 = 16 / sizeof(ADataType); - constexpr index_t K0 = KPerBlock / K1; - constexpr index_t M2 = get_warp_size() / K0; - // coalesce reading for each blocks - if constexpr(get_warp_size() % (M2 * K0) == 0) - { - constexpr index_t M1 = BlockSize / get_warp_size(); - static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error."); - static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error."); - constexpr index_t M0 = MPerBlock / (M2 * M1); - static_assert(M0 * M1 * M2 == MPerBlock, - "Incorrect M0, M2, M1 configuration! " - "M0, M1, M2 must cover whole MPerBlock!"); - - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<2, 0>>, - sequence<1, 2>, - sequence<0, 1>>{}); - } - else - { - constexpr index_t M0 = BlockSize / get_warp_size(); - constexpr index_t M1 = MPerBlock / (M2 * M0); - static_assert(M0 * M1 * M2 == MPerBlock, - "Incorrect M0, M1, M2 configuration! " - "M0, M1, M2 must cover whole MPerBlock!"); - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<2, 0>>, - sequence<1, 2>, - sequence<1, 1>>{}); - } - } - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution() - { - using BDataType = remove_cvref_t; - using BLayout = remove_cvref_t; - - constexpr index_t BlockSize = Problem::kBlockSize; - - constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - - if constexpr(std::is_same_v) - { - constexpr index_t N1 = Problem::VectorLoadSize / sizeof(BDataType); - constexpr index_t N0 = NPerBlock / N1; - constexpr index_t total_pixels = NPerBlock * KPerBlock / BlockSize; - static_assert(total_pixels % N1 == 0); - constexpr index_t K3 = total_pixels / N1; - constexpr index_t KPack = GetSmemPackB(); - static_assert(KPack % K3 == 0); - constexpr index_t K2 = KPack / K3; - if constexpr(get_warp_size() % (K2 * N0) == 0) - { - constexpr index_t K1 = get_warp_size() / (K2 * N0); - constexpr index_t K0 = BlockSize / get_warp_size(); - static_assert(KPerBlock == K0 * K1 * K2 * K3); - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<2, 1, 2>>, - tuple, sequence<1, 0, 2>>, - sequence<2, 1>, - sequence<3, 1>>{}); - } - else - { - constexpr index_t K1 = (K2 * N0) / get_warp_size(); - constexpr index_t K2_m = K2 / K1; - constexpr index_t K0 = BlockSize / get_warp_size() / K1; - static_assert(KPerBlock == K0 * K1 * K2_m * K3); - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<0, 2>>, - sequence<2, 1>, - sequence<3, 1>>{}); - } - } - else - { - - constexpr index_t K1 = Problem::VectorLoadSize / sizeof(BDataType); - constexpr index_t K0 = KPerBlock / K1; - constexpr index_t N2 = get_warp_size() / K0; - // coalesce reading for each blocks - if constexpr(get_warp_size() % (N2 * K0) == 0) - { - constexpr index_t N1 = BlockSize / get_warp_size(); - static_assert(N2 != 0, "N2 is zero, which will lead to a division by zero error."); - static_assert(N1 != 0, "N1 is zero, which will lead to a division by zero error."); - constexpr index_t N0 = NPerBlock / (N2 * N1); - static_assert(N0 * N1 * N2 == NPerBlock, - "Incorrect N0, N1, N2 configuration! " - "N0, N1, N2 must cover whole NPerBlock!"); - - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<2, 0>>, - sequence<1, 2>, - sequence<0, 1>>{}); - } - // coalesce reading for each warps - else - { - constexpr index_t N0 = BlockSize / get_warp_size(); - constexpr index_t N1 = NPerBlock / (N2 * N0); - static_assert(N0 * N1 * N2 == NPerBlock, - "Incorrect N0, N1, N2 configuration! " - "N0, N1, N2 must cover whole NPerBlock!"); - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<2, 0>>, - sequence<1, 2>, - sequence<1, 1>>{}); - } - } - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledBRegBlockDescriptor() - { - using BLayout = remove_cvref_t; - using BDataType = remove_cvref_t; - static_assert(std::is_same_v); - constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; - constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; - - constexpr index_t N1 = Problem::VectorLoadSize / sizeof(BDataType); - constexpr index_t N0 = kNPerBlock / N1; - constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; - static_assert(total_pixels % N1 == 0); - constexpr index_t K3 = total_pixels / N1; - constexpr index_t kKPack = GetSmemPackB(); - static_assert(kKPack % K3 == 0); - constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave - constexpr index_t warp_size = get_warp_size(); - if constexpr(warp_size % (K2 * N0) == 0) - { - constexpr index_t K1 = warp_size / (K2 * N0); - constexpr index_t K0 = kBlockSize / warp_size; - - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<2, 1, 2>>, - tuple, sequence<1, 0, 2>>, - sequence<1, 2>, - sequence<1, 3>>{}); - } - else - { - constexpr index_t K1 = (K2 * N0) / get_warp_size(); - constexpr index_t K2_m = K2 / K1; - constexpr index_t K0 = kBlockSize / get_warp_size() / K1; - static_assert(kKPerBlock == K0 * K1 * K2_m * K3); - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<0, 2>>, - sequence<1, 2>, - sequence<1, 3>>{}); - } - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledARegBlockDescriptor() - { - using ALayout = remove_cvref_t; - using ADataType = remove_cvref_t; - static_assert(std::is_same_v); - constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; - - constexpr index_t M1 = Problem::VectorLoadSize / sizeof(ADataType); - constexpr index_t M0 = kMPerBlock / M1; - constexpr index_t total_pixels = kMPerBlock * kKPerBlock / kBlockSize; - static_assert(total_pixels % M1 == 0); - constexpr index_t K3 = total_pixels / M1; - constexpr index_t kKPack = GetSmemPackA(); - static_assert(kKPack % K3 == 0); - constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave - constexpr index_t warp_size = get_warp_size(); - if constexpr(warp_size % (K2 * M0) == 0) - { - constexpr index_t K1 = warp_size / (K2 * M0); - constexpr index_t K0 = kBlockSize / warp_size; - - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<2, 1, 2>>, - tuple, sequence<1, 0, 2>>, - sequence<1, 2>, - sequence<1, 3>>{}); - } - else - { - constexpr index_t K1 = (K2 * M0) / get_warp_size(); - constexpr index_t K2_m = K2 / K1; - constexpr index_t K0 = kBlockSize / get_warp_size() / K1; - static_assert(kKPerBlock == K0 * K1 * K2_m * K3); - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<0, 2>>, - sequence<1, 2>, - sequence<1, 3>>{}); - } - } - - CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() { return TransposeC; } - - template - CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() - { - using AccDataType = float; - using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; - using WarpTile = typename Problem::BlockGemmShape::WarpTile; - using WarpGemm = WarpGemmMfmaDispatcher; - using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy; - - return BlockUniversalGemmAsBsCr{}; - } -}; - -} // namespace ck_tile diff --git a/include/ck_tile/ops/flatmm/pipeline/flatmm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/flatmm/pipeline/flatmm_universal_pipeline_ag_bg_cr_policy.hpp index c60a1b0b1e..70f294223b 100644 --- a/include/ck_tile/ops/flatmm/pipeline/flatmm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/flatmm_universal_pipeline_ag_bg_cr_policy.hpp @@ -446,24 +446,25 @@ struct UniversalFlatmmPipelineAgBgCrPolicy CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() { return TransposeC; } template - CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() + CK_TILE_HOST_DEVICE static constexpr auto GetBlockFlatmm() { - using AccDataType = float; - using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; - using WarpTile = typename Problem::BlockGemmShape::WarpTile; - using WarpGemm = WarpGemmMfmaDispatcher; - using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy; - return BlockGemmASmemBSmemCRegV1{}; + using AccDataType = float; + using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; + using WarpTile = typename Problem::BlockGemmShape::WarpTile; + using WarpGemm = WarpGemmMfmaDispatcher; + using BlockFlatmmPolicy = + BlockFlatmmASmemBSmemCRegV1CustomPolicy; + return BlockFlatmmASmemBSmemCRegV1{}; } }; From e889d086c52b411abf0eef7aabbd401146f5f73c Mon Sep 17 00:00:00 2001 From: feifei14119 Date: Fri, 14 Feb 2025 10:51:06 +0800 Subject: [PATCH 3/3] save 51 --- example/ck_tile/18_flatmm/flatmm_basic.cpp | 124 ++++++++ example/ck_tile/18_flatmm/flatmm_basic.hpp | 4 +- .../ck_tile/18_flatmm/run_flatmm_example.inc | 4 +- include/ck_tile/ops/flatmm.hpp | 8 +- .../block_flatmm_asmem_bsmem_creg_v1.hpp | 40 ++- .../ops/flatmm/kernel/flatmm_kernel.hpp | 215 +++++++++++-- .../flatmm_pipeline_agmem_bgmem_creg_v1.hpp | 297 +++++++++++++++--- ...tmm_universal_pipeline_ag_bg_cr_policy.hpp | 126 +++++++- 8 files changed, 728 insertions(+), 90 deletions(-) diff --git a/example/ck_tile/18_flatmm/flatmm_basic.cpp b/example/ck_tile/18_flatmm/flatmm_basic.cpp index a46d61c8f4..b811275a42 100644 --- a/example/ck_tile/18_flatmm/flatmm_basic.cpp +++ b/example/ck_tile/18_flatmm/flatmm_basic.cpp @@ -12,6 +12,7 @@ #include "ck_tile/host.hpp" #include "flatmm_basic.hpp" +#if 1 template float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_config& s) { @@ -117,6 +118,129 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_con return ave_time; } +#else +template +float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_config& s) +{ + // The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part. + constexpr bool kPadM = false; + constexpr bool kPadN = false; + constexpr bool kPadK = false; + + constexpr int kBlockPerCu = 1; + + // This part comes from the Codegen + /*constexpr ck_tile::index_t M_Tile = 128; + constexpr ck_tile::index_t N_Tile = 128; + constexpr ck_tile::index_t K_Tile = 32; + + constexpr ck_tile::index_t M_Warp = 1; + constexpr ck_tile::index_t N_Warp = 4; + constexpr ck_tile::index_t K_Warp = 1; + + constexpr ck_tile::index_t M_Warp_Tile = 32; + constexpr ck_tile::index_t N_Warp_Tile = 32; + constexpr ck_tile::index_t K_Warp_Tile = 8;*/ + + constexpr ck_tile::index_t M_Tile = 128; + constexpr ck_tile::index_t N_Tile = 128; + constexpr ck_tile::index_t K_Tile = 64; + + constexpr ck_tile::index_t M_Warp = 2; + constexpr ck_tile::index_t N_Warp = 2; + constexpr ck_tile::index_t K_Warp = 1; + + constexpr ck_tile::index_t M_Warp_Tile = 32; + constexpr ck_tile::index_t N_Warp_Tile = 32; + constexpr ck_tile::index_t K_Warp_Tile = 16; + + using CodegenGemmShape = + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; + + using TilePartitioner = ck_tile::GemmTile1DPartitioner; + + using CodegenGemmTraits = + ck_tile::TileGemmTraits; + using CodegenPipelineProblem = ck_tile:: + GemmPipelineProblem; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + + using CodegenFlatmmPolicy = ck_tile::UniversalFlatmmPipelineAgBgCrPolicy; + using CodegenFlatmmPipeline = + ck_tile::FlatmmPipelineAGmemBGmemCRegV1; + + // ToDo: Will add the codegen part to test different pipeline policies in GEMM. + // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. + using Kernel = ck_tile::FlatmmKernel; + + auto kargs = Kernel::MakeKernelArgs(args); + + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + constexpr dim3 blocks = Kernel::BlockSize(); + +#if FEIFEI_DEBUG + /*using BlockFlatmmStruct = ck_tile::remove_cvref_t())>; + auto block_flatmm = BlockFlatmmStruct(); // struct BlockFlatmmASmemBSmemCRegV1 + //auto ADramTileDistr = CodegenFlatmmPolicy::template MakeADramTileDistribution(); + + auto kernel = Kernel{}; + using SplitKBatchOffset = typename Kernel::SplitKBatchOffset; + SplitKBatchOffset splitk_batch_offset(args); + auto gemm_tensor_views_tuple = Kernel::template MakeGemmTensorViews( + args.a_ptr, + args.b_shuffle_ptr, + args.c_ptr, + kargs, splitk_batch_offset);*/ + + + printf("[FEIFEI] --- flatmm_calc() ---\n"); + printf("[FEIFEI] BlockPerCu = %d\n", static_cast(kBlockPerCu)); + printf("[FEIFEI] BlockTile M = %d\n", static_cast(M_Tile)); + printf("[FEIFEI] BlockTile N = %d\n", static_cast(N_Tile)); + printf("[FEIFEI] BlockTile K = %d\n", static_cast(K_Tile)); + printf("[FEIFEI] WavePerBlock M = %d\n", static_cast(M_Warp)); + printf("[FEIFEI] WavePerBlock N = %d\n", static_cast(N_Warp)); + printf("[FEIFEI] WavePerBlock K = %d\n", static_cast(K_Warp)); + printf("[FEIFEI] WaveTile M = %d\n", static_cast(M_Warp_Tile)); + printf("[FEIFEI] WaveTile N = %d\n", static_cast(N_Warp_Tile)); + printf("[FEIFEI] WaveTile K = %d\n", static_cast(K_Warp_Tile)); + printf("[FEIFEI] grids = [%d, %d, %d]\n", grids.x, grids.y, grids.z); + printf("[FEIFEI] blocks = [%d, %d, %d]\n", blocks.x, blocks.y, blocks.z); +#endif + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args:" + << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } + + float ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + + return ave_time; +} +#endif #include "run_flatmm_example.inc" diff --git a/example/ck_tile/18_flatmm/flatmm_basic.hpp b/example/ck_tile/18_flatmm/flatmm_basic.hpp index 19d0d362ef..6989824fb5 100644 --- a/example/ck_tile/18_flatmm/flatmm_basic.hpp +++ b/example/ck_tile/18_flatmm/flatmm_basic.hpp @@ -80,12 +80,12 @@ auto create_args(int argc, char* argv[]) .insert("n", "128", "n dimension") // 128, 4096 .insert("k", "64", "k dimension") // 64, 2048 .insert("a_layout", "R", "A tensor data layout - Row by default") - .insert("b_layout", "R", "B tensor data layout - Row by default") + .insert("b_layout", "C", "B tensor data layout - Row by default") .insert("c_layout", "R", "C tensor data layout - Row by default") .insert("stride_a", "0", "Tensor A stride") .insert("stride_b", "0", "Tensor B stride") .insert("stride_c", "0", "Tensor C stride") - .insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU") + .insert("v", "1", "0. No validation, 1. Validation on CPU, 2. Validation on GPU") .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") .insert("warmup", "50", "number of iterations before benchmark the kernel") .insert("repeat", "100", "number of iterations to benchmark the kernel") diff --git a/example/ck_tile/18_flatmm/run_flatmm_example.inc b/example/ck_tile/18_flatmm/run_flatmm_example.inc index 9ee7889a17..66944768a6 100644 --- a/example/ck_tile/18_flatmm/run_flatmm_example.inc +++ b/example/ck_tile/18_flatmm/run_flatmm_example.inc @@ -415,8 +415,8 @@ int run_flatmm_example_with_layouts(int argc, // b_shuffle { std::ofstream file("ff_b_shuffle_host.txt"); - int X = static_cast(K); - int Y = static_cast(N); + int X = 32 * 32; + int Y = static_cast(N) * static_cast(M) / X; file << " [b_shuffle_host]: Row = " << Y << ", Col = " << X << std::endl; for(int y = 0; y < Y; y++) diff --git a/include/ck_tile/ops/flatmm.hpp b/include/ck_tile/ops/flatmm.hpp index 0fec22675d..cd792a9002 100644 --- a/include/ck_tile/ops/flatmm.hpp +++ b/include/ck_tile/ops/flatmm.hpp @@ -24,10 +24,10 @@ // kernel #include "ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp" -#include "ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp" -#include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32.hpp" -#include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32_itl.hpp" -#include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp" +//#include "ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp" +//#include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32.hpp" +//#include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32_itl.hpp" +//#include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp index 4efff75763..8995110e4b 100644 --- a/include/ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp @@ -55,14 +55,13 @@ struct BlockFlatmmASmemBSmemCRegV1 return c_block_tensor; } -#if 1 +#if 0 // C += A * B // template - template - CK_TILE_DEVICE void operator()(const ABlockWindow& a_block_window + template + CK_TILE_DEVICE void operator()(const ABlockWindow& a_block_window, const BBlockWindow& b_block_window #if FEIFEI_DEBUG , - const BDataType* b_ptr, int* dbg_int, float* dbg_fp32, void* dbg_f168 @@ -101,14 +100,12 @@ struct BlockFlatmmASmemBSmemCRegV1 */ constexpr index_t MPerBlock = ABlockWindow{}.get_window_lengths()[number<0>{}]; - // constexpr index_t NPerBlock = BBlockWindow{}.get_window_lengths()[number<0>{}]; + constexpr index_t NPerBlock = BBlockWindow{}.get_window_lengths()[number<0>{}]; constexpr index_t KPerBlock = ABlockWindow{}.get_window_lengths()[number<1>{}]; - /* static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && KPerBlock == BlockGemmShape::kK, "wrong!"); - */ constexpr auto config = BlockPolicy::template GetWarpGemmMWarpNWarp(); using WG = remove_cvref_t())>; @@ -117,11 +114,11 @@ struct BlockFlatmmASmemBSmemCRegV1 constexpr index_t NWarp = config.template at<2>(); constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); - // constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); constexpr index_t KIterPerWarp = KPerBlock / WG::kK; constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp; - // constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp; + constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp; constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp; const index_t iMWarp = get_warp_id() / NWarp; @@ -133,6 +130,7 @@ struct BlockFlatmmASmemBSmemCRegV1 make_tuple(number{}, number{}), a_block_window.get_window_origin() + multi_index<2>{iMWarp * WG::kM, 0}, make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); + statically_indexed_array< statically_indexed_array, @@ -151,7 +149,29 @@ struct BlockFlatmmASmemBSmemCRegV1 // Warp loop in block: constexpr index_t kIter = 0; constexpr index_t mIter = 0; - const auto a_warp_tensor = load_tile(a_warp_windows(number{})(number{})); + const auto a_warp_tensor = load_tile(a_warp_window_tmp); + +#if FEIFEI_DEBUG + if(threadIdx.x == 0 && blockIdx.x == 0 && threadIdx.y == 0 && blockIdx.y == 0) + { + printf("[BLOCK ] WG::kM = %d, WG::kM = %d, WG::kK = %d, WG::kKPerThread = %d\n", WG::kM, WG::kN, WG::kK, WG::kKPerThread); + printf("[BLOCK ] MIterPerWarp = %d, NIterPerWarp = %d, KIterPerWarp = %d\n", MIterPerWarp, NIterPerWarp, KIterPerWarp); + } + + // debug A lds read + int warp_tile_size_per_thread = a_warp_tensor.get_thread_buffer_size(); + if(threadIdx.x == 0 && blockIdx.x == 0 && threadIdx.y == 0 && blockIdx.y == 0) + { + printf("[BLOCK ] warp_tile_size_per_thread = %d\n", warp_tile_size_per_thread); + } + for(auto i = 0; i < warp_tile_size_per_thread; i++) + { + dbg_f16[gid * DEBUG_CNT + i] = a_warp_tensor.get_thread_buffer()[i]; + } + + return ; +#endif + #if 1 // feifei TODO: Implement gemm here diff --git a/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp index de6040b758..e5a5892e35 100644 --- a/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp @@ -141,7 +141,7 @@ struct FlatmmKernel struct SplitKBatchOffset { - __device__ SplitKBatchOffset(const FlatmmKernelArgs& kargs, + CK_TILE_DEVICE SplitKBatchOffset(const FlatmmKernelArgs& kargs, const std::size_t k_id = blockIdx.z) { constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}); @@ -175,7 +175,42 @@ struct FlatmmKernel splitted_k = kargs.K - KRead * (kargs.KBatch - 1); } } +#if FEIFEI_DEBUG + CK_TILE_HOST SplitKBatchOffset(const FlatmmHostArgs& hargs, + const std::size_t k_id = 0) + { + constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}); + const index_t K_t = hargs.k_batch * K1; + const index_t KRead = (hargs.K + K_t - 1) / K_t * K1; + + if constexpr(std::is_same_v) + { + a_k_split_offset = k_id * KRead; + } + else if constexpr(std::is_same_v) + { + a_k_split_offset = k_id * KRead * hargs.stride_A; + } + + if constexpr(std::is_same_v) + { + b_k_split_offset = k_id * KRead * hargs.stride_B; + } + else if constexpr(std::is_same_v) + { + b_k_split_offset = k_id * KRead; + } + if(k_id < static_cast(hargs.k_batch - 1)) + { + splitted_k = KRead; + } + else + { + splitted_k = hargs.K - KRead * (hargs.k_batch - 1); + } + } +#endif index_t a_k_split_offset; index_t b_k_split_offset; index_t splitted_k; // problem K after splitted @@ -362,6 +397,9 @@ struct FlatmmKernel return make_tuple(a_tensor_view, b_tensor_view, c_tensor_view); } +#if 1 + + template CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views) { @@ -446,6 +484,118 @@ struct FlatmmKernel return make_tuple(a_block_window, b_block_window, c_block_window); } +#else + template + CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views) + { + const auto& a_pad_view = [&]() { + const auto& a_tensor_view = views.at(I0); + if constexpr(std::is_same_v) + { + return pad_tensor_view(a_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(a_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + }(); + + const auto& b_pad_view = [&]() { + const auto& b_tensor_view = views.at(I1); + if constexpr(std::is_same_v) + { + return pad_tensor_view(b_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(b_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + }(); + + // TODO vector write in for C in ColMajor + const auto& c_pad_view = [&]() { + const auto& c_tensor_view = views.at(I2); + if constexpr(std::is_same_v) + { + return pad_tensor_view(c_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(c_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + }(); + + return make_tuple(a_pad_view, b_pad_view, c_pad_view); + } + + template + CK_TILE_DEVICE static auto + MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n) + { + const auto& a_pad_view = views.at(I0); + const auto& b_pad_view = views.at(I1); + const auto& c_pad_view = views.at(I2); + + const auto& a_block_window = [&]() { + if constexpr(std::is_same_v) + { + return make_tile_window(a_pad_view, + make_tuple(number{}, + number{}), + {i_m, 0}); + } + else + { + return make_tile_window(a_pad_view, + make_tuple(number{}, + number{}), + {0, i_m}); + } + }(); + + const auto& b_block_window = [&]() { + if constexpr(std::is_same_v) + { + return make_tile_window(b_pad_view, + make_tuple(number{}, + number{}), + {i_n, 0}); + } + else + { + return make_tile_window(b_pad_view, + make_tuple(number{}, + number{}), + {0, i_n}); + } + }(); + + auto c_block_window = make_tile_window( + c_pad_view, + make_tuple(number{}, number{}), + {i_m, i_n}); + + return make_tuple(a_block_window, b_block_window, c_block_window); + } +#endif /** * @brief Runs single GEMM problem cooperatively by whole workgroup. @@ -477,20 +627,29 @@ struct FlatmmKernel #endif ) { +#if FEIFEI_DEBUG + uint32_t tidx = threadIdx.x; + uint32_t tidy = threadIdx.y; + uint32_t bidx = blockIdx.x; + uint32_t bidy = blockIdx.y; + uint32_t bdmx = blockDim.x; + uint32_t bdmy = blockDim.y; + uint32_t gdmx = gridDim.x; + uint32_t gdmy = gridDim.y; + uint32_t gid = ((bdmx * bdmy) * gdmx) * bidy + (bdmx * bdmy) * bidx + bdmx * tidy + tidx; + + half_t* dbg_f16 = static_cast(kargs.dbg_f168_ptr); +#endif // Create Flatmm tensor views, pad views and tile windows const auto& gemm_tensor_views_tuple = MakeGemmTensorViews( - a_ptr, b_shuffle_ptr, c_ptr, kargs, splitk_batch_offset); - // Debug origin layout - // const auto& gemm_tensor_views_tuple = MakeGemmTensorViews( - // a_ptr, b_ptr, c_ptr, kargs, splitk_batch_offset); + a_ptr, b_ptr, c_ptr, kargs, splitk_batch_offset); const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); const auto& gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); - #if FEIFEI_DEBUG //////////////////////////////////////////////////////// const auto& a_gemm_tensor_views = gemm_tensor_views_tuple.at(I0); // tensor_view @@ -533,39 +692,51 @@ struct FlatmmKernel const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k); + const auto& b_flat_tensor_view = [&]() { + return make_naive_tensor_view( + b_shuffle_ptr, + make_tuple(kargs.N, splitk_batch_offset.splitted_k), + make_tuple(kargs.stride_B, 1), + number{}, + number<1>{}); + }(); + const auto& b_flat_pad_view = [&]() { + return pad_tensor_view(b_flat_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + }(); + const auto& b_flat_block_window = make_tile_window( + b_flat_pad_view, + make_tuple(number{}, number{}), + {block_idx_n, 0}); + // Run GEMM cooperatively by whole workgroup. const auto& a_block_window = gemm_tile_windows.at(I0); const auto& b_block_window = gemm_tile_windows.at(I1); const auto& c_block_tile = FlatmmPipeline{}.template operator()(a_block_window, - b_block_window, + b_flat_block_window, num_loop, smem_ptr #if FEIFEI_DEBUG , - b_ptr, + b_block_window, dbg_int, dbg_fp32, dbg_f168 #endif ); - // feifei TODO: Un-comment bellow once pipeline() is implemented -#if 0 // Run Epilogue Pipeline - auto& c_block_window = gemm_tile_windows.at(I2); - - constexpr bool is_output_c_reg_transposed = - EpiloguePipeline::IsOutputTransposed() != FlatmmPipeline::IsTransposeC(); - if constexpr((DstInMemOp == memory_operation_enum::set) || (sizeof(CDataType) > 2) || - (FlatmmPipeline::VectorSizeC % 2 == 0 && - std::is_same_v && - is_output_c_reg_transposed)) + /*auto c_block_window = gemm_tile_windows.at(I2); + + if(threadIdx.x == 0 && blockIdx.x == 0 && threadIdx.y == 0 && blockIdx.y == 0) { - EpiloguePipeline{} - .template operator()( - c_block_window, c_block_tile); + printf("[PIPELN] C = %.3f\n", type_convert(c_block_tile.get_thread_buffer()[0])); } -#endif + EpiloguePipeline{} + .template operator()( + c_block_window, c_block_tile, smem_ptr);*/ } CK_TILE_DEVICE void operator()(FlatmmKernelArgs kargs) const diff --git a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp index bcd1ac7a14..7be734d1f4 100644 --- a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -14,6 +14,10 @@ namespace ck_tile { template // feifei TODO: add default policy struct FlatmmPipelineAGmemBGmemCRegV1 { + static constexpr auto I0 = number<0>{}; + static constexpr auto I1 = number<1>{}; + static constexpr auto I2 = number<2>{}; + using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; using CDataType = remove_cvref_t; @@ -62,18 +66,22 @@ struct FlatmmPipelineAGmemBGmemCRegV1 } template + typename BElementFunction +#if FEIFEI_DEBUG + , typename BDramBlockWindowTmp +#endif + > CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, const AElementFunction& a_element_func, - const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, const BElementFunction& b_element_func, index_t num_loop, void* p_smem #if FEIFEI_DEBUG , - const BDataType* b_ptr, + const BDramBlockWindowTmp& b_dram_block_window_tmp, int* dbg_int, float* dbg_fp32, void* dbg_f168 @@ -111,63 +119,107 @@ struct FlatmmPipelineAGmemBGmemCRegV1 "wrong!"); static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kNPerBlock == BFlatBlockWindowTmp{}.get_window_lengths()[number<0>{}] && kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}], "wrong!"); + if(threadIdx.x == 0 && blockIdx.x == 0 && threadIdx.y == 0 && blockIdx.y == 0) + { + printf("[PIPELN] kMPerBlock = %d, winN = %d\n", kMPerBlock, + static_cast(ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}])); + printf("[PIPELN] kNPerBlock = %d, winN = %d\n", kNPerBlock, + static_cast(BFlatBlockWindowTmp{}.get_window_lengths()[number<0>{}])); + printf("[PIPELN] kNPerBlock = %d, winN = %d\n", kNPerBlock, + static_cast(BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}])); + printf("[PIPELN] kKPerBlock = %d, winN = %d\n", kKPerBlock, + static_cast(ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}])); + } #if 1 // feifei TODO: Implement gemm here // Get block flatmm auto block_flatmm = BlockFlatmm(); // struct BlockFlatmmASmemBSmemCRegV1 - // A tile in LDS - ADataType* p_a_lds = static_cast(p_smem); - - constexpr auto a_lds_block_desc = - PipelinePolicy::template MakeALdsBlockDescriptor(); - - auto a_lds_block = make_tensor_view(p_a_lds, a_lds_block_desc); - - constexpr index_t a_lds_block_space_size_aligned = - integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16) * - 16; - // A DRAM tile window for load - auto a_copy_dram_window = - make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), + auto a_copy_dram_window = // tile_window_with_static_distribution + make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), // from kernel gemm_pad_views make_tuple(number{}, number{}), a_dram_block_window_tmp.get_window_origin(), PipelinePolicy::template MakeADramTileDistribution()); - // A LDS tile window for store - auto a_copy_lds_window = make_tile_window( - a_lds_block, make_tuple(number{}, number{}), {0, 0}); + // B DRAM tile window for load + auto b_copy_dram_window = // tile_window_with_static_distribution + make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(), // from kernel gemm_pad_views + make_tuple(number{}, number{}), + b_dram_block_window_tmp.get_window_origin(), + PipelinePolicy::template MakeBDramTileDistribution()); - // A LDS tile for block GEMM - auto a_lds_gemm_window = make_tile_window( - a_lds_block, make_tuple(number{}, number{}), {0, 0}); + // B flat DRAM window for load + auto b_flat_distribution = PipelinePolicy::template MakeBFlatDramTileDistribution(); + auto b_flat_dram_window = // tile_window_with_static_distribution + make_tile_window(b_flat_dram_block_window_tmp.get_bottom_tensor_view(), // from kernel gemm_pad_views + make_tuple(number{}, number{} * 4), + b_flat_dram_block_window_tmp.get_window_origin(), + b_flat_distribution); // Prefetch ----------------------------------------------------------- // global read 0 auto a_block_tile = load_tile(a_copy_dram_window); + auto b_block_tile = load_tile(b_copy_dram_window); + auto b_flat_tile = load_tile(b_flat_dram_window); -#if FEIFEI_DEBUG // debug A global load - int a_dim = a_block_tile.get_num_of_dimension(); - int a_sz = a_block_tile.get_thread_buffer_size(); - +#if FEIFEI_DEBUG + // debug A global load + int a_block_tile_size_per_thread = a_block_tile.get_thread_buffer_size(); if(threadIdx.x == 0 && blockIdx.x == 0 && threadIdx.y == 0 && blockIdx.y == 0) { - printf("[PIPELN] a_dim = %d, a_sz = %d\n", a_dim, a_sz); + printf("[PIPELN] a_block_tile_size_per_thread = %d\n", a_block_tile_size_per_thread); } - for(auto i = 0; i < a_sz; i++) + for(auto i = 0; i < a_block_tile_size_per_thread; i++) { dbg_f16[gid * DEBUG_CNT + i] = a_block_tile.get_thread_buffer()[i]; } + + // debug B global load + int b_block_tile_size_per_thread = b_block_tile.get_thread_buffer_size(); + if(threadIdx.x == 0 && blockIdx.x == 0 && threadIdx.y == 0 && blockIdx.y == 0) + { + printf("[PIPELN] b_block_tile_size_per_thread = %d\n", b_block_tile_size_per_thread); + } + for(auto i = 0; i < b_block_tile_size_per_thread; i++) + { + //dbg_f16[gid * DEBUG_CNT + i] = b_block_tile.get_thread_buffer()[i]; + } + + // debug flat B global load + int b_flat_tile_size_per_thread = b_flat_tile.get_thread_buffer_size(); + if(threadIdx.x == 0 && blockIdx.x == 0 && threadIdx.y == 0 && blockIdx.y == 0) + { + printf("[PIPELN] b_flat_tile_size_per_thread = %d\n", b_flat_tile_size_per_thread); + } + for(auto i = 0; i < b_flat_tile_size_per_thread; i++) + { + //dbg_f16[gid * DEBUG_CNT + i + b_block_tile_size_per_thread + 4] = b_flat_tile.get_thread_buffer()[i]; + } + return nullptr; #endif +#if 0 // move to 1 move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + move_tile_window(b_copy_dram_window, {0, kKPerBlock}); + + // A tile in LDS + ADataType* p_a_lds = static_cast(p_smem); + constexpr auto a_lds_block_desc = PipelinePolicy::template MakeALdsBlockDescriptor(); + auto a_lds_block = make_tensor_view(p_a_lds, a_lds_block_desc); + + // A LDS tile window for store + auto a_copy_lds_window = make_tile_window( // tile_window_with_static_lengths + a_lds_block, make_tuple(number{}, number{}), {0, 0}); + // A LDS tile for block GEMM + auto a_lds_gemm_window = make_tile_window( // tile_window_with_static_lengths + a_lds_block, make_tuple(number{}, number{}), {0, 0}); // LDS write 0 if constexpr(std::is_same_v) @@ -183,12 +235,26 @@ struct FlatmmPipelineAGmemBGmemCRegV1 store_tile(a_copy_lds_window, tile_elementwise_in(a_element_func, a_block_tile)); } + // B tile in LDS + constexpr index_t a_lds_block_space_size_aligned = integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16) * 16; + BDataType* p_b_lds = static_cast(static_cast(static_cast(p_smem) + a_lds_block_space_size_aligned)); + + constexpr auto b_lds_block_desc = PipelinePolicy::template MakeBLdsBlockDescriptor(); + auto b_lds_block = make_tensor_view(p_b_lds, b_lds_block_desc); + + // B LDS tile window for store + auto b_copy_lds_window = make_tile_window( + b_lds_block, make_tuple(number{}, number{}), {0, 0}); + // B LDS tile for block GEMM + auto b_lds_gemm_window = make_tile_window( + b_lds_block, make_tuple(number{}, number{}), {0, 0}); + // Loop --------------------------------------------------------------- // Do flatmm - block_flatmm(a_lds_gemm_window + block_sync_lds(); + block_flatmm(a_lds_gemm_window, b_lds_gemm_window #if FEIFEI_DEBUG , - b_ptr, dbg_int, dbg_fp32, dbg_f168 @@ -198,6 +264,157 @@ struct FlatmmPipelineAGmemBGmemCRegV1 // Tail --------------------------------------------------------------- return nullptr; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + // A tile in LDS + /*ADataType* p_a_lds = static_cast(p_smem); + + constexpr auto a_lds_block_desc = + PipelinePolicy::template MakeALdsBlockDescriptor(); + + auto a_lds_block = make_tensor_view(p_a_lds, a_lds_block_desc); + + constexpr index_t a_lds_block_space_size_aligned = + integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16) * + 16; + + // B tile in LDS + BDataType* p_b_lds = static_cast( + static_cast(static_cast(p_smem) + a_lds_block_space_size_aligned)); + + constexpr auto b_lds_block_desc = + PipelinePolicy::template MakeBLdsBlockDescriptor(); + + auto b_lds_block = make_tensor_view(p_b_lds, b_lds_block_desc); + + // A LDS tile window for store + auto a_copy_lds_window = make_tile_window( + a_lds_block, make_tuple(number{}, number{}), {0, 0}); + + // B LDS tile window for store + auto b_copy_lds_window = make_tile_window( + b_lds_block, make_tuple(number{}, number{}), {0, 0}); + + // A LDS tile for block GEMM + auto a_lds_gemm_window = make_tile_window( + a_lds_block, make_tuple(number{}, number{}), {0, 0}); + + // B LDS tile for block GEMM + auto b_lds_gemm_window = make_tile_window( + b_lds_block, make_tuple(number{}, number{}), {0, 0}); + + // Block GEMM + auto block_gemm = BlockFlatmm(); + + // Acc register tile + auto c_block_tile = decltype(block_gemm(a_lds_gemm_window, b_lds_gemm_window)){}; + + // prefetch + // global read 0 + //auto a_block_tile = load_tile(a_copy_dram_window); + //auto b_block_tile = load_tile(b_copy_dram_window); + + { + // move to 1 + move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + move_tile_window(b_copy_dram_window, {0, kKPerBlock}); + + // initialize C + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + + // LDS write 0 + if constexpr(std::is_same_v) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + PipelinePolicy::template MakeShuffledARegBlockDescriptor()); + shuffle_tile(a_shuffle_tmp, a_block_tile); + const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_shuffle_tmp); + store_tile(a_copy_lds_window, a_block_tile_tmp); + } + else + { + store_tile(a_copy_lds_window, tile_elementwise_in(a_element_func, a_block_tile)); + } + + // LDS write 0 + if constexpr(std::is_same_v) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + PipelinePolicy::template MakeShuffledBRegBlockDescriptor()); + shuffle_tile(b_shuffle_tmp, b_block_tile); + const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_shuffle_tmp); + store_tile(b_copy_lds_window, b_block_tile_tmp); + } + else + { + store_tile(b_copy_lds_window, tile_elementwise_in(b_element_func, b_block_tile)); + } + } + + index_t iCounter = num_loop - 1; + while(iCounter > 0) + { + // global read i + 1 + a_block_tile = load_tile(a_copy_dram_window); + b_block_tile = load_tile(b_copy_dram_window); + + block_sync_lds(); + + // GEMM i + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + + block_sync_lds(); + + // move to i + 2 + move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + move_tile_window(b_copy_dram_window, {0, kKPerBlock}); + + // LDS write i + 1 + const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); + store_tile(a_copy_lds_window, a_block_tile_tmp); + + // LDS write i + 1 + if constexpr(std::is_same_v) + { + auto b_shuffle_tmp_loop = make_static_distributed_tensor( + PipelinePolicy::template MakeShuffledBRegBlockDescriptor()); + shuffle_tile(b_shuffle_tmp_loop, b_block_tile); + store_tile(b_copy_lds_window, + tile_elementwise_in(b_element_func, b_shuffle_tmp_loop)); + } + else + { + const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile); + store_tile(b_copy_lds_window, b_block_tile_tmp); + } + + iCounter--; + } + + // tail + { + block_sync_lds(); + + // GEMM num_loop - 1 + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + } + + int c_block_tile_size_per_thread = c_block_tile.get_thread_buffer_size(); + if(threadIdx.x == 0 && blockIdx.x == 0 && threadIdx.y == 0 && blockIdx.y == 0) + { + printf("[PIPELN] c_block_tile_size_per_thread = %d\n", c_block_tile_size_per_thread); + } + for(auto i = 0; i < c_block_tile_size_per_thread; i++) + { + //dbg_fp32[gid * DEBUG_CNT + i] = c_block_tile.get_thread_buffer()[i]; + dbg_fp32[gid * DEBUG_CNT + i] = 3.12f; + c_block_tile.get_thread_buffer()[i] = 1.23f; + } + return c_block_tile;*/ +//////////////////////////////////////////////////////////////////////////////////////////////////// + + #else // A tile in LDS ADataType* p_a_lds = static_cast(p_smem); @@ -352,14 +569,18 @@ struct FlatmmPipelineAGmemBGmemCRegV1 #endif } - template + template CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, - const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, index_t num_loop, void* p_smem #if FEIFEI_DEBUG , - const BDataType* b_ptr, + const BDramBlockWindowTmp& b_dram_block_window_tmp, int* dbg_int, float* dbg_fp32, void* dbg_f168 @@ -369,13 +590,13 @@ struct FlatmmPipelineAGmemBGmemCRegV1 return operator()( a_dram_block_window_tmp, [](const ADataType & a) { return a; }, - b_dram_block_window_tmp, + b_flat_dram_block_window_tmp, [](const BDataType & b) { return b; }, num_loop, p_smem #if FEIFEI_DEBUG , - b_ptr, + b_dram_block_window_tmp, dbg_int, dbg_fp32, dbg_f168 diff --git a/include/ck_tile/ops/flatmm/pipeline/flatmm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/flatmm/pipeline/flatmm_universal_pipeline_ag_bg_cr_policy.hpp index 70f294223b..af6d570b4f 100644 --- a/include/ck_tile/ops/flatmm/pipeline/flatmm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/flatmm_universal_pipeline_ag_bg_cr_policy.hpp @@ -227,15 +227,24 @@ struct UniversalFlatmmPipelineAgBgCrPolicy } else { - constexpr index_t K1 = Problem::VectorLoadSize / sizeof(ADataType); - constexpr index_t K0 = KPerBlock / K1; - constexpr index_t M2 = get_warp_size() / K0; + constexpr index_t K1 = Problem::VectorLoadSize / sizeof(ADataType); // dwordx4 load A elem cnt + constexpr index_t K0 = KPerBlock / K1; // threads cnt in K dim + constexpr index_t M2 = get_warp_size() / K0; // threads cnt in M dim (per wave) if constexpr(get_warp_size() % (M2 * K0) == 0) { - constexpr index_t M1 = BlockSize / get_warp_size(); + constexpr index_t M1 = BlockSize / get_warp_size(); // wave cnt in M dim (per block) static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error."); static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error."); - constexpr index_t M0 = MPerBlock / (M2 * M1); + constexpr index_t M0 = MPerBlock / (M2 * M1); // load repeat times in M dim +#if FEIFEI_DEBUG + if(threadIdx.x == 0 && blockIdx.x == 0 && threadIdx.y == 0 && blockIdx.y == 0) + { + printf("[PIPELN] MakeADramTileDistribution():\n"); + printf("[PIPELN] MPerBlock = %d, KPerBlock = %d, AperBlock = %d\n", MPerBlock, KPerBlock, MPerBlock*KPerBlock); + printf("[PIPELN] BlockSize = %d, warp_size = %d, VectorLoadSize = %d\n", BlockSize, get_warp_size(), Problem::VectorLoadSize); + printf("[PIPELN] K1 = %d, K0 = %d, M2 = %d, M1 = %d, M0 = %d\n", K1, K0, M2, M1, M0); + } +#endif return make_static_tile_distribution( tile_distribution_encoding, tuple, sequence>, @@ -310,18 +319,25 @@ struct UniversalFlatmmPipelineAgBgCrPolicy } else { - - constexpr index_t K1 = Problem::VectorLoadSize / sizeof(BDataType); - constexpr index_t K0 = KPerBlock / K1; - constexpr index_t N2 = get_warp_size() / K0; + constexpr index_t K1 = Problem::VectorLoadSize / sizeof(BDataType); // dwordx4 load B elem cnt + constexpr index_t K0 = KPerBlock / K1; // threads cnt in K dim + constexpr index_t N2 = get_warp_size() / K0; // threads cnt in N dim (per wave) // coalesce reading for each blocks if constexpr(get_warp_size() % (N2 * K0) == 0) { - constexpr index_t N1 = BlockSize / get_warp_size(); + constexpr index_t N1 = BlockSize / get_warp_size(); // wave cnt in N dim (per block) static_assert(N2 != 0, "N2 is zero, which will lead to a division by zero error."); static_assert(N1 != 0, "N1 is zero, which will lead to a division by zero error."); - constexpr index_t N0 = NPerBlock / (N2 * N1); - + constexpr index_t N0 = NPerBlock / (N2 * N1); // load repeat times in N dim +#if FEIFEI_DEBUG + if(threadIdx.x == 0 && blockIdx.x == 0 && threadIdx.y == 0 && blockIdx.y == 0) + { + printf("[PIPELN] MakeBDramTileDistribution():\n"); + printf("[PIPELN] NPerBlock = %d, KPerBlock = %d, BperBlock = %d\n", NPerBlock, KPerBlock, NPerBlock*KPerBlock); + printf("[PIPELN] BlockSize = %d, warp_size = %d, VectorLoadSize = %d\n", BlockSize, get_warp_size(), Problem::VectorLoadSize); + printf("[PIPELN] K1 = %d, K0 = %d, N2 = %d, N1 = %d, N0 = %d\n", K1, K0, N2, N1, N0); + } +#endif return make_static_tile_distribution( tile_distribution_encoding, tuple, sequence>, @@ -347,6 +363,92 @@ struct UniversalFlatmmPipelineAgBgCrPolicy } } + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBFlatDramTileDistribution() + { + using BDataType = remove_cvref_t; + using BLayout = remove_cvref_t; + + constexpr index_t BlockSize = Problem::kBlockSize; + + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + + if constexpr(std::is_same_v) + { + constexpr index_t N1 = Problem::VectorLoadSize / sizeof(BDataType); + constexpr index_t N0 = NPerBlock / N1; + constexpr index_t total_pixels = NPerBlock * KPerBlock / BlockSize; + static_assert(total_pixels % N1 == 0); + constexpr index_t K3 = total_pixels / N1; + constexpr index_t KPack = GetVectorLoadSize(); + static_assert(KPack % K3 == 0); + constexpr index_t K2 = KPack / K3; + if constexpr(get_warp_size() % (K2 * N0) == 0) + { + constexpr index_t K1 = get_warp_size() / (K2 * N0); + constexpr index_t K0 = BlockSize / get_warp_size(); + static_assert(KPerBlock == K0 * K1 * K2 * K3); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1, 2>>, + tuple, sequence<1, 0, 2>>, + sequence<2, 1>, + sequence<3, 1>>{}); + } + else + { + constexpr index_t K1 = (K2 * N0) / get_warp_size(); + constexpr index_t K2_m = K2 / K1; + constexpr index_t K0 = BlockSize / get_warp_size() / K1; + static_assert(KPerBlock == K0 * K1 * K2_m * K3); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<0, 2>>, + sequence<2, 1>, + sequence<3, 1>>{}); + } + } + else + { + constexpr index_t KLoad = Problem::VectorLoadSize / sizeof(BDataType); // dwordx4 load B elem cnt + constexpr index_t KThdInBlk = 64; + constexpr index_t KBlkInTile = 1; + constexpr index_t KRepeat = 1; + constexpr index_t NLoad = 1; // dwordx4 load B elem cnt + constexpr index_t NThdInBlk = 1; + constexpr index_t NBlkInTile = 4; + constexpr index_t NRepeat = 1; +#if FEIFEI_DEBUG + if(threadIdx.x == 0 && blockIdx.x == 0 && threadIdx.y == 0 && blockIdx.y == 0) + { + printf("[PIPELN] MakeBFlatDramTileDistribution():\n"); + printf("[PIPELN] NPerBlock = %d, KPerBlock = %d, BperBlock = %d\n", + NPerBlock, + KPerBlock, + NPerBlock * KPerBlock); + printf("[PIPELN] BlockSize = %d, warp_size = %d, VectorLoadSize = %d\n", + BlockSize, + get_warp_size(), + Problem::VectorLoadSize); + printf("[PIPELN] NRepeat = %d, NBlkInTile = %d, NThdInBlk = %d, NLoad = %d\n", NRepeat, NBlkInTile, NThdInBlk, NLoad); + printf("[PIPELN] KRepeat = %d, KBlkInTile = %d, KThdInBlk = %d, KLoad = %d\n", KRepeat, KBlkInTile, KThdInBlk, KLoad); + } +#endif + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, + sequence>, // first dim + tuple, sequence<1, 2>>, + tuple, sequence<2, 2>>, + sequence<2, 2>, + sequence<0, 3>>{}); + } + } + template CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledARegBlockDescriptor() {