From 2017d7b1e256f576a3c0f5c222e5ae6f842c42f6 Mon Sep 17 00:00:00 2001 From: Jane Xu Date: Wed, 29 Jan 2025 11:32:05 -0800 Subject: [PATCH] Clean up code, finish other end of void* boxed kernel --- test/test_ops.py | 3 +- .../tensor_core_tiled_layout.cu | 79 ++++--------------- 2 files changed, 19 insertions(+), 63 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 26671ddf40..54efefb026 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1,4 +1,5 @@ import itertools +import sys import pytest import torch @@ -614,4 +615,4 @@ def test_marlin_qqq(batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_fact if __name__ == "__main__": - pytest.main([__file__]) + pytest.main(sys.argv) diff --git a/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu b/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu index 994374910d..39acdaf4eb 100644 --- a/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu +++ b/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu @@ -1,7 +1,7 @@ #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 // at least Ampere -// #include -// #include +#include +#include #include #include #include @@ -332,25 +332,6 @@ AtenTensorHandle _ATH_dequantize_tensor_core_tiled_layout( return out; } -// output is [n][k] (int32 dtype) -// input is [n / 8][k / (InnerKTiles * 16)][32][innerKTiles / 2] -// scales_and_zeros is [numQGroups][n][2] -// qGroupSize is 32, 64, 128 or 256 -// at::Tensor -// _dequantize_tensor_core_tiled_layout(const at::Tensor &packed_w, -// const at::Tensor &scales_and_zeros, -// int64_t group_size, int64_t innerKTiles) { - -// AtenTensorHandle packed_w_ath = -// torch::aot_inductor::tensor_pointer_to_tensor_handle(&packed_w); -// AtenTensorHandle scales_and_zeros_ath = -// torch::aot_inductor::tensor_pointer_to_tensor_handle(&scales_and_zeros); - -// AtenTensorHandle ath_res = _ATH_dequantize_tensor_core_tiled_layout( -// packed_w_ath, scales_and_zeros_ath, group_size, innerKTiles); - -// return *torch::aot_inductor::tensor_handle_to_tensor_pointer(ath_res); -// } void voidyvoid_boxed_ATH_dequantize_tensor_core_tiled_layout(void **stack, int64_t num_args, @@ -360,8 +341,6 @@ void voidyvoid_boxed_ATH_dequantize_tensor_core_tiled_layout(void **stack, // schema values for now, and run the function and modify the void* stack int64_t innerKTiles = reinterpret_cast(stack[3]); int64_t group_size = reinterpret_cast(stack[2]); - TORCH_WARN(innerKTiles); - TORCH_WARN(group_size); AtenTensorHandle scales_and_zeros_ath = reinterpret_cast(stack[1]); AtenTensorHandle packed_w_ath = reinterpret_cast(stack[0]); @@ -386,68 +365,44 @@ void boxed_dequantize_tensor_core_tiled_layout(const c10::OperatorHandle &op, const auto& schema = op.schema(); const auto num_returns = schema.returns().size(); const auto num_arguments = schema.arguments().size(); - TORCH_CHECK(num_arguments==4); - TORCH_CHECK(num_returns==1); void **ministack = (void**)malloc((num_arguments + num_returns) * sizeof(void *)); for (auto idx = 0; idx < num_arguments; idx++) { - TORCH_WARN(idx); const c10::IValue& arg = torch::jit::peek(stack, idx, num_arguments); if (arg.isInt()) { ministack[idx] = reinterpret_cast(arg.toInt()); } else if (arg.isTensor()) { - TORCH_WARN("am tensor!") const at::Tensor& tensor = arg.toTensor(); AtenTensorHandle ath = torch::aot_inductor::tensor_pointer_to_tensor_handle(&tensor); ministack[idx] = reinterpret_cast(ath); } else { - TORCH_CHECK(false, "Other types of IValues not handled!"); + TORCH_CHECK(false, "Other types of IValues not yet handled!"); } } - TORCH_WARN("done with forloop no problems!") // second function is going to take a stack of void*, cast them to our // schema values for now, and run the function and modify the void* stack voidyvoid_boxed_ATH_dequantize_tensor_core_tiled_layout(ministack, num_arguments, num_returns); - // now read the output from the end of the stack and wrap that back into - // IValue from void*? - - AtenTensorHandle out_ath = - reinterpret_cast(ministack[num_arguments]); - - free(ministack); - - at::Tensor out = - *torch::aot_inductor::tensor_handle_to_tensor_pointer(out_ath); - - // now pop everything. if we pop earlier, Tensors would go out of scope + // now pop all inputs on stack. if we pop earlier, Tensors would go out of scope // before calling the function torch::jit::drop(stack, num_arguments); - torch::jit::push(stack, c10::IValue(out)); - - // so above is our stack of IValues, but we cannot have these IValues because - // they are NOT ABI stable! So we need another version of "boxed" with void*s. - // and that is what is going to happen below - - // what the old function used to be: - // int64_t innerKTiles = torch::jit::pop(stack).toInt(); - // int64_t group_size = torch::jit::pop(stack).toInt(); - // const at::Tensor &scales_and_zeros = torch::jit::pop(stack).toTensor(); - // const at::Tensor &packed_w = torch::jit::pop(stack).toTensor(); - // AtenTensorHandle packed_w_ath = - // torch::aot_inductor::tensor_pointer_to_tensor_handle(&packed_w); - // AtenTensorHandle scales_and_zeros_ath = - // torch::aot_inductor::tensor_pointer_to_tensor_handle(&scales_and_zeros); - - // AtenTensorHandle ath_res = _ATH_dequantize_tensor_core_tiled_layout( - // packed_w_ath, scales_and_zeros_ath, group_size, innerKTiles); + // read the output from the end of the stack and wrap that back into + // IValue from void*? + for (auto idx = 0; idx < num_returns; idx++) { + const c10::TypePtr& ret_type = schema.returns()[idx].type(); + if (*ret_type == *c10::getTypePtr()) { + AtenTensorHandle ret_ath = reinterpret_cast( ministack[num_arguments + idx]); + at::Tensor out = *torch::aot_inductor::tensor_handle_to_tensor_pointer(ret_ath); + torch::jit::push(stack, c10::IValue(out)); + } else { + TORCH_CHECK(false, "Only Tensor return types are currently supported!"); + } + } - // at::Tensor out = - // *torch::aot_inductor::tensor_handle_to_tensor_pointer(ath_res); - // torch::jit::push(stack, c10::IValue(out)); + free(ministack); }