Skip to content

Commit

Permalink
Consolidate ZeroPointDomain.NONE & None zero point domains (#1556)
Browse files Browse the repository at this point in the history
* Fix ZeroPointDomain.NONE support & make it default for da8w8 weights

* Fix bug & apply review recommendations

* Throw exceptions when None zero_point_domain is used

* Use ZeroPointDomain.NONE for weight in int8_dynamic_activation_int8_weight

* Rebase with the latest main branch

* Fix typo
  • Loading branch information
sanchitintel authored Jan 29, 2025
1 parent abd41e5 commit 7b0d2ce
Show file tree
Hide file tree
Showing 10 changed files with 171 additions and 64 deletions.
47 changes: 39 additions & 8 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import logging
import os
import unittest
from functools import partial

import torch
import torch.nn as nn
Expand Down Expand Up @@ -48,6 +49,7 @@
quantize_,
)
from torchao.quantization.quant_primitives import (
MappingType,
dequantize_affine,
)
from torchao.quantization.smoothquant import (
Expand Down Expand Up @@ -102,6 +104,8 @@

COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16]

ACT_MAPPING_TYPES = [MappingType.ASYMMETRIC, MappingType.SYMMETRIC]

COMMON_DEVICE_DTYPE = list(itertools.product(COMMON_DEVICES, COMMON_DTYPES)).copy()


Expand All @@ -121,9 +125,18 @@ def _int8wo_groupwise_api(mod):
quantize_(mod, int8_weight_only(group_size=group_size), set_inductor_config=False)


def _int8da_int8w_api(mod):
def _int8da_int8w_api(
mod,
act_mapping_type=MappingType.SYMMETRIC,
):
if TORCH_VERSION_AT_LEAST_2_4:
quantize_(mod, int8_dynamic_activation_int8_weight(), set_inductor_config=False)
quantize_(
mod,
int8_dynamic_activation_int8_weight(
act_mapping_type=act_mapping_type,
),
set_inductor_config=False,
)
if not TORCH_VERSION_AT_LEAST_2_5:
unwrap_tensor_subclass(mod)
else:
Expand Down Expand Up @@ -962,25 +975,43 @@ def _test_lin_weight_subclass_api_impl(
mod[0].weight.tensor_impl.get_plain()

test = mod(x)

self.assertGreater(
SQNR(ref_f, test),
min_sqnr,
f"{api.__name__} failed, no compile dtype={test_dtype}, (m, k, n)={test_shape}",
f"API failed, no compile dtype={test_dtype}, (m, k, n)={test_shape}",
)

mod_qc = torch.compile(mod, mode="max-autotune")
test_comp = mod_qc(x)
self.assertGreater(
SQNR(ref_f, test_comp),
min_sqnr,
f"{api.__name__} failed when compiled with dtype={test_dtype}, (m, k, n)={test_shape}",
f"API failed when compiled with dtype={test_dtype}, (m, k, n)={test_shape}",
)

@parameterized.expand(COMMON_DEVICE_DTYPE)
def test_int8_dynamic_quant_subclass_api(self, device, dtype):
self._test_lin_weight_subclass_api_impl(
_int8da_int8w_api, device, 35, test_dtype=dtype
@parameterized.expand(
list(
itertools.product(
COMMON_DEVICES,
COMMON_DTYPES,
ACT_MAPPING_TYPES,
)
)
)
def test_int8_dynamic_quant_subclass_api(self, device, dtype, act_mapping):
if (
not TORCH_VERSION_AT_LEAST_2_5
and dtype in (torch.float16, torch.bfloat16)
and act_mapping is MappingType.ASYMMETRIC
and device == "cpu"
):
self.skipTest("Inductor-CPU codegen issue fixed in torch 2.5")
api = partial(
_int8da_int8w_api,
act_mapping_type=act_mapping,
)
self._test_lin_weight_subclass_api_impl(api, device, 35, test_dtype=dtype)

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(is_fbcode(), "broken in fbcode")
Expand Down
17 changes: 9 additions & 8 deletions test/quantization/test_observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
)
from torchao.quantization.quant_primitives import (
MappingType,
ZeroPointDomain,
)


Expand Down Expand Up @@ -74,7 +75,7 @@ def test_block_size_calc_success(self):
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
zero_point_domain=None,
zero_point_domain=ZeroPointDomain.NONE,
)
example_inputs = [
torch.randn(10, 2048),
Expand All @@ -93,7 +94,7 @@ def test_block_size_calc_success(self):
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
zero_point_domain=None,
zero_point_domain=ZeroPointDomain.NONE,
)
for example_input in example_inputs:
obs(example_input)
Expand All @@ -108,7 +109,7 @@ def test_block_size_row_errors(self):
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
zero_point_domain=None,
zero_point_domain=ZeroPointDomain.NONE,
)
example_inputs = [
torch.randn(10, 2048),
Expand All @@ -127,7 +128,7 @@ def test_block_size_row_errors(self):
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
zero_point_domain=None,
zero_point_domain=ZeroPointDomain.NONE,
)
example_inputs = [
torch.randn(10, 2048),
Expand Down Expand Up @@ -155,7 +156,7 @@ def test_linear_observer_tensor(self, observe_weight: bool):
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
zero_point_domain=None,
zero_point_domain=ZeroPointDomain.NONE,
)
if observe_weight:
weight_observer = AffineQuantizedMinMaxObserver(
Expand All @@ -165,7 +166,7 @@ def test_linear_observer_tensor(self, observe_weight: bool):
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
zero_point_domain=None,
zero_point_domain=ZeroPointDomain.NONE,
)
else:
weight_observer = None
Expand Down Expand Up @@ -199,7 +200,7 @@ def test_linear_observer_tensor(self, observe_weight: bool):
input_scale.item(),
max_val / max_fp8,
)
self.assertIsNotNone(input_zero_point)
self.assertIsNone(input_zero_point)

