diff --git a/test/float8/test_base.py b/test/float8/test_base.py index 3e894c02b9..b537c7ab9f 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -164,7 +164,10 @@ def test_transpose(self): @pytest.mark.parametrize("shape", [(8, 16), (4, 8, 16), (2, 4, 8, 16)]) @pytest.mark.parametrize("axiswise_dim", [0, -1]) - def test_axiswise_dynamic_cast(self, shape, axiswise_dim): + @pytest.mark.parametrize("round_scales_to_power_of_2", [True, False]) + def test_axiswise_dynamic_cast( + self, shape, axiswise_dim, round_scales_to_power_of_2 + ): a = torch.randn(*shape, dtype=torch.bfloat16) linear_mm_config = LinearMMConfig() a_fp8 = hp_tensor_to_float8_dynamic( @@ -173,6 +176,7 @@ def test_axiswise_dynamic_cast(self, shape, axiswise_dim): linear_mm_config, scaling_granularity=ScalingGranularity.AXISWISE, axiswise_dim=axiswise_dim, + round_scales_to_power_of_2=round_scales_to_power_of_2, ) a_dq = a_fp8.to_original_precision() sqnr = compute_error(a, a_dq) diff --git a/test/float8/test_compile.py b/test/float8/test_compile.py index c42ab8ee77..d9c71f7395 100644 --- a/test/float8/test_compile.py +++ b/test/float8/test_compile.py @@ -45,11 +45,7 @@ hp_tensor_to_float8_delayed, hp_tensor_to_float8_dynamic, ) -from torchao.float8.float8_tensor import ( - GemmInputRole, - LinearMMConfig, - ScaledMMConfig, -) +from torchao.float8.float8_tensor import GemmInputRole, LinearMMConfig, ScaledMMConfig from torchao.float8.float8_utils import config_has_stateful_scaling from torchao.float8.stateful_float8_linear import StatefulFloat8Linear from torchao.testing.float8.test_utils import get_test_float8_linear_config @@ -420,13 +416,23 @@ def test_sync_amax_func_cuda_graph_success(): torch.float16, ], ) -def test_dynamic_scale_numeric_parity(dtype: torch.dtype): +@pytest.mark.parametrize( + "round_scales_to_power_of_2", + [ + True, + False, + ], +) +def test_dynamic_scale_numeric_parity( + dtype: torch.dtype, round_scales_to_power_of_2: bool +): scaling_type_weight = ScalingType.DYNAMIC torch.manual_seed(42) hp_tensor1 = torch.randn(16, 16, device="cuda", dtype=dtype) hp_tensor2 = hp_tensor1.detach().clone() float8_config = Float8LinearConfig( cast_config_weight=CastConfig(scaling_type=scaling_type_weight), + round_scales_to_power_of_2=round_scales_to_power_of_2, ) linear_mm_config = LinearMMConfig( # output @@ -456,6 +462,7 @@ def test_dynamic_scale_numeric_parity(dtype: torch.dtype): e4m3_dtype, linear_mm_config, gemm_input_role=GemmInputRole.WEIGHT, + round_scales_to_power_of_2=float8_config.round_scales_to_power_of_2, ) torch._dynamo.reset() float8_compile = torch.compile(hp_tensor_to_float8_dynamic)( @@ -463,6 +470,7 @@ def test_dynamic_scale_numeric_parity(dtype: torch.dtype): e4m3_dtype, linear_mm_config, gemm_input_role=GemmInputRole.WEIGHT, + round_scales_to_power_of_2=float8_config.round_scales_to_power_of_2, ) assert torch.equal(float8_eager._scale, float8_compile._scale) assert torch.equal(float8_eager._data, float8_compile._data) diff --git a/test/float8/test_float8_utils.py b/test/float8/test_float8_utils.py new file mode 100644 index 0000000000..ca9f21dde1 --- /dev/null +++ b/test/float8/test_float8_utils.py @@ -0,0 +1,65 @@ +import unittest + +import pytest +import torch + +from torchao.float8.float8_utils import _round_scale_down_to_power_of_2 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 + +if not TORCH_VERSION_AT_LEAST_2_5: + pytest.skip("Unsupported PyTorch version", allow_module_level=True) + + +# source for notable single-precision cases: +# https://en.wikipedia.org/wiki/Single-precision_floating-point_format +@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") +@pytest.mark.parametrize( + "test_case", + [ + # ("test_case_name", input, expected result) + ("one", 1.0, 1.0), + ("inf", float("inf"), float("inf")), + ("nan", float("nan"), float("nan")), + ("smallest positive subnormal number", 2**-126 * 2**-23, 2**-126 * 2**-23), + ("largest normal number", 2**127 * (2 - 2**-23), float("inf")), + ("smallest positive normal number", 2**-126, 2**-126), + ("largest number less than one", 1.0 - 2**-24, 0.5), + ("smallest number larger than one", 1.0 + 2**-23, 1.0), + # TODO(danielvegamyhre): debug why creating a tensor with largest + # subnormal value in CI env for pytorch 2.5.1 truncates the value to 0. + # ("largest subnormal number", [2**-126 * (1 - 2**-23), 1.1754943508222875e-38]), + ], +) +def test_round_scale_down_to_power_of_2_valid_inputs( + test_case: dict, +): + test_case_name, input, expected_result = test_case + input_tensor, expected_tensor = ( + torch.tensor(input, dtype=torch.float32).cuda(), + torch.tensor(expected_result, dtype=torch.float32).cuda(), + ) + result = _round_scale_down_to_power_of_2(input_tensor) + + assert ( + torch.equal(result, expected_tensor) + or (result.isnan() and expected_tensor.isnan()) + ), f"test: {test_case_name}, input: {input_tensor}, expected {expected_tensor}, but got {result}" + + +@pytest.mark.parametrize( + "invalid_dtype", + [ + torch.bfloat16, + torch.float16, + torch.float64, + torch.int8, + torch.uint8, + torch.int32, + torch.uint32, + torch.int64, + ], +) +def test_non_float32_input(invalid_dtype: torch.dtype): + non_float32_tensor = torch.tensor([3.0], dtype=invalid_dtype) + with pytest.raises(AssertionError, match="scale must be float32 tensor"): + _round_scale_down_to_power_of_2(non_float32_tensor) diff --git a/torchao/float8/config.py b/torchao/float8/config.py index c7f32cd3fa..21c3c03283 100644 --- a/torchao/float8/config.py +++ b/torchao/float8/config.py @@ -234,6 +234,13 @@ class Float8LinearConfig: # tests so that the warning does not spam the CI stdout. force_recompute_fp8_weight_in_bwd: bool = False + # If this option is enabled, the scaling factor used for float8 quantization + # will be rounded down to the nearest power of 2. This has been shown to help + # reduce quantization error by avoiding rounding errors when multiplying/dividing + # by the scaling factor, as well as ensuring large values are quantized to the + # same value in the forward pass as the backward passes. + round_scales_to_power_of_2: bool = False + def __post_init__(self): # Populate the additional cast overrides, if the user did not specify them # Note: this hacks around the frozen-ness of this dataclass @@ -328,14 +335,22 @@ def recipe_name_to_linear_config( elif recipe_name is Float8LinearRecipeName.ALL_AXISWISE: # dynamic axiswise scaling with the CUTLASS rowwise kernel - cc_i = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) - cc_w = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) - cc_go = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) + cc_i = CastConfig( + scaling_granularity=ScalingGranularity.AXISWISE, target_dtype=e4m3_dtype + ) + cc_w = CastConfig( + scaling_granularity=ScalingGranularity.AXISWISE, target_dtype=e4m3_dtype + ) + cc_go = CastConfig( + scaling_granularity=ScalingGranularity.AXISWISE, target_dtype=e4m3_dtype + ) return Float8LinearConfig( cast_config_input=cc_i, cast_config_weight=cc_w, cast_config_grad_output=cc_go, + # enable power of 2 scaling factors by default for row-wise scaling + round_scales_to_power_of_2=True, ) elif recipe_name is Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP: diff --git a/torchao/float8/float8_linear.py b/torchao/float8/float8_linear.py index 6b3c0f06df..0bc2690bc5 100644 --- a/torchao/float8/float8_linear.py +++ b/torchao/float8/float8_linear.py @@ -96,6 +96,7 @@ def forward( axiswise_dim=get_maybe_axiswise_dim( -1, c.cast_config_input.scaling_granularity ), + round_scales_to_power_of_2=c.round_scales_to_power_of_2, ) if tensor_already_casted_to_fp8(weight_hp_t): @@ -112,6 +113,7 @@ def forward( axiswise_dim=get_maybe_axiswise_dim( 0, c.cast_config_weight.scaling_granularity ), + round_scales_to_power_of_2=c.round_scales_to_power_of_2, ) # the reshapes are needed in order to make the shapes compatible with @@ -151,6 +153,7 @@ def backward(ctx, grad_output): axiswise_dim=get_maybe_axiswise_dim( -1, c.cast_config_grad_output.scaling_granularity ), + round_scales_to_power_of_2=c.round_scales_to_power_of_2, ) if tensor_already_casted_to_fp8(weight_hp_t): @@ -181,6 +184,7 @@ def backward(ctx, grad_output): axiswise_dim=get_maybe_axiswise_dim( -1, c.cast_config_weight_for_grad_input.scaling_granularity ), + round_scales_to_power_of_2=c.round_scales_to_power_of_2, ) grad_input = torch.mm( @@ -216,6 +220,7 @@ def backward(ctx, grad_output): axiswise_dim=get_maybe_axiswise_dim( 0, c.cast_config_grad_output_for_grad_weight.scaling_granularity ), + round_scales_to_power_of_2=c.round_scales_to_power_of_2, ) if tensor_already_casted_to_fp8(input_hp_reshaped): @@ -233,6 +238,7 @@ def backward(ctx, grad_output): axiswise_dim=get_maybe_axiswise_dim( 0, c.cast_config_input_for_grad_weight.scaling_granularity ), + round_scales_to_power_of_2=c.round_scales_to_power_of_2, ) grad_weight = torch.mm( diff --git a/torchao/float8/float8_scaling_utils.py b/torchao/float8/float8_scaling_utils.py index 0c27e4f3fc..b96c7a9b58 100644 --- a/torchao/float8/float8_scaling_utils.py +++ b/torchao/float8/float8_scaling_utils.py @@ -27,6 +27,7 @@ ) +# TODO(danielvegamyhre): refactor to accept Float8LinearConfig directly def hp_tensor_to_float8_dynamic( hp_tensor: torch.Tensor, float8_dtype: torch.dtype, @@ -36,6 +37,7 @@ def hp_tensor_to_float8_dynamic( device_mesh=None, scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE, axiswise_dim: Optional[int] = None, + round_scales_to_power_of_2: bool = False, ) -> Float8Tensor: """ Given a high precision tensor `hp_tensor`, @@ -51,6 +53,7 @@ def hp_tensor_to_float8_dynamic( the 3 fwd/bwd gemms of linear scaling_granularity: Defines the scaling granularity axiswise_dim: if axiswise granularity is used, defines the dim to scale across + round_scales_to_power_of_2: if true, round scaling factor down to the nearest power of 2. """ scale = tensor_to_scale( hp_tensor, @@ -59,6 +62,7 @@ def hp_tensor_to_float8_dynamic( device_mesh, scaling_granularity, axiswise_dim, + round_scales_to_power_of_2, ) return hp_tensor_and_scale_to_float8( hp_tensor, diff --git a/torchao/float8/float8_utils.py b/torchao/float8/float8_utils.py index 6a93a612fa..926b97edb8 100644 --- a/torchao/float8/float8_utils.py +++ b/torchao/float8/float8_utils.py @@ -10,11 +10,7 @@ import torch.distributed as dist from torch.distributed._functional_collectives import AsyncCollectiveTensor, all_reduce -from torchao.float8.config import ( - Float8LinearConfig, - ScalingGranularity, - ScalingType, -) +from torchao.float8.config import Float8LinearConfig, ScalingGranularity, ScalingType # Helpful visualizer for debugging (only supports fp32): # https://www.h-schmidt.net/FloatConverter/IEEE754.html @@ -33,21 +29,28 @@ @torch.no_grad() -def amax_to_scale(amax: torch.Tensor, float8_dtype: torch.dtype): +def amax_to_scale( + amax: torch.Tensor, + float8_dtype: torch.dtype, + round_scales_to_power_of_2: bool = False, +): """Converts the amax value of a tensor to the fp8 scale. Args: amax: The amax value of the tensor. float8_dtype: The float8 dtype. + round_scales_to_power_of_2: if true, round scaling factor down to the nearest power of 2. """ # torch.compile and eager show different numerics for 1.0 / float32, # upcast to float64 to ensure same numeric between compile and eager amax = amax.to(torch.float64) if float8_dtype in FP8_TYPES: res = torch.finfo(float8_dtype).max / torch.clamp(amax, min=EPS) + res = res.to(torch.float32) else: raise ValueError(f"Unsupported float8_dtype: {float8_dtype}") - - return res.to(torch.float32) + if round_scales_to_power_of_2: + res = _round_scale_down_to_power_of_2(res) + return res @torch.no_grad() @@ -119,21 +122,35 @@ def tensor_to_amax( @torch.no_grad() def tensor_to_scale( - x: torch.Tensor, + hp_tensor: torch.Tensor, float8_dtype: torch.dtype, reduce_amax: bool = False, device_mesh=None, scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE, axiswise_dim: Optional[int] = None, + round_scales_to_power_of_2: bool = False, ) -> torch.Tensor: + """ + Compute scaling factor for the given high precision tensor. + + Args: + hp_tensor: high precision tensor + float8_dtype: the float8 dtype to use + reduce_amax: whether to reduce the max(abs(hp_tensor)) value across distributed ranks + scaling_granularity: Defines the scaling granularity + axiswise_dim: if axiswise granularity is used, defines the dim to scale across + round_scales_to_power_of_2: if true, round scaling factor down to the nearest power of 2. + """ amax = tensor_to_amax( - x, + hp_tensor, reduce_amax, device_mesh, scaling_granularity, axiswise_dim, ) - return amax_to_scale(amax, float8_dtype) + return amax_to_scale( + amax, float8_dtype, round_scales_to_power_of_2=round_scales_to_power_of_2 + ) def to_fp8_saturated(x: torch.Tensor, float8_dtype: torch.dtype): @@ -266,3 +283,8 @@ def config_has_stateful_scaling(config: Float8LinearConfig) -> bool: or config.cast_config_weight.scaling_type != ScalingType.DYNAMIC or config.cast_config_grad_output.scaling_type != ScalingType.DYNAMIC ) + + +def _round_scale_down_to_power_of_2(scale: torch.Tensor): + assert scale.dtype == torch.float32, "scale must be float32 tensor" + return torch.exp2(torch.floor(torch.log2(scale)))