Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PROTOTYPE] add new pk_i4 cvt for fp16 and bf16 #1847

Draft
wants to merge 2 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,7 @@ ENDIF()
ENDFOREACH()

add_custom_target(instances DEPENDS utility;${CK_DEVICE_INSTANCES} SOURCES ${INSTANCE_FILES})
add_subdirectory(library)
#add_subdirectory(library)

if(NOT GPU_ARCHS AND USER_GPU_TARGETS)
rocm_package_setup_component(tests
Expand All @@ -601,20 +601,20 @@ if(NOT GPU_ARCHS AND USER_GPU_TARGETS)
PACKAGE_NAME examples
)
add_subdirectory(example)
if(BUILD_TESTING)
add_subdirectory(test)
endif()
#if(BUILD_TESTING)
#add_subdirectory(test)
#endif()
endif()

rocm_package_setup_component(profiler
LIBRARY_NAME composablekernel
PACKAGE_NAME ckprofiler
)
add_subdirectory(profiler)
#add_subdirectory(profiler)

if(CK_USE_CODEGEN AND (SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR GPU_ARCHS))
add_subdirectory(codegen)
endif()
#if(CK_USE_CODEGEN AND (SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR GPU_ARCHS))
#add_subdirectory(codegen)
#endif()

#Create an interface target for the include only files and call it "composablekernels"
include(CMakePackageConfigHelpers)
Expand Down
79 changes: 75 additions & 4 deletions example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,35 @@
using ADataType = ck::bhalf_t;
using BDataType = ck::pk_i4_t;
using AccDataType = float;
using CShuffleDataType = ck::bhalf_t;
using CShuffleDataType = float;
using CDataType = ck::bhalf_t;

using ALayout = Row;
using BLayout = Col;
using CLayout = Row;

struct Scale
{
template <typename E, typename C>
__host__ __device__ constexpr void
operator()(E& e, const C& c) const;

template <>
__host__ __device__ constexpr void operator()<ck::bhalf_t, float>(
ck::bhalf_t& e, const float& c) const
{
const float x0_f = c * scale_;

e = ck::type_convert<ck::bhalf_t>(x0_f);
}

float scale_ = 16.0;
};

using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
using CElementOp = Scale;
//using CElementOp = PassThrough;

static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr bool PermuteA = false;
Expand Down Expand Up @@ -121,7 +140,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{1});
break;
default:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0, 1.0});
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
}

Expand Down Expand Up @@ -149,7 +168,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
{
for(int jj = 0; jj < K1; jj++)
{
b_k_n_permute(j * N * K1 + i * K1 + jj) = b_k_n(i * K + (j * K1 + jj));
//b_k_n_permute(j * N * K1 + i * K1 + jj) = ck::bit_cast<uint8_t>(b_k_n(i * K + (j * K1 + jj))) + 0x88;
b_k_n_permute(j * N * K1 + i * K1 + jj) = ck::bit_cast<uint8_t>(b_k_n(i * K + (j * K1 + jj)));
}
}
}
Expand All @@ -165,6 +185,57 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
}
}

#if 1
// vector pk_i4x4 permute
for(int i = 0; i < N; i++)
{
for(int j = 0; j < K; j += 8)
{
int input[8];

for(int k = 0; k < 4; k++)
{
int i4x2 = b_k_n_permute(j + k * 2, i).data;
input[k * 2 + 0] = (i4x2 >> 4) & 0xf;
input[k * 2 + 1] = (i4x2 >> 0) & 0xf;
}

// permute 01234567->04261537
{
int hi = input[4];
int lo = input[0];
int i4x2 = (hi << 4) | lo;

b_k_n_permute(j + 0, i) = i4x2;
}

{
int hi = input[6];
int lo = input[2];
int i4x2 = (hi << 4) | lo;

b_k_n_permute(j + 2, i) = i4x2;
}

{
int hi = input[5];
int lo = input[1];
int i4x2 = (hi << 4) | lo;

b_k_n_permute(j + 4, i) = i4x2;
}

{
int hi = input[7];
int lo = input[3];
int i4x2 = (hi << 4) | lo;

b_k_n_permute(j + 6, i) = i4x2;
}
}
}
#endif

