From 4ab3e84635b9a0f59a2c86f77289a887ee5d05bf Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Thu, 6 Feb 2025 21:56:40 -0800 Subject: [PATCH] change kleidiai interface --- .../kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h | 133 ++++++++++++++++-- .../kernel_selector.h | 91 ++++++------ 2 files changed, 161 insertions(+), 63 deletions(-) diff --git a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h index 167ccc47df..92569db4bb 100644 --- a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h +++ b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h @@ -15,6 +15,13 @@ #include #include +#include +#include + +#ifdef TORCHAO_ENABLE_ARM_I8MM +#include +#include +#endif // TORCHAO_ENABLE_ARM_I8MM #include @@ -43,14 +50,16 @@ namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p { using Ukernel = struct kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel; -size_t activation_data_size(const Ukernel ukernel, int m, int k) { +size_t activation_data_size(int mr, int kr, int sr, int m, int k) { auto lhs_packing = get_lhs_packing(); return lhs_packing.get_lhs_packed_size( - m, k, ukernel.get_mr(), ukernel.get_kr(), ukernel.get_sr()); + m, k, mr, kr, sr); } void prepare_activation_data( - const Ukernel ukernel, + int mr, + int kr, + int sr, void* activation_data, int m, int k, @@ -60,29 +69,31 @@ void prepare_activation_data( lhs_pack.run_lhs_pack( m, k, - ukernel.get_mr(), - ukernel.get_kr(), - ukernel.get_sr(), + mr, + kr, + sr, /*m_index_start=*/0, activations, /*lhs_stride=*/k * sizeof(float), activation_data); } -size_t weight_data_size(const Ukernel ukernel, int n, int k, int group_size) { +size_t weight_data_size(int nr, int kr, int sr, int n, int k, int group_size) { auto rhs_pack = get_rhs_packing(); return rhs_pack.get_rhs_packed_size( n, k, - ukernel.get_nr(), - ukernel.get_kr(), - ukernel.get_sr(), + nr, + kr, + sr, group_size, kai_datatype::kai_dt_bf16); } void prepare_weight_data( - const Ukernel ukernel, + int nr, + int kr, + int sr, void* weight_data, int n, int k, @@ -134,9 +145,9 @@ void prepare_weight_data( /*groups=*/1, n, k, - ukernel.get_nr(), - ukernel.get_kr(), - ukernel.get_sr(), + nr, + kr, + sr, group_size, /*rhs=*/reinterpret_cast(packed_weight_qvals.data()), /*rhs_stride=*/roundup(k, 2) / 2, @@ -148,5 +159,99 @@ void prepare_weight_data( /*qparams=*/&qparams); } + +size_t get_preferred_alignement() { + return 16; +} + + +#define DEFINE_WEIGHT_DATA_FNS(nr, kr, sr) \ + size_t weight_data_size_nr##nr##_kr##kr##_sr##sr(int n, int k, int group_size) { \ + return weight_data_size(nr, kr, sr, n, k, group_size); \ + } \ + void prepare_weight_data_nr##nr##_kr##kr##_sr##sr( \ + void* weight_data, \ + int n, \ + int k, \ + int group_size, \ + const int8_t* weight_qvals, \ + const float* weight_scales, \ + const int8_t* weight_zeros, \ + const float* bias) { \ + prepare_weight_data(nr, kr, sr, weight_data, n, k, group_size, weight_qvals, weight_scales, weight_zeros, bias); \ + } + +#define DEFINE_ACTIVATION_DATA_FNS(mr, kr, sr) \ + size_t activation_data_size_mr##mr##_kr##kr##_sr##sr(int m, int k, int group_size) { \ + (void)group_size; \ + return activation_data_size(mr, kr, sr, m, k); \ + } \ + void prepare_activation_data_mr##mr##_kr##kr##_sr##sr(void* activation_data, int m, int k, int group_size, const float* activations) { \ + (void)group_size; \ + prepare_activation_data(mr, kr, sr, activation_data, m, k, activations); \ + } + +// TODO: first and suffix need to be better, e.g., parametrized by mr, nr, etc +// But I don't quite follow the naming convention for KleidiAI +#define DEFINE_KERNEL_FNS(first, suffix) \ + namespace impl_##suffix { \ + const Ukernel get_ukernel() { \ + return Ukernel{ \ + .get_m_step = kai_get_m_step_matmul_clamp_f32_qai8dxp##first##_qsi4c32p##suffix, \ + .get_n_step = kai_get_n_step_matmul_clamp_f32_qai8dxp##first##_qsi4c32p##suffix, \ + .get_mr = kai_get_mr_matmul_clamp_f32_qai8dxp##first##_qsi4c32p##suffix, \ + .get_nr = kai_get_nr_matmul_clamp_f32_qai8dxp##first##_qsi4c32p##suffix, \ + .get_kr = kai_get_kr_matmul_clamp_f32_qai8dxp##first##_qsi4c32p##suffix, \ + .get_sr = kai_get_sr_matmul_clamp_f32_qai8dxp##first##_qsi4c32p##suffix, \ + .get_lhs_packed_offset = kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp##first##_qsi4c32p##suffix, \ + .get_rhs_packed_offset = kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp##first##_qsi4c32p##suffix, \ + .get_dst_offset = kai_get_dst_offset_matmul_clamp_f32_qai8dxp##first##_qsi4c32p##suffix, \ + .get_dst_size = kai_get_dst_size_matmul_clamp_f32_qai8dxp##first##_qsi4c32p##suffix, \ + .run_matmul = kai_run_matmul_clamp_f32_qai8dxp##first##_qsi4c32p##suffix \ + }; \ + } \ + void kernel( \ + float32_t* output, \ + int output_m_stride, \ + int m, \ + int n, \ + int k, \ + int group_size, \ + const void* weight_data, \ + const void* activation_data, \ + float clamp_min, \ + float clamp_max) { \ + get_ukernel().run_matmul( \ + m, \ + n, \ + k, \ + group_size, \ + activation_data, \ + weight_data, \ + output, \ + /*dst_stride_row=*/ output_m_stride * sizeof(float), \ + /*dst_stride_col=*/ sizeof(float), \ + /*clamp_min=*/std::numeric_limits::lowest(), \ + /*clamp_max=*/std::numeric_limits::max() \ + ); \ + } \ + } + + + +DEFINE_WEIGHT_DATA_FNS(/*nr*/8, /*kr*/16, /*sr*/2) +DEFINE_ACTIVATION_DATA_FNS(/*mr*/1, /*kr*/16, /*sr*/2) +DEFINE_KERNEL_FNS(1x8, 8x8_1x8x32_neon_dotprod) +DEFINE_KERNEL_FNS(1x8, 4x8_1x4x32_neon_dotprod) + +#ifdef TORCHAO_ENABLE_ARM_I8MM +DEFINE_KERNEL_FNS(4x8, 4x8_8x4x32_neon_i8mm) +DEFINE_KERNEL_FNS(4x8, 8x8_4x8x32_neon_i8mm) +#endif // TORCHAO_ENABLE_ARM_I8MM + +#undef DEFINE_WEIGHT_DATA_FNS +#undef DEFINE_ACTIVATION_DATA_FNS +#undef DEFINE_KERNEL_FNS + } // namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p } // namespace torchao::kernels::cpu::aarch64::kleidi diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h b/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h index eeb455bfc4..d380c9e564 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h @@ -17,12 +17,7 @@ #include #if defined(TORCHAO_ENABLE_KLEIDI) -#include -#include -#if defined (TORCHAO_ENABLE_ARM_I8MM) -#include -#include -#endif // TORCHAO_ENABLE_ARM_I8MM +#include #endif // TORCHAO_ENABLE_KLEIDI namespace torchao::ops::linear_8bit_act_xbit_weight { @@ -208,44 +203,43 @@ void register_ukernel_config_kleidi_ai(UKernelConfigRegistrationTable& table, to "Kernel expects has_bias=true, but packed_weights have has_bias=" + std::to_string(kleidi_ai_format.has_bias) ); } + namespace op = torchao::kernels::cpu::aarch64::kleidi::kai_matmul_clamp_f32_qai8dxp_qsi4c32p; if (nr == 8 && kr == 16 && sr == 2) { #if defined (TORCHAO_ENABLE_ARM_I8MM) if (cpuinfo_has_arm_i8mm()) { - namespace kernel = torchao::kernels::cpu::aarch64::kleidi::kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_i8mm_4x8x32; - auto uk = kernel::get_ukernel(); - assert (nr == uk.get_nr()); - assert (kr == uk.get_kr()); - assert (sr == uk.get_sr()); - table.register_ukernel_config( - format, - uarch, - torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig{ - /*preferred_alignment*/kernel::get_preferred_alignement(), - /*weight_packing*/ - { - /*nr*/static_cast(uk.get_n_step()), - /*weight_data_size_fn*/&kernel::weight_data_size, - /*prepare_weight_data_fn*/&kernel::prepare_weight_data - }, - /*kernels*/ - {{ + auto uk = op::8x8_4x8x32_neon_i8mm::get_ukernel(); + assert (nr == uk.get_nr()); + assert (kr == uk.get_kr()); + assert (sr == uk.get_sr()); + table.register_ukernel_config( + format, + uarch, + torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig{ + /*preferred_alignment*/op::get_preferred_alignement(), + /*weight_packing*/ + { + /*nr*/static_cast(uk.get_n_step()), + /*weight_data_size_fn*/&op::weight_data_size_nr8_kr16_sr2, + /*prepare_weight_data_fn*/&op::prepare_weight_data_nr8_kr16_sr2 + }, + /*kernels*/ + {{ { /*mr*/static_cast(uk.get_m_step()), - /*activation_data_size_fn*/&kernel::activation_data_size, - /*prepare_activation_data_fn*/&kernel::prepare_activation_data, - /*kernel*/&kernel::kernel + /*activation_data_size_fn*/&op::activation_data_size_mr1_kr16_sr2, + /*prepare_activation_data_fn*/&op::prepare_activation_data_mr1_kr16_sr2, + /*kernel*/&op::8x8_4x8x32_neon_i8mm::kernel } - }} - } - ); - return; + }} + } + ); + return; } #endif // TORCHAO_ENABLE_ARM_I8MM if (cpuinfo_has_arm_neon_dot()) { - namespace kernel = torchao::kernels::cpu::aarch64::kleidi::kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_dotprod_1x8x32; - auto uk = kernel::get_ukernel(); + auto uk = op::impl_8x8_1x8x32_neon_dotprod::get_ukernel(); assert (nr == uk.get_nr()); assert (kr == uk.get_kr()); assert (sr == uk.get_sr()); @@ -253,20 +247,20 @@ void register_ukernel_config_kleidi_ai(UKernelConfigRegistrationTable& table, to format, uarch, torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig{ - /*preferred_alignment*/kernel::get_preferred_alignement(), + /*preferred_alignment*/op::get_preferred_alignement(), /*weight_packing*/ { /*nr*/static_cast(uk.get_n_step()), - /*weight_data_size_fn*/&kernel::weight_data_size, - /*prepare_weight_data_fn*/&kernel::prepare_weight_data + /*weight_data_size_fn*/&op::weight_data_size_nr8_kr16_sr2, + /*prepare_weight_data_fn*/&op::prepare_weight_data_nr8_kr16_sr2 }, /*kernels*/ {{ { /*mr*/static_cast(uk.get_m_step()), - /*activation_data_size_fn*/&kernel::activation_data_size, - /*prepare_activation_data_fn*/&kernel::prepare_activation_data, - /*kernel*/&kernel::kernel + /*activation_data_size_fn*/&op::activation_data_size_mr1_kr16_sr2, + /*prepare_activation_data_fn*/&op::prepare_activation_data_mr1_kr16_sr2, + /*kernel*/&op::impl_8x8_1x8x32_neon_dotprod::kernel } }} } @@ -274,11 +268,10 @@ void register_ukernel_config_kleidi_ai(UKernelConfigRegistrationTable& table, to return; } } - + if (nr == 4 && kr == 16 && sr == 2) { if (cpuinfo_has_arm_neon_dot()) { - namespace kernel = torchao::kernels::cpu::aarch64::kleidi::kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_dotprod_1x4x32; - auto uk = kernel::get_ukernel(); + auto uk = op::impl_4x8_1x4x32_neon_dotprod::get_ukernel(); assert (nr == uk.get_nr()); assert (kr == uk.get_kr()); assert (sr == uk.get_sr()); @@ -286,26 +279,26 @@ void register_ukernel_config_kleidi_ai(UKernelConfigRegistrationTable& table, to format, uarch, torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig{ - /*preferred_alignment*/kernel::get_preferred_alignement(), + /*preferred_alignment*/op::get_preferred_alignement(), /*weight_packing*/ { /*nr*/static_cast(uk.get_n_step()), - /*weight_data_size_fn*/&kernel::weight_data_size, - /*prepare_weight_data_fn*/&kernel::prepare_weight_data + /*weight_data_size_fn*/&op::weight_data_size_nr8_kr16_sr2, + /*prepare_weight_data_fn*/&op::prepare_weight_data_nr8_kr16_sr2 }, /*kernels*/ {{ { /*mr*/static_cast(uk.get_m_step()), - /*activation_data_size_fn*/&kernel::activation_data_size, - /*prepare_activation_data_fn*/&kernel::prepare_activation_data, - /*kernel*/&kernel::kernel + /*activation_data_size_fn*/&op::activation_data_size_mr1_kr16_sr2, + /*prepare_activation_data_fn*/&op::prepare_activation_data_mr1_kr16_sr2, + /*kernel*/&op::impl_4x8_1x4x32_neon_dotprod::kernel } }} } ); return; - } + } } #endif // TORCHAO_ENABLE_KLEIDI }