Skip to content

Commit

Permalink
Clean up code, finish other end of void* boxed kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
janeyx99 committed Jan 29, 2025
1 parent 9f816b5 commit 2017d7b
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 63 deletions.
3 changes: 2 additions & 1 deletion test/test_ops.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import itertools
import sys

import pytest
import torch
Expand Down Expand Up @@ -614,4 +615,4 @@ def test_marlin_qqq(batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_fact


if __name__ == "__main__":
pytest.main([__file__])
pytest.main(sys.argv)
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 // at least Ampere

// #include <ATen/ATen.h>
// #include <ATen/core/Tensor.h>
#include <ATen/ATen.h>
#include <ATen/core/Tensor.h>
#include <ATen/DeviceGuard.h>
#include <ATen/core/TensorAccessor.h>
#include <ATen/core/ivalue.h>
Expand Down Expand Up @@ -332,25 +332,6 @@ AtenTensorHandle _ATH_dequantize_tensor_core_tiled_layout(
return out;
}

// output is [n][k] (int32 dtype)
// input is [n / 8][k / (InnerKTiles * 16)][32][innerKTiles / 2]
// scales_and_zeros is [numQGroups][n][2]
// qGroupSize is 32, 64, 128 or 256
// at::Tensor
// _dequantize_tensor_core_tiled_layout(const at::Tensor &packed_w,
// const at::Tensor &scales_and_zeros,
// int64_t group_size, int64_t innerKTiles) {

// AtenTensorHandle packed_w_ath =
// torch::aot_inductor::tensor_pointer_to_tensor_handle(&packed_w);
// AtenTensorHandle scales_and_zeros_ath =
// torch::aot_inductor::tensor_pointer_to_tensor_handle(&scales_and_zeros);

// AtenTensorHandle ath_res = _ATH_dequantize_tensor_core_tiled_layout(
// packed_w_ath, scales_and_zeros_ath, group_size, innerKTiles);

// return *torch::aot_inductor::tensor_handle_to_tensor_pointer(ath_res);
// }

void voidyvoid_boxed_ATH_dequantize_tensor_core_tiled_layout(void **stack,
int64_t num_args,
Expand All @@ -360,8 +341,6 @@ void voidyvoid_boxed_ATH_dequantize_tensor_core_tiled_layout(void **stack,
// schema values for now, and run the function and modify the void* stack
int64_t innerKTiles = reinterpret_cast<int64_t>(stack[3]);
int64_t group_size = reinterpret_cast<int64_t>(stack[2]);
TORCH_WARN(innerKTiles);
TORCH_WARN(group_size);
AtenTensorHandle scales_and_zeros_ath =
reinterpret_cast<AtenTensorHandle>(stack[1]);
AtenTensorHandle packed_w_ath = reinterpret_cast<AtenTensorHandle>(stack[0]);
Expand All @@ -386,68 +365,44 @@ void boxed_dequantize_tensor_core_tiled_layout(const c10::OperatorHandle &op,
const auto& schema = op.schema();
const auto num_returns = schema.returns().size();
const auto num_arguments = schema.arguments().size();
TORCH_CHECK(num_arguments==4);
TORCH_CHECK(num_returns==1);
void **ministack = (void**)malloc((num_arguments + num_returns) * sizeof(void *));

for (auto idx = 0; idx < num_arguments; idx++) {
TORCH_WARN(idx);
const c10::IValue& arg = torch::jit::peek(stack, idx, num_arguments);
if (arg.isInt()) {
ministack[idx] = reinterpret_cast<void *>(arg.toInt());
} else if (arg.isTensor()) {
TORCH_WARN("am tensor!")
const at::Tensor& tensor = arg.toTensor();
AtenTensorHandle ath = torch::aot_inductor::tensor_pointer_to_tensor_handle(&tensor);
ministack[idx] = reinterpret_cast<void *>(ath);
} else {
TORCH_CHECK(false, "Other types of IValues not handled!");
TORCH_CHECK(false, "Other types of IValues not yet handled!");
}
}
TORCH_WARN("done with forloop no problems!")

// second function is going to take a stack of void*, cast them to our
// schema values for now, and run the function and modify the void* stack
voidyvoid_boxed_ATH_dequantize_tensor_core_tiled_layout(ministack, num_arguments,
num_returns);

// now read the output from the end of the stack and wrap that back into
// IValue from void*?

AtenTensorHandle out_ath =
reinterpret_cast<AtenTensorHandle>(ministack[num_arguments]);

free(ministack);

at::Tensor out =
*torch::aot_inductor::tensor_handle_to_tensor_pointer(out_ath);

// now pop everything. if we pop earlier, Tensors would go out of scope
// now pop all inputs on stack. if we pop earlier, Tensors would go out of scope
// before calling the function
torch::jit::drop(stack, num_arguments);
torch::jit::push(stack, c10::IValue(out));

// so above is our stack of IValues, but we cannot have these IValues because
// they are NOT ABI stable! So we need another version of "boxed" with void*s.
// and that is what is going to happen below

// what the old function used to be:
// int64_t innerKTiles = torch::jit::pop(stack).toInt();
// int64_t group_size = torch::jit::pop(stack).toInt();
// const at::Tensor &scales_and_zeros = torch::jit::pop(stack).toTensor();
// const at::Tensor &packed_w = torch::jit::pop(stack).toTensor();

// AtenTensorHandle packed_w_ath =
// torch::aot_inductor::tensor_pointer_to_tensor_handle(&packed_w);
// AtenTensorHandle scales_and_zeros_ath =
// torch::aot_inductor::tensor_pointer_to_tensor_handle(&scales_and_zeros);

// AtenTensorHandle ath_res = _ATH_dequantize_tensor_core_tiled_layout(
// packed_w_ath, scales_and_zeros_ath, group_size, innerKTiles);
// read the output from the end of the stack and wrap that back into
// IValue from void*?
for (auto idx = 0; idx < num_returns; idx++) {
const c10::TypePtr& ret_type = schema.returns()[idx].type();
if (*ret_type == *c10::getTypePtr<at::Tensor>()) {
AtenTensorHandle ret_ath = reinterpret_cast<AtenTensorHandle>( ministack[num_arguments + idx]);
at::Tensor out = *torch::aot_inductor::tensor_handle_to_tensor_pointer(ret_ath);
torch::jit::push(stack, c10::IValue(out));
} else {
TORCH_CHECK(false, "Only Tensor return types are currently supported!");
}
}

// at::Tensor out =
// *torch::aot_inductor::tensor_handle_to_tensor_pointer(ath_res);
// torch::jit::push(stack, c10::IValue(out));
free(ministack);
}


Expand Down

0 comments on commit 2017d7b

Please sign in to comment.