Skip to content

Commit

Permalink
[Fix]: Fallback to KleidiAI channelwise kernel groupsize isnt suitable
Browse files Browse the repository at this point in the history
Description:
1. Some models can have certain odd shapes which can not be used with
   blocked quantization. Fallback to channelwise quantization for those
   shapes.
2. Fix Formatting issue with experimental tests

Signed-off-by: Nikhil Gupta <[email protected]>
  • Loading branch information
ng-05 committed Jan 31, 2025
1 parent 3eb18e7 commit 9a17900
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 15 deletions.
31 changes: 19 additions & 12 deletions torchao/experimental/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,8 +587,6 @@ def apply(weight, bias: Optional[torch.Tensor] = None):
f"granularity must be PerGroup or PerRow, got {granularity}"
)

assert weight.shape[-1] % group_size == 0

layout = layout_arg
scale_dtype = None
tensor_quantizer = to_affine_quantized_intx
Expand All @@ -605,13 +603,8 @@ def apply(weight, bias: Optional[torch.Tensor] = None):
assert (
act_mapping_type == MappingType.ASYMMETRIC
), "PackedLinearInt8DynamicActivationIntxWeightLayout requires act_mapping_type=MappingType.ASYMMETRIC"
assert not layout.has_params_set(), "PackedLinearInt8DynamicActivationIntxWeightLayout params should not already be set"
layout = PackedLinearInt8DynamicActivationIntxWeightLayout(
bit_width=bit_width,
group_size=group_size,
has_weight_zeros=has_weight_zeros,
target="aten" if layout.target == Target.ATEN else "native",
)
assert not layout.has_params_set(
), "PackedLinearInt8DynamicActivationIntxWeightLayout params should not already be set"
if layout.target == Target.ATEN:
if (
weight_dtype != torch.int4
Expand All @@ -628,15 +621,28 @@ def apply(weight, bias: Optional[torch.Tensor] = None):
assert (
TORCH_VERSION_AT_LEAST_2_6
), "aten target is requires torch version > 2.6.0"
# Fallback to Channelwise scheme if group_size is too big
if weight.shape[-1] < group_size:
logger.warning(f"Changing group_size to {
weight.shape[-1]}. Weight shape {weight.shape} can not support group_size {group_size}.")
group_size = weight.shape[-1]
if torch.backends.kleidiai.is_available():
if isinstance(granularity, PerGroup):
if weight.shape[-1] != group_size and group_size % 32 == 0:
scale_dtype = (
torch.bfloat16
) # KleidiAI kernel requires bfloat16 scale_dtype
tensor_quantizer = (
to_packedlinearint8dynamicactivationintxweight_quantized_intx
)

layout = PackedLinearInt8DynamicActivationIntxWeightLayout(
bit_width=bit_width,
group_size=group_size,
has_weight_zeros=has_weight_zeros,
target="aten" if layout.target == Target.ATEN else "native",
)

assert weight.shape[-1] % group_size == 0
quantizer_args = [
weight,
weight_mapping_type,
Expand All @@ -658,7 +664,7 @@ def apply(weight, bias: Optional[torch.Tensor] = None):
# Note that PackedLinearInt8DynamicActivationIntxWeightLayout has dynamic activation quantization fused
# with the kernel and it should not be applied separately
if not isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout):
activation_quant_func = lambda x: to_affine_quantized_intx(
def activation_quant_func(x): return to_affine_quantized_intx(
x,
mapping_type=act_mapping_type,
block_size=_get_per_token_block_size(x),
Expand All @@ -668,7 +674,8 @@ def apply(weight, bias: Optional[torch.Tensor] = None):
scale_dtype=torch.float32,
zero_point_dtype=torch.int32,
)
weight = to_linear_activation_quantized(weight, activation_quant_func)
weight = to_linear_activation_quantized(
weight, activation_quant_func)
return weight

return _get_linear_subclass_inserter(apply, propagate_bias=propagate_bias)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import copy
import unittest

import torch

from torchao.dtypes import PlainLayout
Expand Down Expand Up @@ -43,8 +42,7 @@ def test_accuracy(self):
for has_weight_zeros in [True]:
for granularity in granularities:
print(
f"Testing weight_dtype={weight_dtype}, has_weight_zeros={
has_weight_zeros}, granularity={granularity}"
f"Testing weight_dtype={weight_dtype}, has_weight_zeros={has_weight_zeros}, granularity={granularity}"
)
quantized_model = copy.deepcopy(model)
quantize_(
Expand Down

0 comments on commit 9a17900

Please sign in to comment.