From 99e0dc30e0bdc4f8a8cfafef4afac57c9705a468 Mon Sep 17 00:00:00 2001 From: ZhiweiYan-96 Date: Mon, 20 Jan 2025 06:18:29 +0000 Subject: [PATCH 1/2] Enable s8 in QuantizedMaxPool2d kernel --- src/ATen/native/quantized/QuantizedMaxPool2d.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/ATen/native/quantized/QuantizedMaxPool2d.cpp b/src/ATen/native/quantized/QuantizedMaxPool2d.cpp index 0d559704d..e1168c332 100644 --- a/src/ATen/native/quantized/QuantizedMaxPool2d.cpp +++ b/src/ATen/native/quantized/QuantizedMaxPool2d.cpp @@ -4,6 +4,7 @@ #include #include #include +#include "c10/core/ScalarType.h" namespace at { namespace native { @@ -32,7 +33,7 @@ class QMaxPool_arr_args final { bool ceil_mode) { // Now we only support Byte, qint is not supported. TORCH_CHECK( - qx.scalar_type() == c10::ScalarType::Byte, + qx.scalar_type() == c10::ScalarType::Byte || qx.scalar_type() == c10::ScalarType::Char, "QuantizedMaxPool2d only supports Byte for xpu now"); return at::native::quantized_max_pool2d_xpu( qx, kernel_size, stride, padding, dilation, ceil_mode); From f369106a7c24cfef02bbb726ba9831322427133f Mon Sep 17 00:00:00 2001 From: ZhiweiYan-96 Date: Mon, 20 Jan 2025 08:51:54 +0000 Subject: [PATCH 2/2] Correct error message --- src/ATen/native/quantized/QuantizedMaxPool2d.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ATen/native/quantized/QuantizedMaxPool2d.cpp b/src/ATen/native/quantized/QuantizedMaxPool2d.cpp index e1168c332..a86ebcf49 100644 --- a/src/ATen/native/quantized/QuantizedMaxPool2d.cpp +++ b/src/ATen/native/quantized/QuantizedMaxPool2d.cpp @@ -34,7 +34,7 @@ class QMaxPool_arr_args final { // Now we only support Byte, qint is not supported. TORCH_CHECK( qx.scalar_type() == c10::ScalarType::Byte || qx.scalar_type() == c10::ScalarType::Char, - "QuantizedMaxPool2d only supports Byte for xpu now"); + "QuantizedMaxPool2d only supports quantized tensor with Byte/Char type at XPU backend"); return at::native::quantized_max_pool2d_xpu( qx, kernel_size, stride, padding, dilation, ceil_mode); }