a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n_permute.mData.data());
DeviceMem workspace;
Expand Down
6 changes: 3 additions & 3 deletions include/ck/library/utility/host_tensor_generator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ struct GeneratorTensor_1<ck::pk_i4_t>
template <typename... Is>
ck::pk_i4_t operator()(Is...)
{
int t = value + 8;
int t = value;
ck::pk_i4_t r = ((t << 4) + t) & 0xff;
return r;
}
Expand Down Expand Up @@ -144,8 +144,8 @@ struct GeneratorTensor_2<ck::pk_i4_t>
template <typename... Is>
ck::pk_i4_t operator()(Is...)
{
int hi = std::rand() % (max_value - min_value) + min_value + 8;
int lo = std::rand() % (max_value - min_value) + min_value + 8;
int hi = std::rand() % (max_value - min_value) + min_value;
int lo = std::rand() % (max_value - min_value) + min_value;
ck::pk_i4_t r = ((hi << 4) + lo) & 0xff;
return r;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,83 @@

namespace ck {

__device__ inline bhalf8_t pki4_to_bhalf8(uint32_t q)
{
uint32_t i4x8 = q;
uint32_t bhalfx2_0, bhalfx2_1, bhalfx2_2, bhalfx2_3;
float tmp_0, tmp_2;
vector_type<bhalf_t, 8> res;

//i4x8 = 0x01000000;

asm volatile (
"v_cvt_off_f32_i4 %[v_tmp_0], %[v_src], src0_sel:BYTE_0\n"
"v_cvt_off_f32_i4 %[v_dst_0], %[v_src], src0_sel:BYTE_2\n"
"v_mov_b32 %[v_dst_0], %[v_tmp_0], dst_sel:WORD_0 dst_unused:UNUSED_PRESERVE src0_sel:WORD_1\n"
"v_cvt_off_f32_i4 %[v_tmp_0], %[v_src], src0_sel:BYTE_1\n"
"v_cvt_off_f32_i4 %[v_dst_1], %[v_src], src0_sel:BYTE_3\n"
"v_mov_b32 %[v_dst_1], %[v_tmp_0], dst_sel:WORD_0 dst_unused:UNUSED_PRESERVE src0_sel:WORD_1\n"
"v_lshrrev_b32 %[v_tmp_2], 4, %[v_src]\n"
"v_cvt_off_f32_i4 %[v_tmp_0], %[v_tmp_2], src0_sel:BYTE_0\n"
"v_cvt_off_f32_i4 %[v_dst_2], %[v_tmp_2], src0_sel:BYTE_2\n"
"v_mov_b32 %[v_dst_2], %[v_tmp_0], dst_sel:WORD_0 dst_unused:UNUSED_PRESERVE src0_sel:WORD_1\n"
"v_cvt_off_f32_i4 %[v_tmp_0], %[v_tmp_2], src0_sel:BYTE_1\n"
"v_cvt_off_f32_i4 %[v_dst_3], %[v_tmp_2], src0_sel:BYTE_3\n"
"v_mov_b32 %[v_dst_3], %[v_tmp_0], dst_sel:WORD_0 dst_unused:UNUSED_PRESERVE src0_sel:WORD_1"
: [v_tmp_0]"+v"(tmp_0), [v_tmp_2]"+v"(tmp_2),
[v_dst_0]"+v"(bhalfx2_0), [v_dst_1]"+v"(bhalfx2_1),
[v_dst_2]"+v"(bhalfx2_2), [v_dst_3]"+v"(bhalfx2_3),
[v_src]"+v"(i4x8)
:
);

res.template AsType<bhalf2_t>()(Number<0>{}) = bit_cast<bhalf2_t>(bhalfx2_0);
res.template AsType<bhalf2_t>()(Number<1>{}) = bit_cast<bhalf2_t>(bhalfx2_1);
res.template AsType<bhalf2_t>()(Number<2>{}) = bit_cast<bhalf2_t>(bhalfx2_2);
res.template AsType<bhalf2_t>()(Number<3>{}) = bit_cast<bhalf2_t>(bhalfx2_3);

//if(threadIdx.x == 0 and blockIdx.x == 0)
//printf("%x %x %x %x\n", bhalfx2_0, bhalfx2_1, bhalfx2_2, bhalfx2_3);

return res.template AsType<bhalf8_t>()[Number<0>{}];
}

__device__ inline half8_t pki4_to_half8(uint32_t q)
{
uint32_t i4x8 = q;
uint32_t halfx2_0, halfx2_1, halfx2_2, halfx2_3;
float tmp_0, tmp_1, tmp_2;
vector_type<half_t, 8> res;

asm volatile (
"v_cvt_off_f32_i4 %[v_tmp_0], %[v_src], src0_sel:BYTE_0\n"
"v_cvt_off_f32_i4 %[v_tmp_1], %[v_src], src0_sel:BYTE_2\n"
"v_cvt_pkrtz_f16_f32 %[v_dst_0], %[v_tmp_0], %[v_tmp_1]\n"
"v_cvt_off_f32_i4 %[v_tmp_0], %[v_src], src0_sel:BYTE_1\n"
"v_cvt_off_f32_i4 %[v_tmp_1], %[v_src], src0_sel:BYTE_3\n"
"v_cvt_pkrtz_f16_f32 %[v_dst_1], %[v_tmp_0], %[v_tmp_1]\n"
"v_lshrrev_b32 %[v_tmp_2], 4, %[v_src]\n"
"v_cvt_off_f32_i4 %[v_tmp_0], %[v_tmp_2], src0_sel:BYTE_0\n"
"v_cvt_off_f32_i4 %[v_tmp_1], %[v_tmp_2], src0_sel:BYTE_2\n"
"v_cvt_pkrtz_f16_f32 %[v_dst_2], %[v_tmp_0], %[v_tmp_1]\n"
"v_cvt_off_f32_i4 %[v_tmp_0], %[v_tmp_2], src0_sel:BYTE_1\n"
"v_cvt_off_f32_i4 %[v_tmp_1], %[v_tmp_2], src0_sel:BYTE_3\n"
"v_cvt_pkrtz_f16_f32 %[v_dst_3], %[v_tmp_0], %[v_tmp_1]\n"
: [v_tmp_0]"+v"(tmp_0), [v_tmp_1]"+v"(tmp_1), [v_tmp_2]"+v"(tmp_2),
[v_dst_0]"+v"(halfx2_0), [v_dst_1]"+v"(halfx2_1),
[v_dst_2]"+v"(halfx2_2), [v_dst_3]"+v"(halfx2_3),
[v_src]"+v"(i4x8)
:
);

res.template AsType<half2_t>()(Number<0>{}) = bit_cast<half2_t>(halfx2_0);
res.template AsType<half2_t>()(Number<1>{}) = bit_cast<half2_t>(halfx2_1);
res.template AsType<half2_t>()(Number<2>{}) = bit_cast<half2_t>(halfx2_2);
res.template AsType<half2_t>()(Number<3>{}) = bit_cast<half2_t>(halfx2_3);

return res.template AsType<half8_t>()[Number<0>{}];
}

// Fast int4x4 to half8_t data type conversion based on paper
// [Who Says Elephants Can't Run: Bringing Large Scale MoE Models into Cloud Scale Production]
// (https://arxiv.org/abs/2211.10017) and implementation:
Expand Down Expand Up @@ -166,6 +243,8 @@ struct PassThroughPack8
result.template AsType<half4_t>()(Number<1>{}) = pki4_to_half4(bit_cast<int>(x) >> 8);

y = result.template AsType<half8_t>()[Number<0>{}];
#elif 1
y = pki4_to_half8(bit_cast<uint32_t>(x));
#else
vector_type<half_t, 8> dst;
vector_type<pk_i4_t, 4> src{x};
Expand All @@ -185,13 +264,15 @@ struct PassThroughPack8

__host__ __device__ constexpr void operator()(ck::bhalf8_t& y, const ck::pk_i4x4_t& x) const
{
#if 1
#if 0
vector_type<bhalf_t, 8> result;

result.template AsType<bhalf4_t>()(Number<0>{}) = pki4_to_bhalf4(bit_cast<int>(x));
result.template AsType<bhalf4_t>()(Number<1>{}) = pki4_to_bhalf4(bit_cast<int>(x) >> 16);

y = result.template AsType<bhalf8_t>()[Number<0>{}];
#elif 1
y = pki4_to_bhalf8(bit_cast<uint32_t>(x));
#else
vector_type<bhalf_t, 8> dst;
vector_type<pk_i4_t, 4> src{x};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ struct ReferenceGemm : public device::BaseOperator
i4 = (i4x2 >> 0) & 0xf;
else
i4 = (i4x2 >> 4) & 0xf;
i4 = i4 - 8;
i4 = (i4 > 7) ? i4 - 16 : i4;
v_a = type_convert<ComputeTypeA>(i4);
}
else
Expand All @@ -92,7 +92,7 @@ struct ReferenceGemm : public device::BaseOperator
i4 = (i4x2 >> 0) & 0xf;
else
i4 = (i4x2 >> 4) & 0xf;
i4 = i4 - 8;
i4 = (i4 > 7) ? i4 - 16 : i4;
v_b = type_convert<ComputeTypeB>(i4);
}
else
Expand Down
2 changes: 1 addition & 1 deletion script/cmake-ck-dev.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ fi
cmake \
-D CMAKE_PREFIX_PATH=/opt/rocm/ \
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-D CMAKE_CXX_FLAGS="-Xclang -mllvm -Xclang -enable-post-misched=0 -std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker" \
-D CMAKE_HIP_FLAGS="-gline-tables-only -save-temps -Xclang -mllvm -Xclang -enable-post-misched=0 -std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker" \
-D CMAKE_BUILD_TYPE=Release \
-D BUILD_DEV=ON \
-D GPU_TARGETS=$GPU_TARGETS \
Expand Down