diff --git a/test/float8/test_base.py b/test/float8/test_base.py index 3e894c02b9..7913cced6a 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -15,9 +15,9 @@ import torch.nn as nn from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, is_sm_at_least_89, is_sm_at_least_90, + TORCH_VERSION_AT_LEAST_2_5, ) if not TORCH_VERSION_AT_LEAST_2_5: @@ -26,13 +26,13 @@ from torchao.float8.config import ( CastConfig, + e4m3_dtype, + e5m2_dtype, Float8LinearConfig, Float8LinearRecipeName, + recipe_name_to_linear_config, ScalingGranularity, ScalingType, - e4m3_dtype, - e5m2_dtype, - recipe_name_to_linear_config, ) from torchao.float8.float8_linear import Float8Linear from torchao.float8.float8_linear_utils import ( @@ -48,15 +48,15 @@ from torchao.float8.float8_tensor import ( Float8Tensor, GemmInputRole, + hp_tensor_and_scale_to_float8, LinearMMConfig, ScaledMMConfig, - hp_tensor_and_scale_to_float8, ) from torchao.float8.float8_utils import ( - FP8_TYPES, compute_error, config_has_stateful_scaling, fp8_tensor_statistics, + FP8_TYPES, tensor_to_scale, ) from torchao.float8.stateful_float8_linear import StatefulFloat8Linear @@ -164,7 +164,8 @@ def test_transpose(self): @pytest.mark.parametrize("shape", [(8, 16), (4, 8, 16), (2, 4, 8, 16)]) @pytest.mark.parametrize("axiswise_dim", [0, -1]) - def test_axiswise_dynamic_cast(self, shape, axiswise_dim): + @pytest.mark.parametrize("power_of_2_scale", [True, False]) + def test_axiswise_dynamic_cast(self, shape, axiswise_dim, power_of_2_scale): a = torch.randn(*shape, dtype=torch.bfloat16) linear_mm_config = LinearMMConfig() a_fp8 = hp_tensor_to_float8_dynamic( @@ -173,6 +174,7 @@ def test_axiswise_dynamic_cast(self, shape, axiswise_dim): linear_mm_config, scaling_granularity=ScalingGranularity.AXISWISE, axiswise_dim=axiswise_dim, + power_of_2_scale=power_of_2_scale, ) a_dq = a_fp8.to_original_precision() sqnr = compute_error(a, a_dq) diff --git a/test/float8/test_compile.py b/test/float8/test_compile.py index c42ab8ee77..643e1e9f35 100644 --- a/test/float8/test_compile.py +++ b/test/float8/test_compile.py @@ -13,9 +13,9 @@ import pytest from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, is_sm_at_least_89, is_sm_at_least_90, + TORCH_VERSION_AT_LEAST_2_5, ) if not TORCH_VERSION_AT_LEAST_2_5: @@ -29,11 +29,11 @@ from torchao.float8 import _prototype_register_float8_delayed_scaling_inductor_passes from torchao.float8.config import ( CastConfig, + e4m3_dtype, Float8LinearConfig, Float8LinearRecipeName, - ScalingType, - e4m3_dtype, recipe_name_to_linear_config, + ScalingType, ) from torchao.float8.float8_linear import Float8Linear from torchao.float8.float8_linear_utils import ( @@ -45,11 +45,7 @@ hp_tensor_to_float8_delayed, hp_tensor_to_float8_dynamic, ) -from torchao.float8.float8_tensor import ( - GemmInputRole, - LinearMMConfig, - ScaledMMConfig, -) +from torchao.float8.float8_tensor import GemmInputRole, LinearMMConfig, ScaledMMConfig from torchao.float8.float8_utils import config_has_stateful_scaling from torchao.float8.stateful_float8_linear import StatefulFloat8Linear from torchao.testing.float8.test_utils import get_test_float8_linear_config @@ -420,7 +416,14 @@ def test_sync_amax_func_cuda_graph_success(): torch.float16, ], ) -def test_dynamic_scale_numeric_parity(dtype: torch.dtype): +@pytest.mark.parametrize( + "power_of_2_scale", + [ + True, + False, + ], +) +def test_dynamic_scale_numeric_parity(dtype: torch.dtype, power_of_2_scale: bool): scaling_type_weight = ScalingType.DYNAMIC torch.manual_seed(42) hp_tensor1 = torch.randn(16, 16, device="cuda", dtype=dtype) @@ -456,6 +459,7 @@ def test_dynamic_scale_numeric_parity(dtype: torch.dtype): e4m3_dtype, linear_mm_config, gemm_input_role=GemmInputRole.WEIGHT, + power_of_2_scale=power_of_2_scale, ) torch._dynamo.reset() float8_compile = torch.compile(hp_tensor_to_float8_dynamic)( @@ -463,6 +467,7 @@ def test_dynamic_scale_numeric_parity(dtype: torch.dtype): e4m3_dtype, linear_mm_config, gemm_input_role=GemmInputRole.WEIGHT, + power_of_2_scale=power_of_2_scale, ) assert torch.equal(float8_eager._scale, float8_compile._scale) assert torch.equal(float8_eager._data, float8_compile._data) @@ -474,8 +479,7 @@ def test_dynamic_scale_numeric_parity(dtype: torch.dtype): ) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) def test_delayed_scaling_pattern_replacement(dtype: torch.dtype): - from torch._inductor import config as inductor_config - from torch._inductor import metrics + from torch._inductor import config as inductor_config, metrics inductor_config.loop_ordering_after_fusion = True diff --git a/torchao/float8/config.py b/torchao/float8/config.py index c7f32cd3fa..1fde6dd4fc 100644 --- a/torchao/float8/config.py +++ b/torchao/float8/config.py @@ -234,6 +234,13 @@ class Float8LinearConfig: # tests so that the warning does not spam the CI stdout. force_recompute_fp8_weight_in_bwd: bool = False + # If this option is enabled, the scaling factor used for float8 quantization + # will be rounded down to the nearest power of 2. This has been shown to help + # reduce quantization error by avoiding rounding errors when multiplying/dividing + # by the scaling factor, as well as ensuring large values are quantized to the + # same value in the forward pass as the backward passes. + power_of_2_scale: bool = False + def __post_init__(self): # Populate the additional cast overrides, if the user did not specify them # Note: this hacks around the frozen-ness of this dataclass @@ -336,6 +343,8 @@ def recipe_name_to_linear_config( cast_config_input=cc_i, cast_config_weight=cc_w, cast_config_grad_output=cc_go, + # enable power of 2 scaling factors by default for row-wise scaling + power_of_2_scale=True, ) elif recipe_name is Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP: diff --git a/torchao/float8/float8_linear.py b/torchao/float8/float8_linear.py index 6b3c0f06df..27abf3f866 100644 --- a/torchao/float8/float8_linear.py +++ b/torchao/float8/float8_linear.py @@ -96,6 +96,7 @@ def forward( axiswise_dim=get_maybe_axiswise_dim( -1, c.cast_config_input.scaling_granularity ), + power_of_2_scale=c.power_of_2_scale, ) if tensor_already_casted_to_fp8(weight_hp_t): @@ -112,6 +113,7 @@ def forward( axiswise_dim=get_maybe_axiswise_dim( 0, c.cast_config_weight.scaling_granularity ), + power_of_2_scale=c.power_of_2_scale, ) # the reshapes are needed in order to make the shapes compatible with @@ -151,6 +153,7 @@ def backward(ctx, grad_output): axiswise_dim=get_maybe_axiswise_dim( -1, c.cast_config_grad_output.scaling_granularity ), + power_of_2_scale=c.power_of_2_scale, ) if tensor_already_casted_to_fp8(weight_hp_t): @@ -181,6 +184,7 @@ def backward(ctx, grad_output): axiswise_dim=get_maybe_axiswise_dim( -1, c.cast_config_weight_for_grad_input.scaling_granularity ), + power_of_2_scale=c.power_of_2_scale, ) grad_input = torch.mm( @@ -216,6 +220,7 @@ def backward(ctx, grad_output): axiswise_dim=get_maybe_axiswise_dim( 0, c.cast_config_grad_output_for_grad_weight.scaling_granularity ), + power_of_2_scale=c.power_of_2_scale, ) if tensor_already_casted_to_fp8(input_hp_reshaped): @@ -233,6 +238,7 @@ def backward(ctx, grad_output): axiswise_dim=get_maybe_axiswise_dim( 0, c.cast_config_input_for_grad_weight.scaling_granularity ), + power_of_2_scale=c.power_of_2_scale, ) grad_weight = torch.mm( diff --git a/torchao/float8/float8_scaling_utils.py b/torchao/float8/float8_scaling_utils.py index 0c27e4f3fc..2ba148be24 100644 --- a/torchao/float8/float8_scaling_utils.py +++ b/torchao/float8/float8_scaling_utils.py @@ -36,6 +36,7 @@ def hp_tensor_to_float8_dynamic( device_mesh=None, scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE, axiswise_dim: Optional[int] = None, + power_of_2_scale: bool = False, ) -> Float8Tensor: """ Given a high precision tensor `hp_tensor`, @@ -51,6 +52,7 @@ def hp_tensor_to_float8_dynamic( the 3 fwd/bwd gemms of linear scaling_granularity: Defines the scaling granularity axiswise_dim: if axiswise granularity is used, defines the dim to scale across + power_of_2_scale: if true, round scaling factor down to the nearest power of 2. """ scale = tensor_to_scale( hp_tensor, @@ -60,6 +62,9 @@ def hp_tensor_to_float8_dynamic( scaling_granularity, axiswise_dim, ) + if power_of_2_scale: + # rounds down to the nearest power of 2. + scale = torch.exp2(torch.floor(torch.log2(scale))) return hp_tensor_and_scale_to_float8( hp_tensor, scale,