From f2433b1ad9522ddf42f635b47666c707e2f9b795 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Wed, 5 Feb 2025 10:10:07 -0800 Subject: [PATCH 01/16] support power of 2 scaling factors in float8 training --- test/float8/test_base.py | 16 +++++++++------- test/float8/test_compile.py | 26 +++++++++++++++----------- torchao/float8/config.py | 9 +++++++++ torchao/float8/float8_linear.py | 6 ++++++ torchao/float8/float8_scaling_utils.py | 5 +++++ 5 files changed, 44 insertions(+), 18 deletions(-) diff --git a/test/float8/test_base.py b/test/float8/test_base.py index 3e894c02b9..7913cced6a 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -15,9 +15,9 @@ import torch.nn as nn from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, is_sm_at_least_89, is_sm_at_least_90, + TORCH_VERSION_AT_LEAST_2_5, ) if not TORCH_VERSION_AT_LEAST_2_5: @@ -26,13 +26,13 @@ from torchao.float8.config import ( CastConfig, + e4m3_dtype, + e5m2_dtype, Float8LinearConfig, Float8LinearRecipeName, + recipe_name_to_linear_config, ScalingGranularity, ScalingType, - e4m3_dtype, - e5m2_dtype, - recipe_name_to_linear_config, ) from torchao.float8.float8_linear import Float8Linear from torchao.float8.float8_linear_utils import ( @@ -48,15 +48,15 @@ from torchao.float8.float8_tensor import ( Float8Tensor, GemmInputRole, + hp_tensor_and_scale_to_float8, LinearMMConfig, ScaledMMConfig, - hp_tensor_and_scale_to_float8, ) from torchao.float8.float8_utils import ( - FP8_TYPES, compute_error, config_has_stateful_scaling, fp8_tensor_statistics, + FP8_TYPES, tensor_to_scale, ) from torchao.float8.stateful_float8_linear import StatefulFloat8Linear @@ -164,7 +164,8 @@ def test_transpose(self): @pytest.mark.parametrize("shape", [(8, 16), (4, 8, 16), (2, 4, 8, 16)]) @pytest.mark.parametrize("axiswise_dim", [0, -1]) - def test_axiswise_dynamic_cast(self, shape, axiswise_dim): + @pytest.mark.parametrize("power_of_2_scale", [True, False]) + def test_axiswise_dynamic_cast(self, shape, axiswise_dim, power_of_2_scale): a = torch.randn(*shape, dtype=torch.bfloat16) linear_mm_config = LinearMMConfig() a_fp8 = hp_tensor_to_float8_dynamic( @@ -173,6 +174,7 @@ def test_axiswise_dynamic_cast(self, shape, axiswise_dim): linear_mm_config, scaling_granularity=ScalingGranularity.AXISWISE, axiswise_dim=axiswise_dim, + power_of_2_scale=power_of_2_scale, ) a_dq = a_fp8.to_original_precision() sqnr = compute_error(a, a_dq) diff --git a/test/float8/test_compile.py b/test/float8/test_compile.py index c42ab8ee77..643e1e9f35 100644 --- a/test/float8/test_compile.py +++ b/test/float8/test_compile.py @@ -13,9 +13,9 @@ import pytest from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, is_sm_at_least_89, is_sm_at_least_90, + TORCH_VERSION_AT_LEAST_2_5, ) if not TORCH_VERSION_AT_LEAST_2_5: @@ -29,11 +29,11 @@ from torchao.float8 import _prototype_register_float8_delayed_scaling_inductor_passes from torchao.float8.config import ( CastConfig, + e4m3_dtype, Float8LinearConfig, Float8LinearRecipeName, - ScalingType, - e4m3_dtype, recipe_name_to_linear_config, + ScalingType, ) from torchao.float8.float8_linear import Float8Linear from torchao.float8.float8_linear_utils import ( @@ -45,11 +45,7 @@ hp_tensor_to_float8_delayed, hp_tensor_to_float8_dynamic, ) -from torchao.float8.float8_tensor import ( - GemmInputRole, - LinearMMConfig, - ScaledMMConfig, -) +from torchao.float8.float8_tensor import GemmInputRole, LinearMMConfig, ScaledMMConfig from torchao.float8.float8_utils import config_has_stateful_scaling from torchao.float8.stateful_float8_linear import StatefulFloat8Linear from torchao.testing.float8.test_utils import get_test_float8_linear_config @@ -420,7 +416,14 @@ def test_sync_amax_func_cuda_graph_success(): torch.float16, ], ) -def test_dynamic_scale_numeric_parity(dtype: torch.dtype): +@pytest.mark.parametrize( + "power_of_2_scale", + [ + True, + False, + ], +) +def test_dynamic_scale_numeric_parity(dtype: torch.dtype, power_of_2_scale: bool): scaling_type_weight = ScalingType.DYNAMIC torch.manual_seed(42) hp_tensor1 = torch.randn(16, 16, device="cuda", dtype=dtype) @@ -456,6 +459,7 @@ def test_dynamic_scale_numeric_parity(dtype: torch.dtype): e4m3_dtype, linear_mm_config, gemm_input_role=GemmInputRole.WEIGHT, + power_of_2_scale=power_of_2_scale, ) torch._dynamo.reset() float8_compile = torch.compile(hp_tensor_to_float8_dynamic)( @@ -463,6 +467,7 @@ def test_dynamic_scale_numeric_parity(dtype: torch.dtype): e4m3_dtype, linear_mm_config, gemm_input_role=GemmInputRole.WEIGHT, + power_of_2_scale=power_of_2_scale, ) assert torch.equal(float8_eager._scale, float8_compile._scale) assert torch.equal(float8_eager._data, float8_compile._data) @@ -474,8 +479,7 @@ def test_dynamic_scale_numeric_parity(dtype: torch.dtype): ) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) def test_delayed_scaling_pattern_replacement(dtype: torch.dtype): - from torch._inductor import config as inductor_config - from torch._inductor import metrics + from torch._inductor import config as inductor_config, metrics inductor_config.loop_ordering_after_fusion = True diff --git a/torchao/float8/config.py b/torchao/float8/config.py index c7f32cd3fa..1fde6dd4fc 100644 --- a/torchao/float8/config.py +++ b/torchao/float8/config.py @@ -234,6 +234,13 @@ class Float8LinearConfig: # tests so that the warning does not spam the CI stdout. force_recompute_fp8_weight_in_bwd: bool = False + # If this option is enabled, the scaling factor used for float8 quantization + # will be rounded down to the nearest power of 2. This has been shown to help + # reduce quantization error by avoiding rounding errors when multiplying/dividing + # by the scaling factor, as well as ensuring large values are quantized to the + # same value in the forward pass as the backward passes. + power_of_2_scale: bool = False + def __post_init__(self): # Populate the additional cast overrides, if the user did not specify them # Note: this hacks around the frozen-ness of this dataclass @@ -336,6 +343,8 @@ def recipe_name_to_linear_config( cast_config_input=cc_i, cast_config_weight=cc_w, cast_config_grad_output=cc_go, + # enable power of 2 scaling factors by default for row-wise scaling + power_of_2_scale=True, ) elif recipe_name is Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP: diff --git a/torchao/float8/float8_linear.py b/torchao/float8/float8_linear.py index 6b3c0f06df..27abf3f866 100644 --- a/torchao/float8/float8_linear.py +++ b/torchao/float8/float8_linear.py @@ -96,6 +96,7 @@ def forward( axiswise_dim=get_maybe_axiswise_dim( -1, c.cast_config_input.scaling_granularity ), + power_of_2_scale=c.power_of_2_scale, ) if tensor_already_casted_to_fp8(weight_hp_t): @@ -112,6 +113,7 @@ def forward( axiswise_dim=get_maybe_axiswise_dim( 0, c.cast_config_weight.scaling_granularity ), + power_of_2_scale=c.power_of_2_scale, ) # the reshapes are needed in order to make the shapes compatible with @@ -151,6 +153,7 @@ def backward(ctx, grad_output): axiswise_dim=get_maybe_axiswise_dim( -1, c.cast_config_grad_output.scaling_granularity ), + power_of_2_scale=c.power_of_2_scale, ) if tensor_already_casted_to_fp8(weight_hp_t): @@ -181,6 +184,7 @@ def backward(ctx, grad_output): axiswise_dim=get_maybe_axiswise_dim( -1, c.cast_config_weight_for_grad_input.scaling_granularity ), + power_of_2_scale=c.power_of_2_scale, ) grad_input = torch.mm( @@ -216,6 +220,7 @@ def backward(ctx, grad_output): axiswise_dim=get_maybe_axiswise_dim( 0, c.cast_config_grad_output_for_grad_weight.scaling_granularity ), + power_of_2_scale=c.power_of_2_scale, ) if tensor_already_casted_to_fp8(input_hp_reshaped): @@ -233,6 +238,7 @@ def backward(ctx, grad_output): axiswise_dim=get_maybe_axiswise_dim( 0, c.cast_config_input_for_grad_weight.scaling_granularity ), + power_of_2_scale=c.power_of_2_scale, ) grad_weight = torch.mm( diff --git a/torchao/float8/float8_scaling_utils.py b/torchao/float8/float8_scaling_utils.py index 0c27e4f3fc..2ba148be24 100644 --- a/torchao/float8/float8_scaling_utils.py +++ b/torchao/float8/float8_scaling_utils.py @@ -36,6 +36,7 @@ def hp_tensor_to_float8_dynamic( device_mesh=None, scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE, axiswise_dim: Optional[int] = None, + power_of_2_scale: bool = False, ) -> Float8Tensor: """ Given a high precision tensor `hp_tensor`, @@ -51,6 +52,7 @@ def hp_tensor_to_float8_dynamic( the 3 fwd/bwd gemms of linear scaling_granularity: Defines the scaling granularity axiswise_dim: if axiswise granularity is used, defines the dim to scale across + power_of_2_scale: if true, round scaling factor down to the nearest power of 2. """ scale = tensor_to_scale( hp_tensor, @@ -60,6 +62,9 @@ def hp_tensor_to_float8_dynamic( scaling_granularity, axiswise_dim, ) + if power_of_2_scale: + # rounds down to the nearest power of 2. + scale = torch.exp2(torch.floor(torch.log2(scale))) return hp_tensor_and_scale_to_float8( hp_tensor, scale, From a9fe17ed07243f80ea5c0a21ddb4dd275d803aa5 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Wed, 5 Feb 2025 13:30:11 -0800 Subject: [PATCH 02/16] fix linter issues --- test/float8/test_base.py | 12 ++++++------ test/float8/test_compile.py | 14 ++++++++------ 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/test/float8/test_base.py b/test/float8/test_base.py index 7913cced6a..f096eee223 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -15,9 +15,9 @@ import torch.nn as nn from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_5, is_sm_at_least_89, is_sm_at_least_90, - TORCH_VERSION_AT_LEAST_2_5, ) if not TORCH_VERSION_AT_LEAST_2_5: @@ -26,13 +26,13 @@ from torchao.float8.config import ( CastConfig, - e4m3_dtype, - e5m2_dtype, Float8LinearConfig, Float8LinearRecipeName, - recipe_name_to_linear_config, ScalingGranularity, ScalingType, + e4m3_dtype, + e5m2_dtype, + recipe_name_to_linear_config, ) from torchao.float8.float8_linear import Float8Linear from torchao.float8.float8_linear_utils import ( @@ -48,15 +48,15 @@ from torchao.float8.float8_tensor import ( Float8Tensor, GemmInputRole, - hp_tensor_and_scale_to_float8, LinearMMConfig, ScaledMMConfig, + hp_tensor_and_scale_to_float8, ) from torchao.float8.float8_utils import ( + FP8_TYPES, compute_error, config_has_stateful_scaling, fp8_tensor_statistics, - FP8_TYPES, tensor_to_scale, ) from torchao.float8.stateful_float8_linear import StatefulFloat8Linear diff --git a/test/float8/test_compile.py b/test/float8/test_compile.py index 643e1e9f35..ce00b1ae69 100644 --- a/test/float8/test_compile.py +++ b/test/float8/test_compile.py @@ -13,9 +13,9 @@ import pytest from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_5, is_sm_at_least_89, is_sm_at_least_90, - TORCH_VERSION_AT_LEAST_2_5, ) if not TORCH_VERSION_AT_LEAST_2_5: @@ -29,11 +29,11 @@ from torchao.float8 import _prototype_register_float8_delayed_scaling_inductor_passes from torchao.float8.config import ( CastConfig, - e4m3_dtype, Float8LinearConfig, Float8LinearRecipeName, - recipe_name_to_linear_config, ScalingType, + e4m3_dtype, + recipe_name_to_linear_config, ) from torchao.float8.float8_linear import Float8Linear from torchao.float8.float8_linear_utils import ( @@ -430,6 +430,7 @@ def test_dynamic_scale_numeric_parity(dtype: torch.dtype, power_of_2_scale: bool hp_tensor2 = hp_tensor1.detach().clone() float8_config = Float8LinearConfig( cast_config_weight=CastConfig(scaling_type=scaling_type_weight), + power_of_2_scale=power_of_2_scale, ) linear_mm_config = LinearMMConfig( # output @@ -459,7 +460,7 @@ def test_dynamic_scale_numeric_parity(dtype: torch.dtype, power_of_2_scale: bool e4m3_dtype, linear_mm_config, gemm_input_role=GemmInputRole.WEIGHT, - power_of_2_scale=power_of_2_scale, + power_of_2_scale=float8_config.power_of_2_scale, ) torch._dynamo.reset() float8_compile = torch.compile(hp_tensor_to_float8_dynamic)( @@ -467,7 +468,7 @@ def test_dynamic_scale_numeric_parity(dtype: torch.dtype, power_of_2_scale: bool e4m3_dtype, linear_mm_config, gemm_input_role=GemmInputRole.WEIGHT, - power_of_2_scale=power_of_2_scale, + power_of_2_scale=float8_config.power_of_2_scale, ) assert torch.equal(float8_eager._scale, float8_compile._scale) assert torch.equal(float8_eager._data, float8_compile._data) @@ -479,7 +480,8 @@ def test_dynamic_scale_numeric_parity(dtype: torch.dtype, power_of_2_scale: bool ) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) def test_delayed_scaling_pattern_replacement(dtype: torch.dtype): - from torch._inductor import config as inductor_config, metrics + from torch._inductor import config as inductor_config + from torch._inductor import metrics inductor_config.loop_ordering_after_fusion = True From 896bd8f73b455f8622277b6645609ed25c4f287c Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Wed, 5 Feb 2025 13:58:56 -0800 Subject: [PATCH 03/16] power of 2 scale in amax_to_scale --- torchao/float8/float8_scaling_utils.py | 4 +--- torchao/float8/float8_utils.py | 18 ++++++++++-------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/torchao/float8/float8_scaling_utils.py b/torchao/float8/float8_scaling_utils.py index 2ba148be24..0dc5f16edb 100644 --- a/torchao/float8/float8_scaling_utils.py +++ b/torchao/float8/float8_scaling_utils.py @@ -61,10 +61,8 @@ def hp_tensor_to_float8_dynamic( device_mesh, scaling_granularity, axiswise_dim, + power_of_2_scale, ) - if power_of_2_scale: - # rounds down to the nearest power of 2. - scale = torch.exp2(torch.floor(torch.log2(scale))) return hp_tensor_and_scale_to_float8( hp_tensor, scale, diff --git a/torchao/float8/float8_utils.py b/torchao/float8/float8_utils.py index 6a93a612fa..a410751d09 100644 --- a/torchao/float8/float8_utils.py +++ b/torchao/float8/float8_utils.py @@ -10,11 +10,7 @@ import torch.distributed as dist from torch.distributed._functional_collectives import AsyncCollectiveTensor, all_reduce -from torchao.float8.config import ( - Float8LinearConfig, - ScalingGranularity, - ScalingType, -) +from torchao.float8.config import Float8LinearConfig, ScalingGranularity, ScalingType # Helpful visualizer for debugging (only supports fp32): # https://www.h-schmidt.net/FloatConverter/IEEE754.html @@ -33,11 +29,14 @@ @torch.no_grad() -def amax_to_scale(amax: torch.Tensor, float8_dtype: torch.dtype): +def amax_to_scale( + amax: torch.Tensor, float8_dtype: torch.dtype, power_of_2_scale: bool = False +): """Converts the amax value of a tensor to the fp8 scale. Args: amax: The amax value of the tensor. float8_dtype: The float8 dtype. + power_of_2_scale: if true, round scaling factor down to the nearest power of 2. """ # torch.compile and eager show different numerics for 1.0 / float32, # upcast to float64 to ensure same numeric between compile and eager @@ -46,7 +45,9 @@ def amax_to_scale(amax: torch.Tensor, float8_dtype: torch.dtype): res = torch.finfo(float8_dtype).max / torch.clamp(amax, min=EPS) else: raise ValueError(f"Unsupported float8_dtype: {float8_dtype}") - + if power_of_2_scale: + # rounds down to the nearest power of 2. + res = torch.exp2(torch.floor(torch.log2(res))) return res.to(torch.float32) @@ -125,6 +126,7 @@ def tensor_to_scale( device_mesh=None, scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE, axiswise_dim: Optional[int] = None, + power_of_2_scale: bool = False, ) -> torch.Tensor: amax = tensor_to_amax( x, @@ -133,7 +135,7 @@ def tensor_to_scale( scaling_granularity, axiswise_dim, ) - return amax_to_scale(amax, float8_dtype) + return amax_to_scale(amax, float8_dtype, power_of_2_scale=power_of_2_scale) def to_fp8_saturated(x: torch.Tensor, float8_dtype: torch.dtype): From c70ad6091e85d91bfca3d9f9a85ee9a1ee940edf Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Wed, 5 Feb 2025 14:01:42 -0800 Subject: [PATCH 04/16] add docstring to tensor_to_scale --- torchao/float8/float8_utils.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/torchao/float8/float8_utils.py b/torchao/float8/float8_utils.py index a410751d09..8661b89b07 100644 --- a/torchao/float8/float8_utils.py +++ b/torchao/float8/float8_utils.py @@ -120,7 +120,7 @@ def tensor_to_amax( @torch.no_grad() def tensor_to_scale( - x: torch.Tensor, + hp_tensor: torch.Tensor, float8_dtype: torch.dtype, reduce_amax: bool = False, device_mesh=None, @@ -128,8 +128,19 @@ def tensor_to_scale( axiswise_dim: Optional[int] = None, power_of_2_scale: bool = False, ) -> torch.Tensor: + """ + Compute scaling factor for the given high precision tensor. + + Args: + hp_tensor: high precision tensor + float8_dtype: the float8 dtype to use + reduce_amax: whether to reduce the max(abs(hp_tensor)) value across distributed ranks + scaling_granularity: Defines the scaling granularity + axiswise_dim: if axiswise granularity is used, defines the dim to scale across + power_of_2_scale: if true, round scaling factor down to the nearest power of 2. + """ amax = tensor_to_amax( - x, + hp_tensor, reduce_amax, device_mesh, scaling_granularity, From 34cc033fc094c118fe6173379aa835c05da3cd63 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Wed, 5 Feb 2025 14:23:36 -0800 Subject: [PATCH 05/16] rename round_scales_to_power_of_2 --- test/float8/test_base.py | 8 +++++--- test/float8/test_compile.py | 12 +++++++----- torchao/float8/config.py | 4 ++-- torchao/float8/float8_linear.py | 12 ++++++------ torchao/float8/float8_scaling_utils.py | 6 +++--- torchao/float8/float8_utils.py | 16 ++++++++++------ 6 files changed, 33 insertions(+), 25 deletions(-) diff --git a/test/float8/test_base.py b/test/float8/test_base.py index f096eee223..b537c7ab9f 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -164,8 +164,10 @@ def test_transpose(self): @pytest.mark.parametrize("shape", [(8, 16), (4, 8, 16), (2, 4, 8, 16)]) @pytest.mark.parametrize("axiswise_dim", [0, -1]) - @pytest.mark.parametrize("power_of_2_scale", [True, False]) - def test_axiswise_dynamic_cast(self, shape, axiswise_dim, power_of_2_scale): + @pytest.mark.parametrize("round_scales_to_power_of_2", [True, False]) + def test_axiswise_dynamic_cast( + self, shape, axiswise_dim, round_scales_to_power_of_2 + ): a = torch.randn(*shape, dtype=torch.bfloat16) linear_mm_config = LinearMMConfig() a_fp8 = hp_tensor_to_float8_dynamic( @@ -174,7 +176,7 @@ def test_axiswise_dynamic_cast(self, shape, axiswise_dim, power_of_2_scale): linear_mm_config, scaling_granularity=ScalingGranularity.AXISWISE, axiswise_dim=axiswise_dim, - power_of_2_scale=power_of_2_scale, + round_scales_to_power_of_2=round_scales_to_power_of_2, ) a_dq = a_fp8.to_original_precision() sqnr = compute_error(a, a_dq) diff --git a/test/float8/test_compile.py b/test/float8/test_compile.py index ce00b1ae69..d9c71f7395 100644 --- a/test/float8/test_compile.py +++ b/test/float8/test_compile.py @@ -417,20 +417,22 @@ def test_sync_amax_func_cuda_graph_success(): ], ) @pytest.mark.parametrize( - "power_of_2_scale", + "round_scales_to_power_of_2", [ True, False, ], ) -def test_dynamic_scale_numeric_parity(dtype: torch.dtype, power_of_2_scale: bool): +def test_dynamic_scale_numeric_parity( + dtype: torch.dtype, round_scales_to_power_of_2: bool +): scaling_type_weight = ScalingType.DYNAMIC torch.manual_seed(42) hp_tensor1 = torch.randn(16, 16, device="cuda", dtype=dtype) hp_tensor2 = hp_tensor1.detach().clone() float8_config = Float8LinearConfig( cast_config_weight=CastConfig(scaling_type=scaling_type_weight), - power_of_2_scale=power_of_2_scale, + round_scales_to_power_of_2=round_scales_to_power_of_2, ) linear_mm_config = LinearMMConfig( # output @@ -460,7 +462,7 @@ def test_dynamic_scale_numeric_parity(dtype: torch.dtype, power_of_2_scale: bool e4m3_dtype, linear_mm_config, gemm_input_role=GemmInputRole.WEIGHT, - power_of_2_scale=float8_config.power_of_2_scale, + round_scales_to_power_of_2=float8_config.round_scales_to_power_of_2, ) torch._dynamo.reset() float8_compile = torch.compile(hp_tensor_to_float8_dynamic)( @@ -468,7 +470,7 @@ def test_dynamic_scale_numeric_parity(dtype: torch.dtype, power_of_2_scale: bool e4m3_dtype, linear_mm_config, gemm_input_role=GemmInputRole.WEIGHT, - power_of_2_scale=float8_config.power_of_2_scale, + round_scales_to_power_of_2=float8_config.round_scales_to_power_of_2, ) assert torch.equal(float8_eager._scale, float8_compile._scale) assert torch.equal(float8_eager._data, float8_compile._data) diff --git a/torchao/float8/config.py b/torchao/float8/config.py index 1fde6dd4fc..2a5a96cc46 100644 --- a/torchao/float8/config.py +++ b/torchao/float8/config.py @@ -239,7 +239,7 @@ class Float8LinearConfig: # reduce quantization error by avoiding rounding errors when multiplying/dividing # by the scaling factor, as well as ensuring large values are quantized to the # same value in the forward pass as the backward passes. - power_of_2_scale: bool = False + round_scales_to_power_of_2: bool = False def __post_init__(self): # Populate the additional cast overrides, if the user did not specify them @@ -344,7 +344,7 @@ def recipe_name_to_linear_config( cast_config_weight=cc_w, cast_config_grad_output=cc_go, # enable power of 2 scaling factors by default for row-wise scaling - power_of_2_scale=True, + round_scales_to_power_of_2=True, ) elif recipe_name is Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP: diff --git a/torchao/float8/float8_linear.py b/torchao/float8/float8_linear.py index 27abf3f866..0bc2690bc5 100644 --- a/torchao/float8/float8_linear.py +++ b/torchao/float8/float8_linear.py @@ -96,7 +96,7 @@ def forward( axiswise_dim=get_maybe_axiswise_dim( -1, c.cast_config_input.scaling_granularity ), - power_of_2_scale=c.power_of_2_scale, + round_scales_to_power_of_2=c.round_scales_to_power_of_2, ) if tensor_already_casted_to_fp8(weight_hp_t): @@ -113,7 +113,7 @@ def forward( axiswise_dim=get_maybe_axiswise_dim( 0, c.cast_config_weight.scaling_granularity ), - power_of_2_scale=c.power_of_2_scale, + round_scales_to_power_of_2=c.round_scales_to_power_of_2, ) # the reshapes are needed in order to make the shapes compatible with @@ -153,7 +153,7 @@ def backward(ctx, grad_output): axiswise_dim=get_maybe_axiswise_dim( -1, c.cast_config_grad_output.scaling_granularity ), - power_of_2_scale=c.power_of_2_scale, + round_scales_to_power_of_2=c.round_scales_to_power_of_2, ) if tensor_already_casted_to_fp8(weight_hp_t): @@ -184,7 +184,7 @@ def backward(ctx, grad_output): axiswise_dim=get_maybe_axiswise_dim( -1, c.cast_config_weight_for_grad_input.scaling_granularity ), - power_of_2_scale=c.power_of_2_scale, + round_scales_to_power_of_2=c.round_scales_to_power_of_2, ) grad_input = torch.mm( @@ -220,7 +220,7 @@ def backward(ctx, grad_output): axiswise_dim=get_maybe_axiswise_dim( 0, c.cast_config_grad_output_for_grad_weight.scaling_granularity ), - power_of_2_scale=c.power_of_2_scale, + round_scales_to_power_of_2=c.round_scales_to_power_of_2, ) if tensor_already_casted_to_fp8(input_hp_reshaped): @@ -238,7 +238,7 @@ def backward(ctx, grad_output): axiswise_dim=get_maybe_axiswise_dim( 0, c.cast_config_input_for_grad_weight.scaling_granularity ), - power_of_2_scale=c.power_of_2_scale, + round_scales_to_power_of_2=c.round_scales_to_power_of_2, ) grad_weight = torch.mm( diff --git a/torchao/float8/float8_scaling_utils.py b/torchao/float8/float8_scaling_utils.py index 0dc5f16edb..a8ad4c1920 100644 --- a/torchao/float8/float8_scaling_utils.py +++ b/torchao/float8/float8_scaling_utils.py @@ -36,7 +36,7 @@ def hp_tensor_to_float8_dynamic( device_mesh=None, scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE, axiswise_dim: Optional[int] = None, - power_of_2_scale: bool = False, + round_scales_to_power_of_2: bool = False, ) -> Float8Tensor: """ Given a high precision tensor `hp_tensor`, @@ -52,7 +52,7 @@ def hp_tensor_to_float8_dynamic( the 3 fwd/bwd gemms of linear scaling_granularity: Defines the scaling granularity axiswise_dim: if axiswise granularity is used, defines the dim to scale across - power_of_2_scale: if true, round scaling factor down to the nearest power of 2. + round_scales_to_power_of_2: if true, round scaling factor down to the nearest power of 2. """ scale = tensor_to_scale( hp_tensor, @@ -61,7 +61,7 @@ def hp_tensor_to_float8_dynamic( device_mesh, scaling_granularity, axiswise_dim, - power_of_2_scale, + round_scales_to_power_of_2, ) return hp_tensor_and_scale_to_float8( hp_tensor, diff --git a/torchao/float8/float8_utils.py b/torchao/float8/float8_utils.py index 8661b89b07..f08240a586 100644 --- a/torchao/float8/float8_utils.py +++ b/torchao/float8/float8_utils.py @@ -30,13 +30,15 @@ @torch.no_grad() def amax_to_scale( - amax: torch.Tensor, float8_dtype: torch.dtype, power_of_2_scale: bool = False + amax: torch.Tensor, + float8_dtype: torch.dtype, + round_scales_to_power_of_2: bool = False, ): """Converts the amax value of a tensor to the fp8 scale. Args: amax: The amax value of the tensor. float8_dtype: The float8 dtype. - power_of_2_scale: if true, round scaling factor down to the nearest power of 2. + round_scales_to_power_of_2: if true, round scaling factor down to the nearest power of 2. """ # torch.compile and eager show different numerics for 1.0 / float32, # upcast to float64 to ensure same numeric between compile and eager @@ -45,7 +47,7 @@ def amax_to_scale( res = torch.finfo(float8_dtype).max / torch.clamp(amax, min=EPS) else: raise ValueError(f"Unsupported float8_dtype: {float8_dtype}") - if power_of_2_scale: + if round_scales_to_power_of_2: # rounds down to the nearest power of 2. res = torch.exp2(torch.floor(torch.log2(res))) return res.to(torch.float32) @@ -126,7 +128,7 @@ def tensor_to_scale( device_mesh=None, scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE, axiswise_dim: Optional[int] = None, - power_of_2_scale: bool = False, + round_scales_to_power_of_2: bool = False, ) -> torch.Tensor: """ Compute scaling factor for the given high precision tensor. @@ -137,7 +139,7 @@ def tensor_to_scale( reduce_amax: whether to reduce the max(abs(hp_tensor)) value across distributed ranks scaling_granularity: Defines the scaling granularity axiswise_dim: if axiswise granularity is used, defines the dim to scale across - power_of_2_scale: if true, round scaling factor down to the nearest power of 2. + round_scales_to_power_of_2: if true, round scaling factor down to the nearest power of 2. """ amax = tensor_to_amax( hp_tensor, @@ -146,7 +148,9 @@ def tensor_to_scale( scaling_granularity, axiswise_dim, ) - return amax_to_scale(amax, float8_dtype, power_of_2_scale=power_of_2_scale) + return amax_to_scale( + amax, float8_dtype, round_scales_to_power_of_2=round_scales_to_power_of_2 + ) def to_fp8_saturated(x: torch.Tensor, float8_dtype: torch.dtype): From ab93e1842506ac49151a289f24515190acab9a42 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Wed, 5 Feb 2025 16:15:34 -0800 Subject: [PATCH 06/16] use bitshifting for power of 2 rounding --- torchao/float8/float8_scaling_utils.py | 1 + torchao/float8/float8_utils.py | 9 ++++++--- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/torchao/float8/float8_scaling_utils.py b/torchao/float8/float8_scaling_utils.py index a8ad4c1920..b96c7a9b58 100644 --- a/torchao/float8/float8_scaling_utils.py +++ b/torchao/float8/float8_scaling_utils.py @@ -27,6 +27,7 @@ ) +# TODO(danielvegamyhre): refactor to accept Float8LinearConfig directly def hp_tensor_to_float8_dynamic( hp_tensor: torch.Tensor, float8_dtype: torch.dtype, diff --git a/torchao/float8/float8_utils.py b/torchao/float8/float8_utils.py index f08240a586..bdb08bbb01 100644 --- a/torchao/float8/float8_utils.py +++ b/torchao/float8/float8_utils.py @@ -45,12 +45,15 @@ def amax_to_scale( amax = amax.to(torch.float64) if float8_dtype in FP8_TYPES: res = torch.finfo(float8_dtype).max / torch.clamp(amax, min=EPS) + res = res.to(torch.float32) else: raise ValueError(f"Unsupported float8_dtype: {float8_dtype}") if round_scales_to_power_of_2: - # rounds down to the nearest power of 2. - res = torch.exp2(torch.floor(torch.log2(res))) - return res.to(torch.float32) + # rounds down to the nearest power of 2 + res = res.view(torch.int32) + res = (res >> 23) << 23 + res = res.view(torch.float32) + return res @torch.no_grad() From 56132a3818499cb276ef10d8c6df821fa6e42c67 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Thu, 6 Feb 2025 08:26:49 -0800 Subject: [PATCH 07/16] add tests for round to power of 2 --- test/float8/test_utils.py | 35 ++++++++++++++++++++++++++++++++++ torchao/float8/float8_utils.py | 14 ++++++++++---- 2 files changed, 45 insertions(+), 4 deletions(-) create mode 100644 test/float8/test_utils.py diff --git a/test/float8/test_utils.py b/test/float8/test_utils.py new file mode 100644 index 0000000000..34b07d502e --- /dev/null +++ b/test/float8/test_utils.py @@ -0,0 +1,35 @@ +import pytest +import torch + +from torchao.float8.float8_utils import _round_down_to_power_of_2 + + +@pytest.mark.parametrize( + "input_shape", + [ + (1,), + (2, 3), + (8, 2048, 4, 1024), + ], +) +@pytest.mark.parametrize( + "multiplier", + [ + 1.0, + 2.5, + 10.0, + ], +) +def test_round_down_to_power_of_2(input_shape: tuple[int], multiplier: int): + input_tensor = torch.rand(*input_shape, dtype=torch.float32) * multiplier + expected_output = torch.exp2(torch.floor(torch.log2(input_tensor))) + result = _round_down_to_power_of_2(input_tensor) + assert torch.equal( + result, expected_output + ), f"expected {expected_output}, but got {result}" + + +def test_non_float32_input(): + non_float32_tensor = torch.tensor([3.0], dtype=torch.float64) + with pytest.raises(AssertionError, match="input must be float32 tensor"): + _round_down_to_power_of_2(non_float32_tensor) diff --git a/torchao/float8/float8_utils.py b/torchao/float8/float8_utils.py index bdb08bbb01..a7002516b8 100644 --- a/torchao/float8/float8_utils.py +++ b/torchao/float8/float8_utils.py @@ -49,10 +49,7 @@ def amax_to_scale( else: raise ValueError(f"Unsupported float8_dtype: {float8_dtype}") if round_scales_to_power_of_2: - # rounds down to the nearest power of 2 - res = res.view(torch.int32) - res = (res >> 23) << 23 - res = res.view(torch.float32) + res = _round_down_to_power_of_2(res) return res @@ -286,3 +283,12 @@ def config_has_stateful_scaling(config: Float8LinearConfig) -> bool: or config.cast_config_weight.scaling_type != ScalingType.DYNAMIC or config.cast_config_grad_output.scaling_type != ScalingType.DYNAMIC ) + + +def _round_down_to_power_of_2(x: torch.Tensor) -> torch.Tensor: + assert x.dtype == torch.float32, "input must be float32 tensor" + # rounds down to the nearest power of 2 + x = x.view(torch.int32) + x = (x >> 23) << 23 + x = x.view(torch.float32) + return x From 4169927367681010dd72e61ec569ecefd97c5492 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Thu, 6 Feb 2025 12:03:47 -0800 Subject: [PATCH 08/16] add unit tests for rounding scale down to nearest power of 2 --- test/float8/test_utils.py | 77 ++++++++++++++++++++++++---------- torchao/float8/float8_utils.py | 29 +++++++++---- 2 files changed, 76 insertions(+), 30 deletions(-) diff --git a/test/float8/test_utils.py b/test/float8/test_utils.py index 34b07d502e..f907a3112a 100644 --- a/test/float8/test_utils.py +++ b/test/float8/test_utils.py @@ -1,35 +1,68 @@ import pytest import torch -from torchao.float8.float8_utils import _round_down_to_power_of_2 +from torchao.float8.float8_utils import _round_scale_down_to_power_of_2 +# source for notable single-precision cases: +# https://en.wikipedia.org/wiki/Single-precision_floating-point_format +# +# TODO(danielvegamyhre): +# 1. add case for largest normal fp32 value: 2**127 * (2 - 2**-23). +# need to investigate why exp2(floor(log2(x)))=inf, but bitshift returns real value. +# 2. add case for "nan" +# need to investigate why exp2(floor(log2(nan)))=nan, but bitshift returns inf. +# 3. adjust cases for subnormal values so we aren't clamping the expected results +# into the normal range. +# preliminary investigation shows it may not be possible to support all subnormals +# with bitshifting, so we will need to debug/improve performance of exp2(floor(log2(x))) +# approach. @pytest.mark.parametrize( - "input_shape", - [ - (1,), - (2, 3), - (8, 2048, 4, 1024), - ], -) -@pytest.mark.parametrize( - "multiplier", + "input", [ 1.0, - 2.5, - 10.0, + float("inf"), + # smallest positive subnormal number + 2**-126 * 2**-23, + # largest subnormal number + 2**-126 * (1 - 2**-23), + # smallest positive normal number + 2**-126, + # largest number less than one + 1.0 - 2**-24, + # smallest number larger than one + 1.0 + 2**-23, ], ) -def test_round_down_to_power_of_2(input_shape: tuple[int], multiplier: int): - input_tensor = torch.rand(*input_shape, dtype=torch.float32) * multiplier - expected_output = torch.exp2(torch.floor(torch.log2(input_tensor))) - result = _round_down_to_power_of_2(input_tensor) +def test_round_scale_down_to_power_of_2_valid_inputs(input: float): + input_tensor = torch.tensor(input, dtype=torch.float32) + result = _round_scale_down_to_power_of_2(input_tensor) + + # get expected value for comparison + # TODO(danielvegamyhre): support subnormal values + expected_result = torch.exp2(torch.floor(torch.log2(input_tensor))) + smallest_normal_fp32_value = torch.tensor(2**-126, dtype=torch.float32) + expected_result = torch.max(expected_result, smallest_normal_fp32_value) + assert torch.equal( - result, expected_output - ), f"expected {expected_output}, but got {result}" + result, expected_result + ), f"input: {input_tensor}, expected {expected_result}, but got {result}" -def test_non_float32_input(): - non_float32_tensor = torch.tensor([3.0], dtype=torch.float64) - with pytest.raises(AssertionError, match="input must be float32 tensor"): - _round_down_to_power_of_2(non_float32_tensor) +@pytest.mark.parametrize( + "invalid_dtype", + [ + torch.bfloat16, + torch.float16, + torch.float64, + torch.int8, + torch.uint8, + torch.int32, + torch.uint32, + torch.int64, + ], +) +def test_non_float32_input(invalid_dtype: torch.dtype): + non_float32_tensor = torch.tensor([3.0], dtype=invalid_dtype) + with pytest.raises(AssertionError, match="scale must be float32 tensor"): + _round_scale_down_to_power_of_2(non_float32_tensor) diff --git a/torchao/float8/float8_utils.py b/torchao/float8/float8_utils.py index a7002516b8..ea669c08b4 100644 --- a/torchao/float8/float8_utils.py +++ b/torchao/float8/float8_utils.py @@ -49,7 +49,7 @@ def amax_to_scale( else: raise ValueError(f"Unsupported float8_dtype: {float8_dtype}") if round_scales_to_power_of_2: - res = _round_down_to_power_of_2(res) + res = _round_scale_down_to_power_of_2(res) return res @@ -285,10 +285,23 @@ def config_has_stateful_scaling(config: Float8LinearConfig) -> bool: ) -def _round_down_to_power_of_2(x: torch.Tensor) -> torch.Tensor: - assert x.dtype == torch.float32, "input must be float32 tensor" - # rounds down to the nearest power of 2 - x = x.view(torch.int32) - x = (x >> 23) << 23 - x = x.view(torch.float32) - return x +def _round_scale_down_to_power_of_2(x: torch.Tensor): + assert x.dtype == torch.float32, "scale must be float32 tensor" + + # eps = smallest normal fp32 value + # TODO(danielvegamyhre): support subnormal values + eps = 2**-126 + x = torch.clamp( + x, + min=eps, + ) + + # view as int32 to allow bitshifting + x_int = x.view(torch.int32) + + # clear mantissa bits (rightmost 23 bits) + x_int = (x_int >> 23) << 23 + + # return result as fp32 + result = x_int.view(torch.float32) + return result From c43449811ee2ef3a8378afffa749eed2c45703f7 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Thu, 6 Feb 2025 12:56:02 -0800 Subject: [PATCH 09/16] rename to test_float8_utils.py --- test/float8/{test_utils.py => test_float8_utils.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename test/float8/{test_utils.py => test_float8_utils.py} (100%) diff --git a/test/float8/test_utils.py b/test/float8/test_float8_utils.py similarity index 100% rename from test/float8/test_utils.py rename to test/float8/test_float8_utils.py From 40166e1902a1bcf8b7ea56b1d6fd182f5d849ec8 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Thu, 6 Feb 2025 13:43:11 -0800 Subject: [PATCH 10/16] convert to fp32 before rounding scale down to power of 2; update unit tests --- test/float8/test_float8_utils.py | 53 +++++++++++--------------------- torchao/float8/float8_utils.py | 23 ++------------ 2 files changed, 21 insertions(+), 55 deletions(-) diff --git a/test/float8/test_float8_utils.py b/test/float8/test_float8_utils.py index f907a3112a..184fa21343 100644 --- a/test/float8/test_float8_utils.py +++ b/test/float8/test_float8_utils.py @@ -6,47 +6,30 @@ # source for notable single-precision cases: # https://en.wikipedia.org/wiki/Single-precision_floating-point_format -# -# TODO(danielvegamyhre): -# 1. add case for largest normal fp32 value: 2**127 * (2 - 2**-23). -# need to investigate why exp2(floor(log2(x)))=inf, but bitshift returns real value. -# 2. add case for "nan" -# need to investigate why exp2(floor(log2(nan)))=nan, but bitshift returns inf. -# 3. adjust cases for subnormal values so we aren't clamping the expected results -# into the normal range. -# preliminary investigation shows it may not be possible to support all subnormals -# with bitshifting, so we will need to debug/improve performance of exp2(floor(log2(x))) -# approach. @pytest.mark.parametrize( - "input", + "test_case", [ - 1.0, - float("inf"), - # smallest positive subnormal number - 2**-126 * 2**-23, - # largest subnormal number - 2**-126 * (1 - 2**-23), - # smallest positive normal number - 2**-126, - # largest number less than one - 1.0 - 2**-24, - # smallest number larger than one - 1.0 + 2**-23, + # "test_case_name": [input, expected result] + ("one", [1.0, 1.0]), + ("inf", [float("inf"), float("inf")]), + ("smallest positive subnormal number", [2**-126 * 2**-23, 2**-126 * 2**-23]), + ("largest subnormal number", [2**-126 * (1 - 2**-23), 1.1754943508222875e-38]), + ("largest normal number", [2**127 * (2 - 2**-23), float("inf")]), + ("smallest positive normal number", [2**-126, 2**-126]), + ("largest number less than one", [1.0 - 2**-24, 0.5]), + ("smallest number larger than one", [1.0 + 2**-23, 1.0]), ], ) -def test_round_scale_down_to_power_of_2_valid_inputs(input: float): - input_tensor = torch.tensor(input, dtype=torch.float32) - result = _round_scale_down_to_power_of_2(input_tensor) - - # get expected value for comparison - # TODO(danielvegamyhre): support subnormal values - expected_result = torch.exp2(torch.floor(torch.log2(input_tensor))) - smallest_normal_fp32_value = torch.tensor(2**-126, dtype=torch.float32) - expected_result = torch.max(expected_result, smallest_normal_fp32_value) +def test_round_scale_down_to_power_of_2_valid_inputs( + test_case: dict, +): + test_case_name, (input, expected_result) = test_case + input_tensor, expected_tensor = torch.tensor(input), torch.tensor(expected_result) + result = _round_scale_down_to_power_of_2(input_tensor) assert torch.equal( - result, expected_result - ), f"input: {input_tensor}, expected {expected_result}, but got {result}" + result, expected_tensor + ), f"test: {test_case_name}, input: {input_tensor}, expected {expected_tensor}, but got {result}" @pytest.mark.parametrize( diff --git a/torchao/float8/float8_utils.py b/torchao/float8/float8_utils.py index ea669c08b4..926b97edb8 100644 --- a/torchao/float8/float8_utils.py +++ b/torchao/float8/float8_utils.py @@ -285,23 +285,6 @@ def config_has_stateful_scaling(config: Float8LinearConfig) -> bool: ) -def _round_scale_down_to_power_of_2(x: torch.Tensor): - assert x.dtype == torch.float32, "scale must be float32 tensor" - - # eps = smallest normal fp32 value - # TODO(danielvegamyhre): support subnormal values - eps = 2**-126 - x = torch.clamp( - x, - min=eps, - ) - - # view as int32 to allow bitshifting - x_int = x.view(torch.int32) - - # clear mantissa bits (rightmost 23 bits) - x_int = (x_int >> 23) << 23 - - # return result as fp32 - result = x_int.view(torch.float32) - return result +def _round_scale_down_to_power_of_2(scale: torch.Tensor): + assert scale.dtype == torch.float32, "scale must be float32 tensor" + return torch.exp2(torch.floor(torch.log2(scale))) From c6bcac840f176956c9da0688e623fb4e966bb1f2 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Thu, 6 Feb 2025 16:54:35 -0800 Subject: [PATCH 11/16] run tests on gpu --- test/float8/test_float8_utils.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/test/float8/test_float8_utils.py b/test/float8/test_float8_utils.py index 184fa21343..ef0db2658b 100644 --- a/test/float8/test_float8_utils.py +++ b/test/float8/test_float8_utils.py @@ -1,3 +1,5 @@ +import unittest + import pytest import torch @@ -6,6 +8,7 @@ # source for notable single-precision cases: # https://en.wikipedia.org/wiki/Single-precision_floating-point_format +@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") @pytest.mark.parametrize( "test_case", [ @@ -24,8 +27,10 @@ def test_round_scale_down_to_power_of_2_valid_inputs( test_case: dict, ): test_case_name, (input, expected_result) = test_case - input_tensor, expected_tensor = torch.tensor(input), torch.tensor(expected_result) - + input_tensor, expected_tensor = ( + torch.tensor(input).cuda(), + torch.tensor(expected_result).cuda(), + ) result = _round_scale_down_to_power_of_2(input_tensor) assert torch.equal( result, expected_tensor From 77d004e173ad30d92aa6bc095126875be9f1d817 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Thu, 6 Feb 2025 17:03:59 -0800 Subject: [PATCH 12/16] test nan --- test/float8/test_float8_utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/float8/test_float8_utils.py b/test/float8/test_float8_utils.py index ef0db2658b..80b169089e 100644 --- a/test/float8/test_float8_utils.py +++ b/test/float8/test_float8_utils.py @@ -15,6 +15,7 @@ # "test_case_name": [input, expected result] ("one", [1.0, 1.0]), ("inf", [float("inf"), float("inf")]), + ("nan", [float("nan"), float("nan")]), ("smallest positive subnormal number", [2**-126 * 2**-23, 2**-126 * 2**-23]), ("largest subnormal number", [2**-126 * (1 - 2**-23), 1.1754943508222875e-38]), ("largest normal number", [2**127 * (2 - 2**-23), float("inf")]), @@ -32,8 +33,9 @@ def test_round_scale_down_to_power_of_2_valid_inputs( torch.tensor(expected_result).cuda(), ) result = _round_scale_down_to_power_of_2(input_tensor) - assert torch.equal( - result, expected_tensor + assert ( + torch.equal(result, expected_tensor) + or (result.isnan() and expected_tensor.isnan()) ), f"test: {test_case_name}, input: {input_tensor}, expected {expected_tensor}, but got {result}" From 533e027a39a84a5b86f5b30586e3b43605fb9760 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Fri, 7 Feb 2025 09:55:54 -0800 Subject: [PATCH 13/16] skip torch versions < 2.5 --- test/float8/test_float8_utils.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/test/float8/test_float8_utils.py b/test/float8/test_float8_utils.py index 80b169089e..25c5cdc69c 100644 --- a/test/float8/test_float8_utils.py +++ b/test/float8/test_float8_utils.py @@ -4,6 +4,10 @@ import torch from torchao.float8.float8_utils import _round_scale_down_to_power_of_2 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 + +if not TORCH_VERSION_AT_LEAST_2_5: + pytest.skip("Unsupported PyTorch version", allow_module_level=True) # source for notable single-precision cases: @@ -33,6 +37,11 @@ def test_round_scale_down_to_power_of_2_valid_inputs( torch.tensor(expected_result).cuda(), ) result = _round_scale_down_to_power_of_2(input_tensor) + + print(f"input: {input}") + print(f"input tensor: {input_tensor}") + print(f"result: {result}") + print(f"expected_result: {expected_result}") assert ( torch.equal(result, expected_tensor) or (result.isnan() and expected_tensor.isnan()) From 69dbadb416e0644ec319fd1ce3b1686abf9a2748 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Fri, 7 Feb 2025 13:50:58 -0800 Subject: [PATCH 14/16] explicitly use float32 --- test/float8/test_float8_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/float8/test_float8_utils.py b/test/float8/test_float8_utils.py index 25c5cdc69c..10d25dd511 100644 --- a/test/float8/test_float8_utils.py +++ b/test/float8/test_float8_utils.py @@ -33,8 +33,8 @@ def test_round_scale_down_to_power_of_2_valid_inputs( ): test_case_name, (input, expected_result) = test_case input_tensor, expected_tensor = ( - torch.tensor(input).cuda(), - torch.tensor(expected_result).cuda(), + torch.tensor(input, dtype=torch.float32).cuda(), + torch.tensor(expected_result, dtype=torch.float32).cuda(), ) result = _round_scale_down_to_power_of_2(input_tensor) From fa552c6639c38567d7dcd45add69a65e6f6ee7e1 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Sat, 8 Feb 2025 18:10:07 -0800 Subject: [PATCH 15/16] add todo for truncation issue --- test/float8/test_float8_utils.py | 28 +++++++++++++--------------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/test/float8/test_float8_utils.py b/test/float8/test_float8_utils.py index 10d25dd511..ca9f21dde1 100644 --- a/test/float8/test_float8_utils.py +++ b/test/float8/test_float8_utils.py @@ -16,32 +16,30 @@ @pytest.mark.parametrize( "test_case", [ - # "test_case_name": [input, expected result] - ("one", [1.0, 1.0]), - ("inf", [float("inf"), float("inf")]), - ("nan", [float("nan"), float("nan")]), - ("smallest positive subnormal number", [2**-126 * 2**-23, 2**-126 * 2**-23]), - ("largest subnormal number", [2**-126 * (1 - 2**-23), 1.1754943508222875e-38]), - ("largest normal number", [2**127 * (2 - 2**-23), float("inf")]), - ("smallest positive normal number", [2**-126, 2**-126]), - ("largest number less than one", [1.0 - 2**-24, 0.5]), - ("smallest number larger than one", [1.0 + 2**-23, 1.0]), + # ("test_case_name", input, expected result) + ("one", 1.0, 1.0), + ("inf", float("inf"), float("inf")), + ("nan", float("nan"), float("nan")), + ("smallest positive subnormal number", 2**-126 * 2**-23, 2**-126 * 2**-23), + ("largest normal number", 2**127 * (2 - 2**-23), float("inf")), + ("smallest positive normal number", 2**-126, 2**-126), + ("largest number less than one", 1.0 - 2**-24, 0.5), + ("smallest number larger than one", 1.0 + 2**-23, 1.0), + # TODO(danielvegamyhre): debug why creating a tensor with largest + # subnormal value in CI env for pytorch 2.5.1 truncates the value to 0. + # ("largest subnormal number", [2**-126 * (1 - 2**-23), 1.1754943508222875e-38]), ], ) def test_round_scale_down_to_power_of_2_valid_inputs( test_case: dict, ): - test_case_name, (input, expected_result) = test_case + test_case_name, input, expected_result = test_case input_tensor, expected_tensor = ( torch.tensor(input, dtype=torch.float32).cuda(), torch.tensor(expected_result, dtype=torch.float32).cuda(), ) result = _round_scale_down_to_power_of_2(input_tensor) - print(f"input: {input}") - print(f"input tensor: {input_tensor}") - print(f"result: {result}") - print(f"expected_result: {expected_result}") assert ( torch.equal(result, expected_tensor) or (result.isnan() and expected_tensor.isnan()) From 21e8061baddaf7fd60831077c3a0e5ac208598c2 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Mon, 10 Feb 2025 09:45:58 -0800 Subject: [PATCH 16/16] e4m3 on all casts for fp8 rowwise --- torchao/float8/config.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/torchao/float8/config.py b/torchao/float8/config.py index 2a5a96cc46..21c3c03283 100644 --- a/torchao/float8/config.py +++ b/torchao/float8/config.py @@ -335,9 +335,15 @@ def recipe_name_to_linear_config( elif recipe_name is Float8LinearRecipeName.ALL_AXISWISE: # dynamic axiswise scaling with the CUTLASS rowwise kernel - cc_i = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) - cc_w = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) - cc_go = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) + cc_i = CastConfig( + scaling_granularity=ScalingGranularity.AXISWISE, target_dtype=e4m3_dtype + ) + cc_w = CastConfig( + scaling_granularity=ScalingGranularity.AXISWISE, target_dtype=e4m3_dtype + ) + cc_go = CastConfig( + scaling_granularity=ScalingGranularity.AXISWISE, target_dtype=e4m3_dtype + ) return Float8LinearConfig( cast_config_input=cc_i,