Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix]: Fallback to KleidiAI channelwise kernel groupsize isnt suitable #1647

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 19 additions & 12 deletions torchao/experimental/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,8 +587,6 @@ def apply(weight, bias: Optional[torch.Tensor] = None):
f"granularity must be PerGroup or PerRow, got {granularity}"
)

assert weight.shape[-1] % group_size == 0

layout = layout_arg
scale_dtype = None
tensor_quantizer = to_affine_quantized_intx
Expand All @@ -605,13 +603,8 @@ def apply(weight, bias: Optional[torch.Tensor] = None):
assert (
act_mapping_type == MappingType.ASYMMETRIC
), "PackedLinearInt8DynamicActivationIntxWeightLayout requires act_mapping_type=MappingType.ASYMMETRIC"
assert not layout.has_params_set(), "PackedLinearInt8DynamicActivationIntxWeightLayout params should not already be set"
layout = PackedLinearInt8DynamicActivationIntxWeightLayout(
bit_width=bit_width,
group_size=group_size,
has_weight_zeros=has_weight_zeros,
target="aten" if layout.target == Target.ATEN else "native",
)
assert not layout.has_params_set(
), "PackedLinearInt8DynamicActivationIntxWeightLayout params should not already be set"
if layout.target == Target.ATEN:
if (
weight_dtype != torch.int4
Expand All @@ -628,15 +621,28 @@ def apply(weight, bias: Optional[torch.Tensor] = None):
assert (
TORCH_VERSION_AT_LEAST_2_6
), "aten target is requires torch version > 2.6.0"
# Fallback to Channelwise scheme if group_size is too big
if weight.shape[-1] < group_size:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd recommend adding a top-level configuration option which clearly tells the user "group size can be changed for certain weight shapes" to handle this case, and throwing an exception if that config setting isn't on.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why we are doing a fallback here. If group_size doesn't divide weight.shape[-1], why not raise exception and let user explicitly move to channelwise?

logger.warning(f"Changing group_size to {
weight.shape[-1]}. Weight shape {weight.shape} can not support group_size {group_size}.")
group_size = weight.shape[-1]
if torch.backends.kleidiai.is_available():
if isinstance(granularity, PerGroup):
if weight.shape[-1] != group_size and group_size % 32 == 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

weight.shape[-1] != group_size I don't understand this constraint.

IIRC the constraint is k % group_size == 0 and group_size % 32 == 0

scale_dtype = (
torch.bfloat16
) # KleidiAI kernel requires bfloat16 scale_dtype
tensor_quantizer = (
to_packedlinearint8dynamicactivationintxweight_quantized_intx
)

layout = PackedLinearInt8DynamicActivationIntxWeightLayout(
bit_width=bit_width,
group_size=group_size,
has_weight_zeros=has_weight_zeros,
target="aten" if layout.target == Target.ATEN else "native",
)

assert weight.shape[-1] % group_size == 0
quantizer_args = [
weight,
weight_mapping_type,
Expand All @@ -658,7 +664,7 @@ def apply(weight, bias: Optional[torch.Tensor] = None):
# Note that PackedLinearInt8DynamicActivationIntxWeightLayout has dynamic activation quantization fused
# with the kernel and it should not be applied separately
if not isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout):
activation_quant_func = lambda x: to_affine_quantized_intx(
def activation_quant_func(x): return to_affine_quantized_intx(
x,
mapping_type=act_mapping_type,
block_size=_get_per_token_block_size(x),
Expand All @@ -668,7 +674,8 @@ def apply(weight, bias: Optional[torch.Tensor] = None):
scale_dtype=torch.float32,
zero_point_dtype=torch.int32,
)
weight = to_linear_activation_quantized(weight, activation_quant_func)
weight = to_linear_activation_quantized(
weight, activation_quant_func)
return weight

return _get_linear_subclass_inserter(apply, propagate_bias=propagate_bias)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import copy
import unittest

import torch

from torchao.dtypes import PlainLayout
Expand Down Expand Up @@ -43,8 +42,7 @@ def test_accuracy(self):
for has_weight_zeros in [True]:
for granularity in granularities:
print(
f"Testing weight_dtype={weight_dtype}, has_weight_zeros={
has_weight_zeros}, granularity={granularity}"
f"Testing weight_dtype={weight_dtype}, has_weight_zeros={has_weight_zeros}, granularity={granularity}"
)
quantized_model = copy.deepcopy(model)
quantize_(
Expand Down