Skip to content

Commit

Permalink
Fix pk_int4 cast and add pk_int4 dtype in ck tile (#1854)
Browse files Browse the repository at this point in the history
* Fix pk_int4 cast and add pk_int4 dtype in ck tile

* fixes

* Improvements

* fix typo
  • Loading branch information
bartekxk authored Feb 4, 2025
1 parent 9c5b2f3 commit 9ee69dd
Show file tree
Hide file tree
Showing 12 changed files with 406 additions and 73 deletions.
4 changes: 4 additions & 0 deletions include/ck/ck.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,10 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
// set rounding to nearest even as default for f8 conversions
#define CK_USE_SR_F8_CONVERSION 0

// shuffle pk_i4 values during conversion to optimize number of binary
// operations
#define CK_USE_PK4_LAYOUT_SHUFFLE 1

// block synchronization only s_wait lgkmcnt(0), not vmcnt(0)
#define CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ namespace ck {
// [Who Says Elephants Can't Run: Bringing Large Scale MoE Models into Cloud Scale Production]
// (https://arxiv.org/abs/2211.10017) and implementation:
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
__host__ __device__ inline half4_t pki4_to_half4(int q)
// Convert lower part of packed int4 -> int4 to half
__device__ inline half4_t i4_to_half4(int q)
{
const int LO = 0x000f000f;
const int HI = 0x00f000f0;
Expand Down Expand Up @@ -44,7 +45,7 @@ __host__ __device__ inline half4_t pki4_to_half4(int q)
return res.template AsType<half4_t>()[Number<0>{}];
}

__host__ __device__ inline half4_t pki4_to_half4_scale(int q, const ck::half2_t& scale)
__device__ inline half4_t i4_to_half4_scale(int q, const ck::half2_t& scale)
{
const int LO = 0x000f000f;
const int HI = 0x00f000f0;
Expand Down Expand Up @@ -78,34 +79,7 @@ __host__ __device__ inline half4_t pki4_to_half4_scale(int q, const ck::half2_t&
return res.template AsType<half4_t>()[Number<0>{}];
}

__host__ __device__ inline half2_t pki4_to_half2(pk_i4_t q)
{
#if 1
uint8_t x_u8 = ck::bit_cast<uint8_t>(q);
uint32_t i4s = ((x_u8 & 0x0f) << 16) | ((x_u8 & 0xf0) >> 4);

const int EX = 0x64006400;
const int SUB = 0xE408E408; //-8

int lo = i4s | EX;

return amd_assembly_pk_add_f16(bit_cast<half2_t>(lo), bit_cast<half2_t>(SUB));
#else
uint8_t x_u8 = ck::bit_cast<uint8_t>(q);
vector_type<half_t, 2> res;
half_t x_h = (x_u8 & 0x0f) - 8;
half_t x_l = ((x_u8 & 0xf0) >> 4) - 8;
res.template AsType<half_t>()(Number<0>{}) = x_l;
res.template AsType<half_t>()(Number<1>{}) = x_h;
return res.template AsType<half2_t>()[Number<0>{}];
#endif
}

__host__ __device__ inline bhalf4_t pki4_to_bhalf4(int q)
__device__ inline bhalf4_t i4_to_bhalf4(int q)
{
uint32_t i8s = (q & 0xf) | ((q & 0xf0) << 4) | ((q & 0xf00) << 8) | ((q & 0xf000) << 12);

Expand Down Expand Up @@ -134,21 +108,6 @@ __host__ __device__ inline bhalf4_t pki4_to_bhalf4(int q)
return res.template AsType<bhalf4_t>()[Number<0>{}];
}

__host__ __device__ inline bhalf2_t pki4_to_bhalf2(pk_i4_t q)
{
uint8_t x_u8 = ck::bit_cast<uint8_t>(q);

float x_h = ((x_u8 & 0x0f) >> 0) - 8.f;
float x_l = ((x_u8 & 0xf0) >> 4) - 8.f;

vector_type<bhalf_t, 2> res;

res.template AsType<bhalf_t>()(Number<0>{}) = type_convert<bhalf_t>(x_l);
res.template AsType<bhalf_t>()(Number<1>{}) = type_convert<bhalf_t>(x_h);

return res.template AsType<bhalf2_t>()[Number<0>{}];
}

namespace tensor_operation {
namespace element_wise {

Expand All @@ -159,51 +118,51 @@ struct PassThroughPack8

__host__ __device__ constexpr void operator()(ck::half8_t& y, const ck::pk_i4x4_t& x) const
{
#if 1
#if CK_USE_PK4_LAYOUT_SHUFFLE
vector_type<half_t, 8> result;

result.template AsType<half4_t>()(Number<0>{}) = pki4_to_half4(bit_cast<int>(x));
result.template AsType<half4_t>()(Number<1>{}) = pki4_to_half4(bit_cast<int>(x) >> 8);
result.template AsType<half4_t>()(Number<0>{}) = i4_to_half4(bit_cast<int>(x));
result.template AsType<half4_t>()(Number<1>{}) = i4_to_half4(bit_cast<int>(x) >> 8);

y = result.template AsType<half8_t>()[Number<0>{}];
#else
vector_type<half_t, 8> dst;
vector_type<pk_i4_t, 4> src{x};

dst.template AsType<half2_t>()(Number<0>{}) =
pki4_to_half2(src.template AsType<pk_i4_t>()[Number<0>{}]);
type_convert<half2_t>(src.template AsType<pk_i4_t>()[Number<0>{}]);
dst.template AsType<half2_t>()(Number<1>{}) =
pki4_to_half2(src.template AsType<pk_i4_t>()[Number<1>{}]);
type_convert<half2_t>(src.template AsType<pk_i4_t>()[Number<1>{}]);
dst.template AsType<half2_t>()(Number<2>{}) =
pki4_to_half2(src.template AsType<pk_i4_t>()[Number<2>{}]);
type_convert<half2_t>(src.template AsType<pk_i4_t>()[Number<2>{}]);
dst.template AsType<half2_t>()(Number<3>{}) =
pki4_to_half2(src.template AsType<pk_i4_t>()[Number<3>{}]);
type_convert<half2_t>(src.template AsType<pk_i4_t>()[Number<3>{}]);

y = dst.template AsType<half8_t>()[Number<0>{}];
#endif
}

__host__ __device__ constexpr void operator()(ck::bhalf8_t& y, const ck::pk_i4x4_t& x) const
{
#if 1
#if CK_USE_PK4_LAYOUT_SHUFFLE
vector_type<bhalf_t, 8> result;

result.template AsType<bhalf4_t>()(Number<0>{}) = pki4_to_bhalf4(bit_cast<int>(x));
result.template AsType<bhalf4_t>()(Number<1>{}) = pki4_to_bhalf4(bit_cast<int>(x) >> 16);
result.template AsType<bhalf4_t>()(Number<0>{}) = i4_to_bhalf4(bit_cast<int>(x));
result.template AsType<bhalf4_t>()(Number<1>{}) = i4_to_bhalf4(bit_cast<int>(x) >> 16);

y = result.template AsType<bhalf8_t>()[Number<0>{}];
#else
vector_type<bhalf_t, 8> dst;
vector_type<pk_i4_t, 4> src{x};

dst.template AsType<bhalf2_t>()(Number<0>{}) =
pki4_to_bhalf2(src.template AsType<pk_i4_t>()[Number<0>{}]);
type_convert<bhalf2_t>(src.template AsType<pk_i4_t>()[Number<0>{}]);
dst.template AsType<bhalf2_t>()(Number<1>{}) =
pki4_to_bhalf2(src.template AsType<pk_i4_t>()[Number<1>{}]);
type_convert<bhalf2_t>(src.template AsType<pk_i4_t>()[Number<1>{}]);
dst.template AsType<bhalf2_t>()(Number<2>{}) =
pki4_to_bhalf2(src.template AsType<pk_i4_t>()[Number<2>{}]);
type_convert<bhalf2_t>(src.template AsType<pk_i4_t>()[Number<2>{}]);
dst.template AsType<bhalf2_t>()(Number<3>{}) =
pki4_to_bhalf2(src.template AsType<pk_i4_t>()[Number<3>{}]);
type_convert<bhalf2_t>(src.template AsType<pk_i4_t>()[Number<3>{}]);

y = dst.template AsType<bhalf8_t>()[Number<0>{}];
#endif
Expand All @@ -219,26 +178,26 @@ struct DequantPack8
__host__ __device__ constexpr void
operator()(ck::half8_t& y, const ck::pk_i4x4_t& x, const ck::half2_t& z) const
{
#if 1
#if CK_USE_PK4_LAYOUT_SHUFFLE
vector_type<half_t, 8> result;

result.template AsType<half4_t>()(Number<0>{}) = pki4_to_half4_scale(bit_cast<int>(x), z);
result.template AsType<half4_t>()(Number<0>{}) = i4_to_half4_scale(bit_cast<int>(x), z);
result.template AsType<half4_t>()(Number<1>{}) =
pki4_to_half4_scale(bit_cast<int>(x) >> 8, z);
i4_to_half4_scale(bit_cast<int>(x) >> 8, z);

y = result.template AsType<half8_t>()[Number<0>{}];
#else
vector_type<half_t, 8> dst;
vector_type<pk_i4_t, 4> src{x};

dst.template AsType<half2_t>()(Number<0>{}) =
pki4_to_half2(src.template AsType<pk_i4_t>()[Number<0>{}]);
type_convert<half2_t>(src.template AsType<pk_i4_t>()[Number<0>{}]);
dst.template AsType<half2_t>()(Number<1>{}) =
pki4_to_half2(src.template AsType<pk_i4_t>()[Number<1>{}]);
type_convert<half2_t>(src.template AsType<pk_i4_t>()[Number<1>{}]);
dst.template AsType<half2_t>()(Number<2>{}) =
pki4_to_half2(src.template AsType<pk_i4_t>()[Number<2>{}]);
type_convert<half2_t>(src.template AsType<pk_i4_t>()[Number<2>{}]);
dst.template AsType<half2_t>()(Number<3>{}) =
pki4_to_half2(src.template AsType<pk_i4_t>()[Number<3>{}]);
type_convert<half2_t>(src.template AsType<pk_i4_t>()[Number<3>{}]);

y = dst.template AsType<half8_t>()[Number<0>{}];
#endif
Expand All @@ -260,7 +219,7 @@ struct PassThroughPack2

__host__ __device__ constexpr void operator()(ck::half2_t& y, const ck::pk_i4_t& x) const
{
#if 1
#if CK_USE_PK4_LAYOUT_SHUFFLE
uint8_t x_u8 = ck::bit_cast<uint8_t>(x);
uint8_t x_l = (x_u8 & 0x0f) >> 0;
uint8_t x_h = (x_u8 & 0xf0) >> 4;
Expand Down
70 changes: 65 additions & 5 deletions include/ck/utility/type_convert.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,35 @@
#include "ck/utility/f8_utils.hpp"
#include "ck/utility/random_gen.hpp"
#include "ck/utility/array.hpp"
#include "ck/utility/amd_inline_asm.hpp"
#include "ck/utility/type.hpp"

namespace ck {
// Define the common macro for MI300 models
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#define __gfx94__
#endif

namespace {
namespace details {

[[maybe_unused]] __host__ half2_t pk_add_f16(const half2_t& x, const half2_t& y)
{
half2_t vector_res;

vector_res.x = x.x + y.x;
vector_res.y = x.y + y.y;

return vector_res;
}

[[maybe_unused]] __device__ half2_t pk_add_f16(const half2_t& x, const half2_t& y)
{
return amd_assembly_pk_add_f16(x, y);
}
} // namespace details
} // namespace

// Declare a template function for bf16 conversion using RTN
template <typename Y, typename X>
__host__ __device__ constexpr Y bf16_convert_rtn(X x);
Expand Down Expand Up @@ -520,13 +542,51 @@ template <>
inline __host__ __device__ float2_t type_convert<float2_t, pk_i4_t>(pk_i4_t x)
{
uint8_t x_u8 = ck::bit_cast<uint8_t>(x);
uint8_t x_l = (x_u8 & 0x0f) >> 0;
uint8_t x_h = (x_u8 & 0xf0) >> 4;

auto l_f32 = ck::type_convert<float>(x_l);
auto h_f32 = ck::type_convert<float>(x_h);
float x_l = ((x_u8 & 0x0f) >> 0) - 8.f;
float x_h = ((x_u8 & 0xf0) >> 4) - 8.f;

#ifdef CK_USE_PK4_LAYOUT_SHUFFLE
float2_t res = {x_h, x_l};
#elif
float2_t res = {x_l, x_h};
#endif
return res;
}

template <>
inline __host__ __device__ half2_t type_convert<half2_t, pk_i4_t>(pk_i4_t x)
{
uint8_t x_u8 = ck::bit_cast<uint8_t>(x);
#ifdef CK_USE_PK4_LAYOUT_SHUFFLE
uint32_t i4s = ((x_u8 & 0x0f) << 16) | ((x_u8 & 0xf0) >> 4);
#else
uint32_t i4s = ((x_u8 & 0xf0) << 12) | (x_u8 & 0xf);
#endif

const int EX = 0x64006400;
const int SUB = 0xE408E408; //-8

int lo = i4s | EX;

return details::pk_add_f16(bit_cast<half2_t>(lo), bit_cast<half2_t>(SUB));
}

template <>
inline __host__ __device__ bhalf2_t type_convert<bhalf2_t, pk_i4_t>(pk_i4_t x)
{
uint8_t x_u8 = ck::bit_cast<uint8_t>(x);

float x_l = ((x_u8 & 0x0f) >> 0) - 8.f;
float x_h = ((x_u8 & 0xf0) >> 4) - 8.f;

#ifdef CK_USE_PK4_LAYOUT_SHUFFLE
bhalf2_t res = {type_convert<bhalf_t>(x_h), type_convert<bhalf_t>(x_l)};
#else
bhalf2_t res = {type_convert<bhalf_t>(x_l), type_convert<bhalf_t>(x_h)};
#endif

return {l_f32, h_f32};
return res;
}

template <>
Expand Down
1 change: 1 addition & 0 deletions include/ck_tile/core.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "ck_tile/core/numeric/float8.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/int8.hpp"
#include "ck_tile/core/numeric/pk_int4.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/math.hpp"
Expand Down
4 changes: 4 additions & 0 deletions include/ck_tile/core/config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,10 @@
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER 1
#endif

#ifndef CK_TILE_USE_PK4_LAYOUT_SHUFFLE
#define CK_TILE_USE_PK4_LAYOUT_SHUFFLE 1
#endif

// buffer atomic add: floating point
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
Expand Down
Loading

0 comments on commit 9ee69dd

Please sign in to comment.