if observe_weight:
weight_observer = linear.weight.weight_observer
Expand All @@ -210,7 +211,7 @@ def test_linear_observer_tensor(self, observe_weight: bool):
atol=5e-5,
rtol=0.0,
)
self.assertIsNotNone(weight_zero_point)
self.assertIsNone(weight_zero_point)
else:
self.assertIsNone(linear.weight.weight_observer)

Expand Down
53 changes: 51 additions & 2 deletions test/quantization/test_quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -843,6 +843,55 @@ def test_fake_quantize_affine_cachemask(self):
torch.testing.assert_close(dequantized, fake_quantized)
torch.testing.assert_close(expected_mask, mask)

def test_none_zero_point_domain(self):
"""A None value for a ZeroPointDomain should not work, but ZeroPointDomain.NONE should"""
input = torch.randn(10, 256)
mapping_type = MappingType.SYMMETRIC
dtype = torch.int8
block_size = (1, 128)
quant_min = None
quant_max = None
eps = 1e-6
scale_dtype = torch.float32
zero_point_dtype = torch.int64
try:
_, zero_point = choose_qparams_affine(
input,
mapping_type,
block_size,
dtype,
quant_min,
quant_max,
eps,
scale_dtype=scale_dtype,
zero_point_dtype=zero_point_dtype,
preserve_zero=True,
zero_point_domain=None,
)
except ValueError:
# This exception was expected
# Now test for ZeroPointDomain.NONE
_, zero_point = choose_qparams_affine(
input,
mapping_type,
block_size,
dtype,
quant_min,
quant_max,
eps,
scale_dtype=scale_dtype,
zero_point_dtype=zero_point_dtype,
preserve_zero=True,
zero_point_domain=ZeroPointDomain.NONE,
)
self.assertTrue(zero_point is None)
else:
# An exception should have been thrown for zero_point_domain None
self.assertTrue(
False,
msg="A runtime exception should have been thrown for zero_point_domain None",
)

