From df45a6b59d5e74743e7f5aeb3e45a4b5f2adfec1 Mon Sep 17 00:00:00 2001 From: Jiming Ruan Date: Fri, 24 Jan 2025 07:29:33 +0000 Subject: [PATCH 1/6] 50ms -> 28ms --- example/ck_tile/10_rmsnorm2d/generate.py | 9 +++-- .../rmsnorm2d_fwd_pipeline_two_pass.hpp | 37 +++++++++++++------ 2 files changed, 30 insertions(+), 16 deletions(-) diff --git a/example/ck_tile/10_rmsnorm2d/generate.py b/example/ck_tile/10_rmsnorm2d/generate.py index dadb2268b2..b58fba3ffc 100644 --- a/example/ck_tile/10_rmsnorm2d/generate.py +++ b/example/ck_tile/10_rmsnorm2d/generate.py @@ -535,10 +535,11 @@ def get_blobs(self): h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 512, 4, True, False, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 2, True, False, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 8, 1,1024, 1, True, False, False, 0, 0)], - 'big' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, 0, 0)]} + 'big' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 1, 1024, 8, True, False, True, 0, 0), + # h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, 0, 0), + # h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, 0, 0), + # h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, 0, 0) + ]} total_blob = list() for hs_key in h_trait_dict: hs = h_trait_dict[hs_key] diff --git a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp index c29a6cb07d..663c513800 100644 --- a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp +++ b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp @@ -125,7 +125,7 @@ struct Rmsnorm2dFwdPipelineTwoPass // compute inv-rms auto inv_rms = tile_elementwise_in( [&](const auto& v_) { - return type_convert(1.0f) / (sqrt(v_ / row_size + epsilon)); + return rsqrtf(v_ / row_size + epsilon); }, square_sum); @@ -136,32 +136,47 @@ struct Rmsnorm2dFwdPipelineTwoPass ck_tile::index_t stride_to_right_most_window = row_size % Block_N == 0 ? row_size - Block_N : row_size - row_size % Block_N; - move_tile_window(x_window, {0, -Block_N}); - move_tile_window(x_residual_window, {0, -Block_N}); + if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE) + { + move_tile_window(y_residual_window, {0, -Block_N}); + } + else + { + move_tile_window(x_window, {0, -Block_N}); + move_tile_window(x_residual_window, {0, -Block_N}); + } move_tile_window(gamma_window, {stride_to_right_most_window}); move_tile_window(y_window, {0, stride_to_right_most_window}); // rmsnorm computation for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) { - auto x = load_tile(x_window); - auto x_resi = load_tile(x_residual_window); - auto acc = cast_tile(x); + auto acc = make_static_distributed_tensor(decltype(load_tile(x_window))::get_tile_distribution()); - if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE || - kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD) + if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD) { + auto x = load_tile(x_window); + auto x_resi = load_tile(x_residual_window); + sweep_tile(x_resi, [&](auto idx) { // compute x = x_resi + x acc(idx) = type_convert(x_resi(idx)) + acc(idx); }); + + move_tile_window(x_window, {0, -Block_N}); + move_tile_window(x_residual_window, {0, -Block_N}); + } + else if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE) + { + acc = cast_tile(load_tile(y_residual_window)); + move_tile_window(y_residual_window, {0, -Block_N}); } // load gamma (TODO: support no gamma?) const auto gamma = load_tile(gamma_window); // rmsnorm computation - auto rmsn = make_static_distributed_tensor(x.get_tile_distribution()); + auto rmsn = make_static_distributed_tensor(decltype(load_tile(x_window))::get_tile_distribution()); sweep_tile(rmsn, [&, inv_rms_ = inv_rms](auto idx) { constexpr auto i_idx = make_tuple(idx[number<0>{}]); constexpr auto j_idx = make_tuple(idx[number<1>{}]); @@ -175,9 +190,7 @@ struct Rmsnorm2dFwdPipelineTwoPass static_assert(kFusedQuant == Rmsnorm2dFusedQuantEnum::NO_SWEEP); Epilogue{}(y_window, rmsn); - - move_tile_window(x_window, {0, -Block_N}); - move_tile_window(x_residual_window, {0, -Block_N}); + move_tile_window(gamma_window, {-Block_N}); move_tile_window(y_window, {0, -Block_N}); } From f41a43a18ea54081651a70157ea07ea2e9eda6eb Mon Sep 17 00:00:00 2001 From: Jiming Ruan Date: Mon, 27 Jan 2025 09:07:43 +0000 Subject: [PATCH 2/6] Fix bug in non fuse_add_store cases --- .../rmsnorm2d_fwd_pipeline_two_pass.hpp | 30 ++++++++++--------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp index 663c513800..371e82b88b 100644 --- a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp +++ b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp @@ -153,24 +153,26 @@ struct Rmsnorm2dFwdPipelineTwoPass { auto acc = make_static_distributed_tensor(decltype(load_tile(x_window))::get_tile_distribution()); - if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD) - { - auto x = load_tile(x_window); - auto x_resi = load_tile(x_residual_window); - - sweep_tile(x_resi, [&](auto idx) { - // compute x = x_resi + x - acc(idx) = type_convert(x_resi(idx)) + acc(idx); - }); - - move_tile_window(x_window, {0, -Block_N}); - move_tile_window(x_residual_window, {0, -Block_N}); - } - else if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE) + if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE) { acc = cast_tile(load_tile(y_residual_window)); move_tile_window(y_residual_window, {0, -Block_N}); } + else + { + acc = cast_tile(load_tile(x_window)); + move_tile_window(x_window, {0, -Block_N}); + + if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD) + { + auto x_resi = load_tile(x_residual_window); + sweep_tile(x_resi, [&](auto idx) { + // compute x = x_resi + x + acc(idx) = type_convert(x_resi(idx)) + acc(idx); + }); + move_tile_window(x_residual_window, {0, -Block_N}); + } + } // load gamma (TODO: support no gamma?) const auto gamma = load_tile(gamma_window); From 153240836529a72338343bc1d65dc0b36bdaa08e Mon Sep 17 00:00:00 2001 From: Jiming Ruan Date: Mon, 27 Jan 2025 09:08:19 +0000 Subject: [PATCH 3/6] Fine tuned setting for 2 pass pipeline --- example/ck_tile/10_rmsnorm2d/generate.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/example/ck_tile/10_rmsnorm2d/generate.py b/example/ck_tile/10_rmsnorm2d/generate.py index b58fba3ffc..28faa39f90 100644 --- a/example/ck_tile/10_rmsnorm2d/generate.py +++ b/example/ck_tile/10_rmsnorm2d/generate.py @@ -535,11 +535,8 @@ def get_blobs(self): h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 512, 4, True, False, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 2, True, False, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 8, 1,1024, 1, True, False, False, 0, 0)], - 'big' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 1, 1024, 8, True, False, True, 0, 0), - # h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, 0, 0), - # h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, 0, 0), - # h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, 0, 0) - ]} + 'big' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 1,1024, 8, True, False, True, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 20, 1, 256, 1, True, False, True, 0, 0)]} total_blob = list() for hs_key in h_trait_dict: hs = h_trait_dict[hs_key] From 485e530b60ecbc7668a36748c408a0db272a04f5 Mon Sep 17 00:00:00 2001 From: Jiming Ruan Date: Wed, 5 Feb 2025 06:47:49 +0000 Subject: [PATCH 4/6] adjust workload --- example/ck_tile/10_rmsnorm2d/generate.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/example/ck_tile/10_rmsnorm2d/generate.py b/example/ck_tile/10_rmsnorm2d/generate.py index 28faa39f90..488538ed22 100644 --- a/example/ck_tile/10_rmsnorm2d/generate.py +++ b/example/ck_tile/10_rmsnorm2d/generate.py @@ -536,7 +536,9 @@ def get_blobs(self): h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 2, True, False, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 8, 1,1024, 1, True, False, False, 0, 0)], 'big' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 1,1024, 8, True, False, True, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 20, 1, 256, 1, True, False, True, 0, 0)]} + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 12, 1, 256, 2, True, False, True, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 16, 1, 256, 1, True, False, True, 0, 0)]} total_blob = list() for hs_key in h_trait_dict: hs = h_trait_dict[hs_key] From b426b99a23ac0c5aff5eb27e5e7c3e35f9f04aef Mon Sep 17 00:00:00 2001 From: Jiming Ruan Date: Wed, 5 Feb 2025 07:05:13 +0000 Subject: [PATCH 5/6] remove unnecessary change --- .../pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp index 371e82b88b..6f02995e00 100644 --- a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp +++ b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp @@ -125,7 +125,7 @@ struct Rmsnorm2dFwdPipelineTwoPass // compute inv-rms auto inv_rms = tile_elementwise_in( [&](const auto& v_) { - return rsqrtf(v_ / row_size + epsilon); + return type_convert(1.0f) / (sqrt(v_ / row_size + epsilon)); }, square_sum); @@ -151,7 +151,8 @@ struct Rmsnorm2dFwdPipelineTwoPass // rmsnorm computation for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) { - auto acc = make_static_distributed_tensor(decltype(load_tile(x_window))::get_tile_distribution()); + auto acc = make_static_distributed_tensor( + decltype(load_tile(x_window))::get_tile_distribution()); if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD_STORE) { @@ -178,7 +179,8 @@ struct Rmsnorm2dFwdPipelineTwoPass const auto gamma = load_tile(gamma_window); // rmsnorm computation - auto rmsn = make_static_distributed_tensor(decltype(load_tile(x_window))::get_tile_distribution()); + auto rmsn = make_static_distributed_tensor( + decltype(load_tile(x_window))::get_tile_distribution()); sweep_tile(rmsn, [&, inv_rms_ = inv_rms](auto idx) { constexpr auto i_idx = make_tuple(idx[number<0>{}]); constexpr auto j_idx = make_tuple(idx[number<1>{}]); @@ -192,7 +194,7 @@ struct Rmsnorm2dFwdPipelineTwoPass static_assert(kFusedQuant == Rmsnorm2dFusedQuantEnum::NO_SWEEP); Epilogue{}(y_window, rmsn); - + move_tile_window(gamma_window, {-Block_N}); move_tile_window(y_window, {0, -Block_N}); } From 28f93c9581f7cbfd0f702a05dbab0a35fb0e2da3 Mon Sep 17 00:00:00 2001 From: Jiming Ruan Date: Thu, 6 Feb 2025 07:01:15 +0000 Subject: [PATCH 6/6] add layernorm --- example/ck_tile/02_layernorm2d/generate.py | 4 +- example/ck_tile/10_rmsnorm2d/generate.py | 2 +- .../layernorm2d_fwd_pipeline_two_pass.hpp | 65 ++++++++++++------- 3 files changed, 45 insertions(+), 26 deletions(-) diff --git a/example/ck_tile/02_layernorm2d/generate.py b/example/ck_tile/02_layernorm2d/generate.py index 700b007fad..0238a125dc 100644 --- a/example/ck_tile/02_layernorm2d/generate.py +++ b/example/ck_tile/02_layernorm2d/generate.py @@ -564,9 +564,9 @@ def get_blobs(self, args): h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 512, 4, True, False, True, True, False, 0, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 2, True, False, True, True, False, 0, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 8, 1,1024, 1, True, False, True, True, False, 0, 0, 0)], - 'big' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, True, True, 0, 0, 0), + 'big' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 1,1024, 8, True, False, True, True, True, 0, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, True, True, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, True, True, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 12, 1, 256, 2, True, False, True, True, True, 0, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, True, True, 0, 0, 0)]} total_blob = list() for hs_key in h_trait_dict: diff --git a/example/ck_tile/10_rmsnorm2d/generate.py b/example/ck_tile/10_rmsnorm2d/generate.py index 488538ed22..c13f52a3b4 100644 --- a/example/ck_tile/10_rmsnorm2d/generate.py +++ b/example/ck_tile/10_rmsnorm2d/generate.py @@ -538,7 +538,7 @@ def get_blobs(self): 'big' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 1,1024, 8, True, False, True, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 12, 1, 256, 2, True, False, True, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 16, 1, 256, 1, True, False, True, 0, 0)]} + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, 0, 0)]} total_blob = list() for hs_key in h_trait_dict: hs = h_trait_dict[hs_key] diff --git a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp index b0b0c194ad..73cdd084c6 100644 --- a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp +++ b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp @@ -182,9 +182,16 @@ struct Layernorm2dFwdPipelineTwoPass ck_tile::index_t stride_to_right_most_window = row_size % Block_N == 0 ? row_size - Block_N : row_size - row_size % Block_N; - move_tile_window(x_window, {0, -Block_N}); - move_tile_window(x_residual_window, {0, -Block_N}); - move_tile_window(x_bias_window, {-Block_N}); + if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE) + { + move_tile_window(y_residual_window, {0, -Block_N}); + } + else + { + move_tile_window(x_window, {0, -Block_N}); + move_tile_window(x_residual_window, {0, -Block_N}); + move_tile_window(x_bias_window, {-Block_N}); + } move_tile_window(gamma_window, {stride_to_right_most_window}); move_tile_window(beta_window, {stride_to_right_most_window}); move_tile_window(y_window, {0, stride_to_right_most_window}); @@ -192,28 +199,43 @@ struct Layernorm2dFwdPipelineTwoPass // layernorm computation for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) { - auto x = load_tile(x_window); - auto x_resi = load_tile(x_residual_window); - const auto x_bias = load_tile(x_bias_window); - auto acc = cast_tile(x); + auto acc = make_static_distributed_tensor( + decltype(load_tile(x_window))::get_tile_distribution()); - if constexpr(kXbias == Layernorm2dXBiasEnum::ADD_BIAS) + if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE) { - sweep_tile(x, [&](auto idx) { - // compute x = bias + x - constexpr auto j_idx = make_tuple(idx[number<1>{}]); - acc(idx) = type_convert(x_bias[j_idx]) + acc(idx); - }); + acc = cast_tile(load_tile(y_residual_window)); + move_tile_window(y_residual_window, {0, -Block_N}); } - - if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE || - kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD) + else { - sweep_tile(x_resi, [&](auto idx) { - // compute x = x_resi + x - acc(idx) = type_convert(x_resi(idx)) + acc(idx); - }); + acc = cast_tile(load_tile(x_window)); + move_tile_window(x_window, {0, -Block_N}); + + if constexpr(kXbias == Layernorm2dXBiasEnum::ADD_BIAS) + { + const auto x_bias = load_tile(x_bias_window); + move_tile_window(x_bias_window, {-Block_N}); + + sweep_tile(acc, [&](auto idx) { + // compute x = bias + x + constexpr auto j_idx = make_tuple(idx[number<1>{}]); + acc(idx) = type_convert(x_bias[j_idx]) + acc(idx); + }); + } + + if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD) + { + auto x_resi = load_tile(x_residual_window); + move_tile_window(x_residual_window, {0, -Block_N}); + + sweep_tile(x_resi, [&](auto idx) { + // compute x = x_resi + x + acc(idx) = type_convert(x_resi(idx)) + acc(idx); + }); + } } + // load gamma/beta (TODO: support no gamma/beta?) const auto gamma = load_tile(gamma_window); const auto beta = load_tile(beta_window); @@ -235,9 +257,6 @@ struct Layernorm2dFwdPipelineTwoPass static_assert(kFusedQuant != Layernorm2dFusedQuantEnum::DYNAMIC_QUANT); Epilogue{}(y_window, ln); - move_tile_window(x_window, {0, -Block_N}); - move_tile_window(x_residual_window, {0, -Block_N}); - move_tile_window(x_bias_window, {-Block_N}); move_tile_window(gamma_window, {-Block_N}); move_tile_window(beta_window, {-Block_N}); move_tile_window(y_window, {0, -Block_N});