Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT][Flatmm] implement framwork only #1872

Draft
wants to merge 5 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions example/ck_tile/18_flatmm/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
add_executable(tile_example_flatmm_basic EXCLUDE_FROM_ALL flatmm_basic.cpp)

set(EXAMPLE_FLATMM_COMPILE_OPTIONS)
list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -Wno-unused-variable -Wno-unused-parameter)
list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -Wno-unused-local-typedef)
#list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -ggdb -g -O0 -v -save-temps)
list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -DFEIFEI_DEBUG=1 -DDEBUG_CNT=64)
target_compile_options(tile_example_flatmm_basic PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS})
35 changes: 35 additions & 0 deletions example/ck_tile/18_flatmm/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# FLATMM Matrix Multiplication

This folder contains example for FLATMM using ck_tile tile-programming implementation. Currently, it only supports the basic feature of the CK Tile FLATMM, but creates the placeholders for the future support on different FLATMM pipeline and different FLATMM modules. In the near future, we will gradually migrate all the FLATMM features from old CK to CK Tile.

## build
```
# in the root of ck_tile
mkdir build && cd build
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
sh ../script/cmake-ck-dev.sh ../ <arch>
# The basic pipeline method on the flatmm calculation
make tile_example_flatmm_basic -j
```
This will result in an executable `build/bin/tile_example_flatmm_basic`

## example
```
args:
-b batch size (default:1)
-m m dimension (default:1024)
-n n dimension (default:2048)
-k k dimension (default:64)
-a_layout Tensor A data layout (default: R)
-b_layout Tensor B data layout (default: R)
-c_layout Tensor C data layout (default: R)
-stride_a Tensor A stride (default:0)
-stride_b Tensor B stride (default:0)
-stride_c Tensor C stride (default:0)
-v 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:2)
-e Absolute error tolerance (default:1e-5)
-prec data type. fp16/bf16/fp8/bf8 (default:fp16)
-warmup number of iterations before benchmark the kernel (default:10)
-repeat number of iterations to benchmark the kernel (default:100)
-timer gpu:gpu timer, cpu:cpu timer (default:gpu)
```
137 changes: 137 additions & 0 deletions example/ck_tile/18_flatmm/flatmm_basic.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.

#include <hip/hip_runtime.h>

#include <cstring>
#include <iostream>
#include <ostream>
#include <string>
#include <tuple>

#include "ck_tile/host.hpp"
#include "flatmm_basic.hpp"

