Skip to content

Commit

Permalink
support power of 2 scaling factors in float8 training
Browse files Browse the repository at this point in the history
  • Loading branch information
danielvegamyhre committed Feb 5, 2025
1 parent 8afd10e commit f2433b1
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 18 deletions.
16 changes: 9 additions & 7 deletions test/float8/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
import torch.nn as nn

from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_5,
is_sm_at_least_89,
is_sm_at_least_90,
TORCH_VERSION_AT_LEAST_2_5,
)

if not TORCH_VERSION_AT_LEAST_2_5:
Expand All @@ -26,13 +26,13 @@

from torchao.float8.config import (
CastConfig,
e4m3_dtype,
e5m2_dtype,
Float8LinearConfig,
Float8LinearRecipeName,
recipe_name_to_linear_config,
ScalingGranularity,
ScalingType,
e4m3_dtype,
e5m2_dtype,
recipe_name_to_linear_config,
)
from torchao.float8.float8_linear import Float8Linear
from torchao.float8.float8_linear_utils import (
Expand All @@ -48,15 +48,15 @@
from torchao.float8.float8_tensor import (
Float8Tensor,
GemmInputRole,
hp_tensor_and_scale_to_float8,
LinearMMConfig,
ScaledMMConfig,
hp_tensor_and_scale_to_float8,
)
from torchao.float8.float8_utils import (
FP8_TYPES,
compute_error,
config_has_stateful_scaling,
fp8_tensor_statistics,
FP8_TYPES,
tensor_to_scale,
)
from torchao.float8.stateful_float8_linear import StatefulFloat8Linear
Expand Down Expand Up @@ -164,7 +164,8 @@ def test_transpose(self):

@pytest.mark.parametrize("shape", [(8, 16), (4, 8, 16), (2, 4, 8, 16)])
@pytest.mark.parametrize("axiswise_dim", [0, -1])
def test_axiswise_dynamic_cast(self, shape, axiswise_dim):
@pytest.mark.parametrize("power_of_2_scale", [True, False])
def test_axiswise_dynamic_cast(self, shape, axiswise_dim, power_of_2_scale):
a = torch.randn(*shape, dtype=torch.bfloat16)
linear_mm_config = LinearMMConfig()
a_fp8 = hp_tensor_to_float8_dynamic(
Expand All @@ -173,6 +174,7 @@ def test_axiswise_dynamic_cast(self, shape, axiswise_dim):
linear_mm_config,
scaling_granularity=ScalingGranularity.AXISWISE,
axiswise_dim=axiswise_dim,
power_of_2_scale=power_of_2_scale,
)
a_dq = a_fp8.to_original_precision()
sqnr = compute_error(a, a_dq)
Expand Down
26 changes: 15 additions & 11 deletions test/float8/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
import pytest

from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_5,
is_sm_at_least_89,
is_sm_at_least_90,
TORCH_VERSION_AT_LEAST_2_5,
)

if not TORCH_VERSION_AT_LEAST_2_5:
Expand All @@ -29,11 +29,11 @@
from torchao.float8 import _prototype_register_float8_delayed_scaling_inductor_passes
from torchao.float8.config import (
CastConfig,
e4m3_dtype,
Float8LinearConfig,
Float8LinearRecipeName,
ScalingType,
e4m3_dtype,
recipe_name_to_linear_config,
ScalingType,
)
from torchao.float8.float8_linear import Float8Linear
from torchao.float8.float8_linear_utils import (
Expand All @@ -45,11 +45,7 @@
hp_tensor_to_float8_delayed,
hp_tensor_to_float8_dynamic,
)
from torchao.float8.float8_tensor import (
GemmInputRole,
LinearMMConfig,
ScaledMMConfig,
)
from torchao.float8.float8_tensor import GemmInputRole, LinearMMConfig, ScaledMMConfig
from torchao.float8.float8_utils import config_has_stateful_scaling
from torchao.float8.stateful_float8_linear import StatefulFloat8Linear
from torchao.testing.float8.test_utils import get_test_float8_linear_config
Expand Down Expand Up @@ -420,7 +416,14 @@ def test_sync_amax_func_cuda_graph_success():
torch.float16,
],
)
def test_dynamic_scale_numeric_parity(dtype: torch.dtype):
@pytest.mark.parametrize(
"power_of_2_scale",
[
True,
False,
],
)
def test_dynamic_scale_numeric_parity(dtype: torch.dtype, power_of_2_scale: bool):
scaling_type_weight = ScalingType.DYNAMIC
torch.manual_seed(42)
hp_tensor1 = torch.randn(16, 16, device="cuda", dtype=dtype)
Expand Down Expand Up @@ -456,13 +459,15 @@ def test_dynamic_scale_numeric_parity(dtype: torch.dtype):
e4m3_dtype,
linear_mm_config,
gemm_input_role=GemmInputRole.WEIGHT,
power_of_2_scale=power_of_2_scale,
)
torch._dynamo.reset()
float8_compile = torch.compile(hp_tensor_to_float8_dynamic)(
hp_tensor2,
e4m3_dtype,
linear_mm_config,
gemm_input_role=GemmInputRole.WEIGHT,
power_of_2_scale=power_of_2_scale,
)
assert torch.equal(float8_eager._scale, float8_compile._scale)
assert torch.equal(float8_eager._data, float8_compile._data)
Expand All @@ -474,8 +479,7 @@ def test_dynamic_scale_numeric_parity(dtype: torch.dtype):
)
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
def test_delayed_scaling_pattern_replacement(dtype: torch.dtype):
from torch._inductor import config as inductor_config
from torch._inductor import metrics
from torch._inductor import config as inductor_config, metrics

