From 9a17900318a8e70d5b6ee0ccd54d7f0ca6385fab Mon Sep 17 00:00:00 2001 From: Nikhil Gupta Date: Fri, 31 Jan 2025 15:19:13 +0000 Subject: [PATCH] [Fix]: Fallback to KleidiAI channelwise kernel groupsize isnt suitable Description: 1. Some models can have certain odd shapes which can not be used with blocked quantization. Fallback to channelwise quantization for those shapes. 2. Fix Formatting issue with experimental tests Signed-off-by: Nikhil Gupta --- torchao/experimental/quant_api.py | 31 ++++++++++++------- ...tivation_intx_weight_layout_target_aten.py | 4 +-- 2 files changed, 20 insertions(+), 15 deletions(-) diff --git a/torchao/experimental/quant_api.py b/torchao/experimental/quant_api.py index ea89e98303..65ddef0b89 100644 --- a/torchao/experimental/quant_api.py +++ b/torchao/experimental/quant_api.py @@ -587,8 +587,6 @@ def apply(weight, bias: Optional[torch.Tensor] = None): f"granularity must be PerGroup or PerRow, got {granularity}" ) - assert weight.shape[-1] % group_size == 0 - layout = layout_arg scale_dtype = None tensor_quantizer = to_affine_quantized_intx @@ -605,13 +603,8 @@ def apply(weight, bias: Optional[torch.Tensor] = None): assert ( act_mapping_type == MappingType.ASYMMETRIC ), "PackedLinearInt8DynamicActivationIntxWeightLayout requires act_mapping_type=MappingType.ASYMMETRIC" - assert not layout.has_params_set(), "PackedLinearInt8DynamicActivationIntxWeightLayout params should not already be set" - layout = PackedLinearInt8DynamicActivationIntxWeightLayout( - bit_width=bit_width, - group_size=group_size, - has_weight_zeros=has_weight_zeros, - target="aten" if layout.target == Target.ATEN else "native", - ) + assert not layout.has_params_set( + ), "PackedLinearInt8DynamicActivationIntxWeightLayout params should not already be set" if layout.target == Target.ATEN: if ( weight_dtype != torch.int4 @@ -628,8 +621,13 @@ def apply(weight, bias: Optional[torch.Tensor] = None): assert ( TORCH_VERSION_AT_LEAST_2_6 ), "aten target is requires torch version > 2.6.0" + # Fallback to Channelwise scheme if group_size is too big + if weight.shape[-1] < group_size: + logger.warning(f"Changing group_size to { + weight.shape[-1]}. Weight shape {weight.shape} can not support group_size {group_size}.") + group_size = weight.shape[-1] if torch.backends.kleidiai.is_available(): - if isinstance(granularity, PerGroup): + if weight.shape[-1] != group_size and group_size % 32 == 0: scale_dtype = ( torch.bfloat16 ) # KleidiAI kernel requires bfloat16 scale_dtype @@ -637,6 +635,14 @@ def apply(weight, bias: Optional[torch.Tensor] = None): to_packedlinearint8dynamicactivationintxweight_quantized_intx ) + layout = PackedLinearInt8DynamicActivationIntxWeightLayout( + bit_width=bit_width, + group_size=group_size, + has_weight_zeros=has_weight_zeros, + target="aten" if layout.target == Target.ATEN else "native", + ) + + assert weight.shape[-1] % group_size == 0 quantizer_args = [ weight, weight_mapping_type, @@ -658,7 +664,7 @@ def apply(weight, bias: Optional[torch.Tensor] = None): # Note that PackedLinearInt8DynamicActivationIntxWeightLayout has dynamic activation quantization fused # with the kernel and it should not be applied separately if not isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout): - activation_quant_func = lambda x: to_affine_quantized_intx( + def activation_quant_func(x): return to_affine_quantized_intx( x, mapping_type=act_mapping_type, block_size=_get_per_token_block_size(x), @@ -668,7 +674,8 @@ def apply(weight, bias: Optional[torch.Tensor] = None): scale_dtype=torch.float32, zero_point_dtype=torch.int32, ) - weight = to_linear_activation_quantized(weight, activation_quant_func) + weight = to_linear_activation_quantized( + weight, activation_quant_func) return weight return _get_linear_subclass_inserter(apply, propagate_bias=propagate_bias) diff --git a/torchao/experimental/tests/test_packed_linear_int8_dynamic_activation_intx_weight_layout_target_aten.py b/torchao/experimental/tests/test_packed_linear_int8_dynamic_activation_intx_weight_layout_target_aten.py index 2a08d0e548..d590b64d58 100644 --- a/torchao/experimental/tests/test_packed_linear_int8_dynamic_activation_intx_weight_layout_target_aten.py +++ b/torchao/experimental/tests/test_packed_linear_int8_dynamic_activation_intx_weight_layout_target_aten.py @@ -6,7 +6,6 @@ import copy import unittest - import torch from torchao.dtypes import PlainLayout @@ -43,8 +42,7 @@ def test_accuracy(self): for has_weight_zeros in [True]: for granularity in granularities: print( - f"Testing weight_dtype={weight_dtype}, has_weight_zeros={ - has_weight_zeros}, granularity={granularity}" + f"Testing weight_dtype={weight_dtype}, has_weight_zeros={has_weight_zeros}, granularity={granularity}" ) quantized_model = copy.deepcopy(model) quantize_(