Skip to content

Commit

Permalink
Add autotune support for PT2E (#2110)
Browse files Browse the repository at this point in the history
Add autotune support for PT2E and disable some conv1d-related test on HPU
---------

Signed-off-by: yiliu30 <[email protected]>
Co-authored-by: Xin He <[email protected]>
  • Loading branch information
yiliu30 and xin3he authored Jan 24, 2025
1 parent d2e49d2 commit a617115
Show file tree
Hide file tree
Showing 7 changed files with 138 additions and 34 deletions.
8 changes: 7 additions & 1 deletion neural_compressor/common/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from __future__ import annotations

import copy
import inspect
import json
import os
Expand Down Expand Up @@ -539,6 +540,7 @@ def expand(self) -> List[BaseConfig]:
tuning_param_pair = dict(zip(tuning_param_name_lst, params_values))
tmp_params_dict = {**not_tuning_param_pair, **tuning_param_pair}
new_config = self.__class__(**tmp_params_dict)
new_config.local_config = copy.deepcopy(self.local_config)
logger.info(new_config.to_dict())
config_list.append(new_config)
logger.info("Expanded the %s and got %d configs.", self.__class__.name, len(config_list))
Expand Down Expand Up @@ -629,9 +631,13 @@ def __eq__(self, other: BaseConfig) -> bool:
"""
if not isinstance(other, type(self)):
return False
return self.params_list == other.params_list and all(

params_equal = self.params_list == other.params_list and all(
getattr(self, str(attr)) == getattr(other, str(attr)) for attr in self.params_list
)
local_config_equal = self.local_config == other.local_config
global_config_equal = self.global_config == other.global_config
return params_equal and local_config_equal and global_config_equal


class ComposableConfig(BaseConfig):
Expand Down
4 changes: 3 additions & 1 deletion neural_compressor/torch/algorithms/pt2e_quant/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from typing import Any

import torch
import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer
Expand Down Expand Up @@ -102,4 +103,5 @@ def half_precision_transformation(self, model, config):
"""
half_precision_node_set = hp_rewriter.get_half_precision_node_set(model, config)
logger.info("Try to convert %d nodes to half precision.", len(half_precision_node_set))
hp_rewriter.transformation(model, half_precision_node_set)
hp_rewriter.transformation(model, half_precision_node_set, torch.float16)
hp_rewriter.transformation(model, half_precision_node_set, torch.bfloat16)
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
"""Rewrite the FP32 operators to FP16 or BF16 operators."""

from collections import defaultdict
from dataclasses import dataclass
from functools import partial
from typing import Any, Callable, Dict, List, Tuple
Expand All @@ -25,7 +26,7 @@
from torch.fx.subgraph_rewriter import Match
from typing_extensions import TypeAlias

from neural_compressor.common import utils
from neural_compressor.common import logger, utils

# =============================================================================
# Search and replace patterns
Expand All @@ -50,25 +51,44 @@ class PatternPair:

# key: torch func
# value: the tuple of args
FuncArgsMappingType: TypeAlias = Dict[TorchFuncType, Tuple[torch.Tensor, ...]]
FuncArgsMappingType: TypeAlias = Dict[TorchFuncType, List[Tuple[torch.Tensor, ...]]]


# Align with https://pytorch.org/docs/stable/amp.html#cpu-ops-that-can-autocast-to-bfloat16
# TODO: complete the mapping
# Align with xiq, as it relay on xiq's set_module_xx capability
FN_ARGS_MAPPING: FuncArgsMappingType = {
torch.nn.functional.linear: (torch.randn(0, 0), torch.randn(0, 0)), # linear w/o bias
torch.nn.functional.linear: (torch.randn(0, 0), torch.randn(0, 0), torch.randn(0)), # linear w/ bias
# Note: ORDER is matter
torch.nn.functional.linear: [
(torch.randn(0, 0), torch.randn(0, 0)), # linear w/o bias
(torch.randn(0, 0), torch.randn(0, 0), torch.randn(0)), # linear w/ bias
],
torch.nn.functional.conv2d: [
(torch.randn(1, 1, 1, 1), torch.randn(1, 1, 1, 1)), # conv2d w/o bias
(torch.randn(1, 1, 1, 1), torch.randn(1, 1, 1, 1), torch.randn(1)), # conv2d w/ bias
],
torch.matmul: [
(torch.randn(0, 0), torch.randn(0, 0)),
(torch.randn(0, 0, 0), torch.randn(0, 0, 0)),
(torch.randn(0, 0, 0, 0), torch.randn(0, 0, 0, 0)),
],
}
# TODO: complete the mapping
FN_ATEN_OPS_MAPPING = {
torch.nn.functional.linear: torch.ops.aten.linear.default,

# module cls <-> function name
NN_MODULES_TO_NN_FN = {
torch.nn.Linear: torch.nn.functional.linear,
torch.nn.Conv2d: torch.nn.functional.conv2d,
}

# Use the mapping from xiq
FN_ATEN_OPS_MAPPING = xiq._map_module_function_to_aten_operator_type()

SUPPORTED_OPERATORS = FN_ATEN_OPS_MAPPING.values()


PatternRegistryType: TypeAlias = Dict[TorchFuncType, PatternPair]
HALF_PRECISION_PATTERN_REGISTRY: Dict[torch.dtype, PatternRegistryType] = {torch.float16: {}, torch.bfloat16: {}}
HALF_PRECISION_PATTERN_REGISTRY: Dict[torch.dtype, PatternRegistryType] = {
torch.float16: defaultdict(list),
torch.bfloat16: defaultdict(list),
}

# FP16_PATTERN_REGISTRY: PatternRegistryType = HALF_PRECISION_PATTERN_REGISTRY[torch.float16]
# BF16_PATTERN_REGISTRY: PatternRegistryType = HALF_PRECISION_PATTERN_REGISTRY[torch.bfloat16]
Expand Down Expand Up @@ -98,15 +118,18 @@ def replace_fn_wrapper(fn_args, fn):


def _register_pattern_pair(dtype: torch.dtype) -> None:
for fn, fn_args in FN_ARGS_MAPPING.items():
pattern_pair = pattern_factory(fn, fn_args)
HALF_PRECISION_PATTERN_REGISTRY[dtype][fn] = pattern_pair
utils.logger.info(
for fn, fn_args_lst in FN_ARGS_MAPPING.items():
for fn_args in fn_args_lst:
logger.debug(f"Registering search and replace patterns for {fn} with args: {fn_args}.")
pattern_pair = pattern_factory(fn, fn_args)
HALF_PRECISION_PATTERN_REGISTRY[dtype][fn].append(pattern_pair)
utils.logger.debug(
f"Registered {len(HALF_PRECISION_PATTERN_REGISTRY[dtype])} search and replace patterns for {dtype}."
)


_register_pattern_pair(torch.float16)
_register_pattern_pair(torch.bfloat16)


def get_filter_fn(node_list, fn):
Expand Down Expand Up @@ -182,9 +205,10 @@ def get_unquantized_node_set(gm: torch.fx.GraphModule):

def transformation(gm: torch.fx.GraphModule, node_candidate_list: List[str], target_dtype: torch.dtype = torch.float16):
"""Convert the nodes in `node_candidate_list` to `target_dtype` if possible."""
for pattern_pair in HALF_PRECISION_PATTERN_REGISTRY[target_dtype].values():
apply_single_pattern_pair(gm, pattern_pair, node_candidate_list)
utils.logger.info("Half precision conversion is done:")
for pattern_pair_lst in HALF_PRECISION_PATTERN_REGISTRY[target_dtype].values():
for pattern_pair in pattern_pair_lst:
apply_single_pattern_pair(gm, pattern_pair, node_candidate_list)
utils.logger.info(f"Half precision conversion({target_dtype}) completed.")
if utils.level_name == "DEBUG": # pragma: no cover
gm.print_readable(True)

Expand All @@ -201,11 +225,11 @@ def _parse_node_candidate_set_from_user_config(config, gm):
op_name_filters = []
for op_type_name, config in op_type_configs.items(): # pragma: no cover
op_type = getattr(torch.nn, op_type_name)
if config.act_dtype == "fp16": # pragma: no cover
if config.act_dtype in ["fp16", "bf16"]: # pragma: no cover
filter = xpq._get_module_type_filter(op_type)
op_type_filters.append(filter)
for op_name, config in op_name_configs.items():
if config.act_dtype == "fp16": # pragma: no cover
if config.act_dtype in ["fp16", "bf16"]: # pragma: no cover
filter = xpq._get_module_name_filter(op_name)
op_name_filters.append(filter)
node_set_from_user_config = set()
Expand Down Expand Up @@ -237,5 +261,7 @@ def get_half_precision_node_set(gm, config):
for node in possible_node_set:
if node.target in SUPPORTED_OPERATORS:
half_precision_node_set.add(node)
utils.logger.info(f"Found {len(half_precision_node_set)} nodes to convert to half precision.")
utils.logger.info(
f"Found {len(half_precision_node_set)} nodes to convert to half precision: {half_precision_node_set}"
)
return half_precision_node_set
3 changes: 2 additions & 1 deletion neural_compressor/torch/algorithms/pt2e_quant/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from torch.ao.quantization.quantizer import QuantizationSpec
from torch.ao.quantization.quantizer.x86_inductor_quantizer import QuantizationConfig, X86InductorQuantizer

from neural_compressor.torch.utils import GT_OR_EQUAL_TORCH_VERSION_2_5
from neural_compressor.torch.utils import GT_OR_EQUAL_TORCH_VERSION_2_5, logger


def create_quant_spec_from_config(dtype, sym, granularity, algo, is_dynamic=False) -> QuantizationSpec:
Expand Down Expand Up @@ -79,6 +79,7 @@ def create_quant_spec_from_config(dtype, sym, granularity, algo, is_dynamic=Fals
def _map_inc_config_to_torch_quant_config(inc_config, is_dynamic=False) -> QuantizationConfig:
NOT_QUANT_DTYPES = ["fp32", "fp16", "bf16"]
if inc_config.act_dtype in NOT_QUANT_DTYPES and inc_config.w_dtype in NOT_QUANT_DTYPES: # pragma: no cover
logger.debug("Got non-quantizable data types, skipping quantization.")
return None
default_quant_config = xiq.get_default_x86_inductor_quantization_config(is_dynamic=is_dynamic)
input_act_quant_spec = create_quant_spec_from_config(
Expand Down
15 changes: 12 additions & 3 deletions neural_compressor/torch/quantization/autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,15 @@ def get_all_config_set() -> Union[BaseConfig, List[BaseConfig]]:
return get_all_config_set_from_config_registry(fwk_name=FRAMEWORK_NAME)


def _deepcopy_warp(model):
additional_attr_lst = ["_exported", "dynamic_shapes"]
original_attr = {key: getattr(model, key, None) for key in additional_attr_lst}
new_model = deepcopy(model)
for key, value in original_attr.items():
setattr(new_model, key, value)
return new_model


@dump_elapsed_time("Pass auto-tune")
def autotune(
model: torch.nn.Module,
Expand Down Expand Up @@ -81,7 +90,7 @@ def autotune(
best_quant_model = None
eval_func_wrapper = EvaluationFuncWrapper(eval_fn, eval_args)
config_loader, tuning_logger, tuning_monitor = init_tuning(tuning_config=tune_config)
baseline: float = eval_func_wrapper.evaluate(deepcopy(model))
baseline: float = eval_func_wrapper.evaluate(_deepcopy_warp(model))
tuning_monitor.set_baseline(baseline)
tuning_logger.tuning_start()
for trial_index, quant_config in enumerate(config_loader, 1):
Expand All @@ -90,7 +99,7 @@ def autotune(
logger.info(quant_config.to_dict())
# !!! Make sure to use deepcopy only when inplace is set to `True`.
q_model = quantize(
deepcopy(model),
_deepcopy_warp(model),
quant_config=quant_config,
run_fn=run_fn,
run_args=run_args,
Expand All @@ -112,7 +121,7 @@ def autotune(
best_quant_config: BaseConfig = best_trial_record.quant_config
# !!! Make sure to use deepcopy only when inplace is set to `True`.
q_model = quantize(
deepcopy(model),
_deepcopy_warp(model),
quant_config=best_quant_config,
run_fn=run_fn,
run_args=run_args,
Expand Down
74 changes: 66 additions & 8 deletions test/3x/torch/quantization/test_pt2e_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ def _is_ipex_imported():
monkeypatch.setattr("neural_compressor.torch.quantization.algorithm_entry.is_ipex_imported", _is_ipex_imported)
monkeypatch.setattr("neural_compressor.torch.export.pt2e_export.is_ipex_imported", _is_ipex_imported)


class TestPT2EQuantization:
def teardown_class(self):
shutil.rmtree("saved_results", ignore_errors=True)
Expand All @@ -53,15 +52,15 @@ def forward(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
return bar, example_inputs

@staticmethod
def build_model_include_conv_and_linear():
def build_model_include_conv_and_linear(bias=True):
class Model(torch.nn.Module):
def __init__(self):
def __init__(self, bias=True):
super(Model, self).__init__()
self.conv1 = torch.nn.Conv2d(3, 6, 5)
self.conv1 = torch.nn.Conv2d(3, 6, 5, bias=bias)
self.pool = torch.nn.MaxPool2d(2, 2)
self.conv2 = torch.nn.Conv2d(6, 16, 5)
self.fc1 = torch.nn.Linear(16 * 5 * 5, 120)
self.fc2 = torch.nn.Linear(120, 84)
self.conv2 = torch.nn.Conv2d(6, 16, 5, bias=bias)
self.fc1 = torch.nn.Linear(16 * 5 * 5, 120, bias=bias)
self.fc2 = torch.nn.Linear(120, 84, bias=bias)

def forward(self, x):
x = self.conv1(x)
Expand All @@ -74,7 +73,7 @@ def forward(self, x):

return x

model = Model()
model = Model(bias)
example_inputs = (torch.randn(1, 3, 32, 32),)
return model, example_inputs

Expand Down Expand Up @@ -283,3 +282,62 @@ def test_mixed_fp16_and_int8(self, force_not_import_ipex):
opt_model = torch.compile(converted_model)
out = opt_model(*example_inputs)
assert out is not None

@pytest.mark.skipif(not GT_OR_EQUAL_TORCH_VERSION_2_5, reason="Requires torch>=2.5")
@pytest.mark.parametrize("half_precision_dtype", ["fp16", "bf16"])
@pytest.mark.parametrize("op_name_or_type", ["conv1", "fc1", torch.nn.Linear, torch.nn.Conv2d])
@pytest.mark.parametrize("bias", [True, False])
def test_auto_tune_mixed_int8_and_16bits(self, half_precision_dtype, op_name_or_type, bias, force_not_import_ipex):
# Just make sure the pattern matches, not the accuracy.
# config1: int8 for all
# config2: half precision for linear/conv
from neural_compressor.torch.quantization.config import INT8StaticQuantConfig
from neural_compressor.torch.quantization.autotune import autotune, TuningConfig

config1 = INT8StaticQuantConfig()
config2 = INT8StaticQuantConfig().set_local(
op_name_or_type, StaticQuantConfig(w_dtype=half_precision_dtype, act_dtype=half_precision_dtype)
)
tune_config = TuningConfig(config_set=[config1, config2], tolerable_loss=-0.1)
eval_result = [1, 1, 2]

def fake_eval_fn(model):
res = eval_result.pop(0)
return res

def run_fn(model):
for i in range(2):
model(*example_inputs)

model, example_inputs = self.build_model_include_conv_and_linear(bias)
model = export(model, example_inputs=example_inputs)
qmodel = autotune(
model=model, tune_config=tune_config, eval_fn=fake_eval_fn, run_fn=run_fn, example_inputs=example_inputs
)

# Calculate the expected number of `aten.to` operations based on bias and op_name_or_type
"""
| Bias | op_name | nn.Module |
|-------|---------|-----------|
| True | 4 | 8 |
| False | 3 | 6 |
"""
expected_node_occurrence = {
torch.ops.aten.to.dtype: (3 + int(bias)) * (1 if isinstance(op_name_or_type, str) else 2)
}

expected_node_occurrence = {
torch_test_quant_common.NodeSpec.call_function(k): v for k, v in expected_node_occurrence.items()
}
node_in_graph = self.get_node_in_graph(qmodel)
for node, cnt in expected_node_occurrence.items():
assert (
node_in_graph.get(node, 0) == cnt
), f"Node {node} should occur {cnt} times, but {node_in_graph.get(node, 0)}"
# inference
from torch._inductor import config

config.freezing = True
opt_model = torch.compile(qmodel)
out = opt_model(*example_inputs)
assert out is not None
2 changes: 2 additions & 0 deletions test/3x/torch/quantization/weight_only/test_rtn.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,8 @@ def test_rtn_with_quantize_API(self):
), "The results of calling `convert` + `prepare` and calling `quantize` should be equal."

# TODO: (4, True, 32, 0), group_dim=0, format not supported
# TODO [SW-216127]: it's not in high priority, so we can implement it later.
@pytest.mark.skipif(is_hpex_available(), reason="These tests are not supported on HPU for now.")
@pytest.mark.parametrize(
"bits, use_sym, group_size, group_dim",
[
Expand Down

0 comments on commit a617115

Please sign in to comment.