template <typename ALayout, typename BLayout, typename CLayout>
float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_config& s)
{
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
constexpr bool kPadM = false;
constexpr bool kPadN = false;
constexpr bool kPadK = false;

constexpr bool kTilePermute = false;
// The rank and permutation will also be generate out by the CodeGen part.
constexpr ck_tile::index_t kOutputRank = 2;

constexpr int kBlockPerCu = 1;

// This part comes from the Codegen
constexpr ck_tile::index_t M_Tile = 128;
constexpr ck_tile::index_t N_Tile = 128;
constexpr ck_tile::index_t K_Tile = 32;

constexpr ck_tile::index_t M_Warp = 1;
constexpr ck_tile::index_t N_Warp = 4;
constexpr ck_tile::index_t K_Warp = 1;

constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 8;

// Whether doing the CShuffle (transpose before the global memory), depending on the output
// layout.
constexpr bool CShuffleEpilogue =
std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::ColumnMajor>;

using CodegenGemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;

using TilePartitioner = ck_tile::GemmTile2DPartitioner<CodegenGemmShape>;

using CodegenGemmTraits =
ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
using CodegenPipelineProblem = ck_tile::
GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenGemmShape, CodegenGemmTraits>;
using CodegenFlatmmPolicy = ck_tile::UniversalFlatmmPipelineAgBgCrPolicy;

using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<AccDataType,
CDataType,
CLayout,
CodegenPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
CodegenPipelineProblem::TransposeC>>;

using CodegenFlatmmPipeline =
ck_tile::FlatmmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem, CodegenFlatmmPolicy>;

// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using Kernel = ck_tile::FlatmmKernel<TilePartitioner, CodegenFlatmmPipeline, GemmEpilogue>;

auto kargs = Kernel::MakeKernelArgs(args);

const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
constexpr dim3 blocks = Kernel::BlockSize();

#if FEIFEI_DEBUG
using BlockFlatmmStruct = ck_tile::remove_cvref_t<decltype(CodegenFlatmmPolicy::template GetBlockFlatmm<CodegenPipelineProblem>())>;
auto block_flatmm = BlockFlatmmStruct(); // struct BlockFlatmmASmemBSmemCRegV1
//auto ADramTileDistr = CodegenFlatmmPolicy::template MakeADramTileDistribution<CodegenPipelineProblem>();

auto kernel = Kernel{};
using SplitKBatchOffset = typename Kernel::SplitKBatchOffset;
SplitKBatchOffset splitk_batch_offset(args);
auto gemm_tensor_views_tuple = Kernel::template MakeGemmTensorViews<ck_tile::memory_operation_enum::set>(
args.a_ptr,
args.b_shuffle_ptr,
args.c_ptr,
kargs, splitk_batch_offset);


printf("[FEIFEI] --- flatmm_calc() ---\n");
printf("[FEIFEI] BlockPerCu = %d\n", static_cast<int>(kBlockPerCu));
printf("[FEIFEI] BlockTile M = %d\n", static_cast<int>(M_Tile));
printf("[FEIFEI] BlockTile N = %d\n", static_cast<int>(N_Tile));
printf("[FEIFEI] BlockTile K = %d\n", static_cast<int>(K_Tile));
printf("[FEIFEI] WavePerBlock M = %d\n", static_cast<int>(M_Warp));
printf("[FEIFEI] WavePerBlock N = %d\n", static_cast<int>(N_Warp));
printf("[FEIFEI] WavePerBlock K = %d\n", static_cast<int>(K_Warp));
printf("[FEIFEI] WaveTile M = %d\n", static_cast<int>(M_Warp_Tile));
printf("[FEIFEI] WaveTile N = %d\n", static_cast<int>(N_Warp_Tile));
printf("[FEIFEI] WaveTile K = %d\n", static_cast<int>(K_Warp_Tile));
printf("[FEIFEI] grids = [%d, %d, %d]\n", grids.x, grids.y, grids.z);
printf("[FEIFEI] blocks = [%d, %d, %d]\n", blocks.x, blocks.y, blocks.z);
#endif

if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}

if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args:"
<< " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
}

float ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));

return ave_time;
}

#include "run_flatmm_example.inc"

int main(int argc, char* argv[]) { return !run_flatmm_example(argc, argv); }
100 changes: 100 additions & 0 deletions example/ck_tile/18_flatmm/flatmm_basic.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@

// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

#include <string>

#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/flatmm.hpp"

#define CK_TILE_PIPELINE_COMPUTE 1
#define CK_TILE_PIPELINE_MEMORY 2

#ifndef CK_TILE_PIPELINE_DEFAULT
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE
#endif

#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrMem
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrMem
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Interwave
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV3
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV3
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave
#else
#error "unsupported CK_TILE_PIPELINE_DEFAULT value"
#endif

template <typename DataType>
struct GemmBasicTypeConfig;

template <>
struct GemmBasicTypeConfig<ck_tile::half_t>
{
using ADataType = ck_tile::half_t;
using BDataType = ck_tile::half_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
// ToDo: Add more bias config to support different categories of GEMM.
};

template <typename T>
struct DataTypeTraits;

template <>
struct DataTypeTraits<float>
{
static constexpr const char* name = "fp32";
};

template <>
struct DataTypeTraits<double>
{
static constexpr const char* name = "fp64";
};

template <>
struct DataTypeTraits<ck_tile::half_t>
{
static constexpr const char* name = "fp16";
};

using Types = GemmBasicTypeConfig<ck_tile::half_t>;

// Specific type aliases for easy access
using ADataType = Types::ADataType;
using BDataType = Types::BDataType;
using AccDataType = Types::AccDataType;
using CDataType = Types::CDataType;

auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser
.insert("m", "128", "m dimension") // 128, 3840
.insert("n", "128", "n dimension") // 128, 4096
.insert("k", "64", "k dimension") // 64, 2048
.insert("a_layout", "R", "A tensor data layout - Row by default")
.insert("b_layout", "C", "B tensor data layout - Row by default")
.insert("c_layout", "R", "C tensor data layout - Row by default")
.insert("stride_a", "0", "Tensor A stride")
.insert("stride_b", "0", "Tensor B stride")
.insert("stride_c", "0", "Tensor C stride")
.insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
.insert("warmup", "50", "number of iterations before benchmark the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("split_k", "1", "splitK value");

bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}

// host API
float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_config& s);
Loading