inductor_config.loop_ordering_after_fusion = True

Expand Down
9 changes: 9 additions & 0 deletions torchao/float8/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,13 @@ class Float8LinearConfig:
# tests so that the warning does not spam the CI stdout.
force_recompute_fp8_weight_in_bwd: bool = False

# If this option is enabled, the scaling factor used for float8 quantization
# will be rounded down to the nearest power of 2. This has been shown to help
# reduce quantization error by avoiding rounding errors when multiplying/dividing
# by the scaling factor, as well as ensuring large values are quantized to the
# same value in the forward pass as the backward passes.
power_of_2_scale: bool = False

def __post_init__(self):
# Populate the additional cast overrides, if the user did not specify them
# Note: this hacks around the frozen-ness of this dataclass
Expand Down Expand Up @@ -336,6 +343,8 @@ def recipe_name_to_linear_config(
cast_config_input=cc_i,
cast_config_weight=cc_w,
cast_config_grad_output=cc_go,
# enable power of 2 scaling factors by default for row-wise scaling
power_of_2_scale=True,
)

elif recipe_name is Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP:
Expand Down
6 changes: 6 additions & 0 deletions torchao/float8/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def forward(
axiswise_dim=get_maybe_axiswise_dim(
-1, c.cast_config_input.scaling_granularity
),
power_of_2_scale=c.power_of_2_scale,
)

if tensor_already_casted_to_fp8(weight_hp_t):
Expand All @@ -112,6 +113,7 @@ def forward(
axiswise_dim=get_maybe_axiswise_dim(
0, c.cast_config_weight.scaling_granularity
),
power_of_2_scale=c.power_of_2_scale,
)

# the reshapes are needed in order to make the shapes compatible with
Expand Down Expand Up @@ -151,6 +153,7 @@ def backward(ctx, grad_output):
axiswise_dim=get_maybe_axiswise_dim(
-1, c.cast_config_grad_output.scaling_granularity
),
power_of_2_scale=c.power_of_2_scale,
)

if tensor_already_casted_to_fp8(weight_hp_t):
Expand Down Expand Up @@ -181,6 +184,7 @@ def backward(ctx, grad_output):
axiswise_dim=get_maybe_axiswise_dim(
-1, c.cast_config_weight_for_grad_input.scaling_granularity
),
power_of_2_scale=c.power_of_2_scale,
)

grad_input = torch.mm(
Expand Down Expand Up @@ -216,6 +220,7 @@ def backward(ctx, grad_output):
axiswise_dim=get_maybe_axiswise_dim(
0, c.cast_config_grad_output_for_grad_weight.scaling_granularity
),
power_of_2_scale=c.power_of_2_scale,
)

if tensor_already_casted_to_fp8(input_hp_reshaped):
Expand All @@ -233,6 +238,7 @@ def backward(ctx, grad_output):
axiswise_dim=get_maybe_axiswise_dim(
0, c.cast_config_input_for_grad_weight.scaling_granularity
),
power_of_2_scale=c.power_of_2_scale,
)

grad_weight = torch.mm(
Expand Down
5 changes: 5 additions & 0 deletions torchao/float8/float8_scaling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def hp_tensor_to_float8_dynamic(
device_mesh=None,
scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE,
axiswise_dim: Optional[int] = None,
power_of_2_scale: bool = False,
) -> Float8Tensor:
"""
Given a high precision tensor `hp_tensor`,
Expand All @@ -51,6 +52,7 @@ def hp_tensor_to_float8_dynamic(
the 3 fwd/bwd gemms of linear
scaling_granularity: Defines the scaling granularity
axiswise_dim: if axiswise granularity is used, defines the dim to scale across
power_of_2_scale: if true, round scaling factor down to the nearest power of 2.
"""
scale = tensor_to_scale(
hp_tensor,
Expand All @@ -60,6 +62,9 @@ def hp_tensor_to_float8_dynamic(
scaling_granularity,
axiswise_dim,
)
if power_of_2_scale:
# rounds down to the nearest power of 2.
scale = torch.exp2(torch.floor(torch.log2(scale)))
return hp_tensor_and_scale_to_float8(
hp_tensor,
scale,
Expand Down

0 comments on commit f2433b1

Please sign in to comment.