@parameterized.expand(
[
(
Expand Down Expand Up @@ -890,7 +939,7 @@ def test_float8_quant_primitives(self, hp_dtype, float8_dtype):
quant_min=torch.finfo(float8_dtype).min,
quant_max=torch.finfo(float8_dtype).max,
zero_point=None,
zero_point_domain=None,
zero_point_domain=ZeroPointDomain.NONE,
)
expected_dequantized = dequantize_affine(
expected_quantized,
Expand All @@ -901,7 +950,7 @@ def test_float8_quant_primitives(self, hp_dtype, float8_dtype):
quant_min=torch.finfo(float8_dtype).min,
quant_max=torch.finfo(float8_dtype).max,
zero_point=None,
zero_point_domain=None,
zero_point_domain=ZeroPointDomain.NONE,
)

self.assertTrue(torch.equal(expected_scale, scale))
Expand Down
20 changes: 11 additions & 9 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ def __new__(
dtype=None,
strides=None,
):
if zero_point_domain is None:
raise ValueError("please use ZeroPointDomain.NONE instead of None")
kwargs = {}
kwargs["device"] = tensor_impl.device
kwargs["layout"] = (
Expand Down Expand Up @@ -199,7 +201,7 @@ def from_hp_to_intx(
scale_dtype: Optional[torch.dtype] = None,
zero_point_dtype: Optional[torch.dtype] = None,
preserve_zero: bool = True,
zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT,
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
_layout: Layout = PlainLayout(),
use_hqq: bool = False,
):
Expand Down Expand Up @@ -258,8 +260,7 @@ def from_hp_to_intx(
zero_point_domain,
)
# choose_qparams_affine is a custom op that does support returning optional Tensors. We thus set the zero_point to None if its domain is None
# TODO should probably consolidate ZeroPointDomain.NONE and None
if zero_point_domain is None or zero_point_domain == ZeroPointDomain.NONE:
if zero_point_domain == ZeroPointDomain.NONE:
zero_point = None
data = quantize_affine(
input_float,
Expand Down Expand Up @@ -296,14 +297,15 @@ def from_hp_to_intx_static(
target_dtype: torch.dtype,
quant_min: Optional[int] = None,
quant_max: Optional[int] = None,
zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT,
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
_layout: Layout = PlainLayout(),
):
"""Create an integer AffineQuantizedTensor from a high precision tensor using static parameters."""
if zero_point_domain is None:
raise ValueError("please use ZeroPointDomain.NONE instead of None")
elif zero_point_domain is ZeroPointDomain.NONE and zero_point is not None:
raise ValueError("zero_point should be None when zero_point_domain is NONE")
if target_dtype not in FP8_TYPES:
assert (
zero_point_domain is not None
), "zero_point_domain must be specified for non-fp8 types"
assert (
zero_point is not None
), "zero_point must be specified for non-fp8 types"
Expand Down Expand Up @@ -359,7 +361,7 @@ def from_hp_to_floatx(
scale_dtype=scale_dtype,
zero_point_dtype=None,
preserve_zero=True,
zero_point_domain=None,
zero_point_domain=ZeroPointDomain.NONE,
_layout=_layout,
use_hqq=False,
)
Expand Down Expand Up @@ -387,7 +389,7 @@ def from_hp_to_floatx_static(
target_dtype=target_dtype,
quant_min=math.ceil(torch.finfo(target_dtype).min),
quant_max=math.ceil(torch.finfo(target_dtype).max),
zero_point_domain=None,
zero_point_domain=ZeroPointDomain.NONE,
_layout=_layout,
)
else:
Expand Down
4 changes: 3 additions & 1 deletion torchao/dtypes/uintx/marlin_qqq_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,12 @@ def from_hp_to_intx(
block_size: Tuple[int, ...],
quant_min: Optional[int] = None,
quant_max: Optional[int] = None,
zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT,
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
_layout: Optional[Layout] = None,
):
"""Converts a floating point tensor to a Marlin QQQ quantized tensor."""
if zero_point_domain is None:
raise ValueError("Please use ZeroPointDomain.NONE instead of None")
original_shape = input_float.shape
input_float = _layout.pre_process(input_float)
nbits = int(math.log2(quant_max - quant_min + 1))
Expand Down
5 changes: 3 additions & 2 deletions torchao/quantization/observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,12 @@ def __init__(
scale_dtype: Optional[torch.dtype] = None,
zero_point_dtype: Optional[torch.dtype] = None,
preserve_zero: bool = True,
zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT,
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
):
super().__init__()
assert granularity is not None, "granularity is None"

if zero_point_domain is None:
raise ValueError("Please use ZeroPointDomain.NONE instead of None")
self.mapping_type = mapping_type
self.target_dtype = target_dtype
self.granularity = granularity
Expand Down
5 changes: 5 additions & 0 deletions torchao/quantization/qat/affine_fake_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ def forward(
preserve_zero: bool = True,
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
) -> "AffineFakeQuantizedTensor":
if zero_point_domain is None:
raise ValueError("Please use ZeroPointDomain.NONE instead of None")

def apply_fake_quant_fn(t: torch.Tensor):
assert isinstance(t, AffineFakeQuantizedTensor)
qmin, qmax = _get_and_check_qmin_qmax(target_dtype, quant_min, quant_max)
Expand Down Expand Up @@ -158,6 +161,8 @@ def from_float(
preserve_zero: bool = True,
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
):
if zero_point_domain is None:
raise ValueError("Please use ZeroPointDomain.NONE instead of None")
return _ToAffineFakeQuantized.apply(
original_input,
mapping_type,
Expand Down
2 changes: 2 additions & 0 deletions torchao/quantization/qat/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ def __init__(
group_size: Optional[int] = None,
is_symmetric: Optional[bool] = None,
):
if zero_point_domain is None:
raise ValueError("Please use ZeroPointDomain.NONE instead of None")
self.dtype = dtype
self.granularity = self._get_granularity(granularity, group_size)
self.mapping_type = self._get_mapping_type(mapping_type, is_symmetric)
Expand Down
Loading

0 comments on commit 7b0d2ce

Please sign in to comment.