Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support power of 2 scaling factors in float8 training and use e4m3 everywhere #1670

Merged
merged 16 commits into from
Feb 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion test/float8/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,10 @@ def test_transpose(self):

@pytest.mark.parametrize("shape", [(8, 16), (4, 8, 16), (2, 4, 8, 16)])
@pytest.mark.parametrize("axiswise_dim", [0, -1])
def test_axiswise_dynamic_cast(self, shape, axiswise_dim):
@pytest.mark.parametrize("round_scales_to_power_of_2", [True, False])
def test_axiswise_dynamic_cast(
self, shape, axiswise_dim, round_scales_to_power_of_2
):
a = torch.randn(*shape, dtype=torch.bfloat16)
linear_mm_config = LinearMMConfig()
a_fp8 = hp_tensor_to_float8_dynamic(
Expand All @@ -173,6 +176,7 @@ def test_axiswise_dynamic_cast(self, shape, axiswise_dim):
linear_mm_config,
scaling_granularity=ScalingGranularity.AXISWISE,
axiswise_dim=axiswise_dim,
round_scales_to_power_of_2=round_scales_to_power_of_2,
)
a_dq = a_fp8.to_original_precision()
sqnr = compute_error(a, a_dq)
Expand Down
20 changes: 14 additions & 6 deletions test/float8/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,7 @@
hp_tensor_to_float8_delayed,
hp_tensor_to_float8_dynamic,
)
from torchao.float8.float8_tensor import (
GemmInputRole,
LinearMMConfig,
ScaledMMConfig,
)
from torchao.float8.float8_tensor import GemmInputRole, LinearMMConfig, ScaledMMConfig
from torchao.float8.float8_utils import config_has_stateful_scaling
from torchao.float8.stateful_float8_linear import StatefulFloat8Linear
from torchao.testing.float8.test_utils import get_test_float8_linear_config
Expand Down Expand Up @@ -420,13 +416,23 @@ def test_sync_amax_func_cuda_graph_success():
torch.float16,
],
)
def test_dynamic_scale_numeric_parity(dtype: torch.dtype):
@pytest.mark.parametrize(
"round_scales_to_power_of_2",
[
True,
False,
],
)
def test_dynamic_scale_numeric_parity(
dtype: torch.dtype, round_scales_to_power_of_2: bool
):
scaling_type_weight = ScalingType.DYNAMIC
torch.manual_seed(42)
hp_tensor1 = torch.randn(16, 16, device="cuda", dtype=dtype)
hp_tensor2 = hp_tensor1.detach().clone()
float8_config = Float8LinearConfig(
cast_config_weight=CastConfig(scaling_type=scaling_type_weight),
round_scales_to_power_of_2=round_scales_to_power_of_2,
)
linear_mm_config = LinearMMConfig(
# output
Expand Down Expand Up @@ -456,13 +462,15 @@ def test_dynamic_scale_numeric_parity(dtype: torch.dtype):
e4m3_dtype,
linear_mm_config,
gemm_input_role=GemmInputRole.WEIGHT,
round_scales_to_power_of_2=float8_config.round_scales_to_power_of_2,
)
torch._dynamo.reset()
float8_compile = torch.compile(hp_tensor_to_float8_dynamic)(
hp_tensor2,
e4m3_dtype,
linear_mm_config,
gemm_input_role=GemmInputRole.WEIGHT,
round_scales_to_power_of_2=float8_config.round_scales_to_power_of_2,
)
assert torch.equal(float8_eager._scale, float8_compile._scale)
assert torch.equal(float8_eager._data, float8_compile._data)
Expand Down
65 changes: 65 additions & 0 deletions test/float8/test_float8_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import unittest

import pytest
import torch

from torchao.float8.float8_utils import _round_scale_down_to_power_of_2
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5

if not TORCH_VERSION_AT_LEAST_2_5:
pytest.skip("Unsupported PyTorch version", allow_module_level=True)


# source for notable single-precision cases:
# https://en.wikipedia.org/wiki/Single-precision_floating-point_format
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@pytest.mark.parametrize(
"test_case",
[
# ("test_case_name", input, expected result)
("one", 1.0, 1.0),
("inf", float("inf"), float("inf")),
("nan", float("nan"), float("nan")),
("smallest positive subnormal number", 2**-126 * 2**-23, 2**-126 * 2**-23),
("largest normal number", 2**127 * (2 - 2**-23), float("inf")),
("smallest positive normal number", 2**-126, 2**-126),
("largest number less than one", 1.0 - 2**-24, 0.5),
("smallest number larger than one", 1.0 + 2**-23, 1.0),
# TODO(danielvegamyhre): debug why creating a tensor with largest
# subnormal value in CI env for pytorch 2.5.1 truncates the value to 0.
# ("largest subnormal number", [2**-126 * (1 - 2**-23), 1.1754943508222875e-38]),
],
)
def test_round_scale_down_to_power_of_2_valid_inputs(
test_case: dict,
):
test_case_name, input, expected_result = test_case
input_tensor, expected_tensor = (
torch.tensor(input, dtype=torch.float32).cuda(),
torch.tensor(expected_result, dtype=torch.float32).cuda(),
)
result = _round_scale_down_to_power_of_2(input_tensor)

assert (
torch.equal(result, expected_tensor)
or (result.isnan() and expected_tensor.isnan())
), f"test: {test_case_name}, input: {input_tensor}, expected {expected_tensor}, but got {result}"


@pytest.mark.parametrize(
"invalid_dtype",
[
torch.bfloat16,
torch.float16,
torch.float64,
torch.int8,
torch.uint8,
torch.int32,
torch.uint32,
torch.int64,
],
)
def test_non_float32_input(invalid_dtype: torch.dtype):
non_float32_tensor = torch.tensor([3.0], dtype=invalid_dtype)
with pytest.raises(AssertionError, match="scale must be float32 tensor"):
_round_scale_down_to_power_of_2(non_float32_tensor)
21 changes: 18 additions & 3 deletions torchao/float8/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,13 @@ class Float8LinearConfig:
# tests so that the warning does not spam the CI stdout.
force_recompute_fp8_weight_in_bwd: bool = False

# If this option is enabled, the scaling factor used for float8 quantization
# will be rounded down to the nearest power of 2. This has been shown to help
# reduce quantization error by avoiding rounding errors when multiplying/dividing
# by the scaling factor, as well as ensuring large values are quantized to the
# same value in the forward pass as the backward passes.
round_scales_to_power_of_2: bool = False

def __post_init__(self):
# Populate the additional cast overrides, if the user did not specify them
# Note: this hacks around the frozen-ness of this dataclass
Expand Down Expand Up @@ -328,14 +335,22 @@ def recipe_name_to_linear_config(

elif recipe_name is Float8LinearRecipeName.ALL_AXISWISE:
# dynamic axiswise scaling with the CUTLASS rowwise kernel
cc_i = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
cc_w = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
cc_go = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
cc_i = CastConfig(
scaling_granularity=ScalingGranularity.AXISWISE, target_dtype=e4m3_dtype
)
cc_w = CastConfig(
scaling_granularity=ScalingGranularity.AXISWISE, target_dtype=e4m3_dtype
)
cc_go = CastConfig(
scaling_granularity=ScalingGranularity.AXISWISE, target_dtype=e4m3_dtype
)

return Float8LinearConfig(
cast_config_input=cc_i,
cast_config_weight=cc_w,
cast_config_grad_output=cc_go,
# enable power of 2 scaling factors by default for row-wise scaling
round_scales_to_power_of_2=True,
)

elif recipe_name is Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP:
Expand Down
6 changes: 6 additions & 0 deletions torchao/float8/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def forward(
axiswise_dim=get_maybe_axiswise_dim(
-1, c.cast_config_input.scaling_granularity
),
round_scales_to_power_of_2=c.round_scales_to_power_of_2,
)

if tensor_already_casted_to_fp8(weight_hp_t):
Expand All @@ -112,6 +113,7 @@ def forward(
axiswise_dim=get_maybe_axiswise_dim(
0, c.cast_config_weight.scaling_granularity
),
round_scales_to_power_of_2=c.round_scales_to_power_of_2,
)

# the reshapes are needed in order to make the shapes compatible with
Expand Down Expand Up @@ -151,6 +153,7 @@ def backward(ctx, grad_output):
axiswise_dim=get_maybe_axiswise_dim(
-1, c.cast_config_grad_output.scaling_granularity
),
round_scales_to_power_of_2=c.round_scales_to_power_of_2,
)

if tensor_already_casted_to_fp8(weight_hp_t):
Expand Down Expand Up @@ -181,6 +184,7 @@ def backward(ctx, grad_output):
axiswise_dim=get_maybe_axiswise_dim(
-1, c.cast_config_weight_for_grad_input.scaling_granularity
),
round_scales_to_power_of_2=c.round_scales_to_power_of_2,
)

grad_input = torch.mm(
Expand Down Expand Up @@ -216,6 +220,7 @@ def backward(ctx, grad_output):
axiswise_dim=get_maybe_axiswise_dim(
0, c.cast_config_grad_output_for_grad_weight.scaling_granularity
),
round_scales_to_power_of_2=c.round_scales_to_power_of_2,
)

if tensor_already_casted_to_fp8(input_hp_reshaped):
Expand All @@ -233,6 +238,7 @@ def backward(ctx, grad_output):
axiswise_dim=get_maybe_axiswise_dim(
0, c.cast_config_input_for_grad_weight.scaling_granularity
),
round_scales_to_power_of_2=c.round_scales_to_power_of_2,
)

grad_weight = torch.mm(
Expand Down
4 changes: 4 additions & 0 deletions torchao/float8/float8_scaling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
)


# TODO(danielvegamyhre): refactor to accept Float8LinearConfig directly
def hp_tensor_to_float8_dynamic(
hp_tensor: torch.Tensor,
float8_dtype: torch.dtype,
Expand All @@ -36,6 +37,7 @@ def hp_tensor_to_float8_dynamic(
device_mesh=None,
scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE,
axiswise_dim: Optional[int] = None,
round_scales_to_power_of_2: bool = False,
) -> Float8Tensor:
"""
Given a high precision tensor `hp_tensor`,
Expand All @@ -51,6 +53,7 @@ def hp_tensor_to_float8_dynamic(
the 3 fwd/bwd gemms of linear
scaling_granularity: Defines the scaling granularity
axiswise_dim: if axiswise granularity is used, defines the dim to scale across
round_scales_to_power_of_2: if true, round scaling factor down to the nearest power of 2.
"""
scale = tensor_to_scale(
hp_tensor,
Expand All @@ -59,6 +62,7 @@ def hp_tensor_to_float8_dynamic(
device_mesh,
scaling_granularity,
axiswise_dim,
round_scales_to_power_of_2,
)
return hp_tensor_and_scale_to_float8(
hp_tensor,
Expand Down
44 changes: 33 additions & 11 deletions torchao/float8/float8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,7 @@
import torch.distributed as dist
from torch.distributed._functional_collectives import AsyncCollectiveTensor, all_reduce

from torchao.float8.config import (
Float8LinearConfig,
ScalingGranularity,
ScalingType,
)
from torchao.float8.config import Float8LinearConfig, ScalingGranularity, ScalingType

# Helpful visualizer for debugging (only supports fp32):
# https://www.h-schmidt.net/FloatConverter/IEEE754.html
Expand All @@ -33,21 +29,28 @@


@torch.no_grad()
def amax_to_scale(amax: torch.Tensor, float8_dtype: torch.dtype):
def amax_to_scale(
amax: torch.Tensor,
float8_dtype: torch.dtype,
round_scales_to_power_of_2: bool = False,
):
"""Converts the amax value of a tensor to the fp8 scale.
Args:
amax: The amax value of the tensor.
float8_dtype: The float8 dtype.
round_scales_to_power_of_2: if true, round scaling factor down to the nearest power of 2.
"""
# torch.compile and eager show different numerics for 1.0 / float32,
# upcast to float64 to ensure same numeric between compile and eager
amax = amax.to(torch.float64)
if float8_dtype in FP8_TYPES:
res = torch.finfo(float8_dtype).max / torch.clamp(amax, min=EPS)
res = res.to(torch.float32)
else:
raise ValueError(f"Unsupported float8_dtype: {float8_dtype}")

return res.to(torch.float32)
if round_scales_to_power_of_2:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we're using bit shifting, IMO it would be good to

  1. wrap this into a function
  2. assert the input is float32
  3. add tests just around this function, testing that 0, positive finite number, infinity, nan are all handled correctly

it's ok as is, but numerical correctness is IMO a good place to be super explicit and eliminate potential confusion

Copy link
Contributor Author

@danielvegamyhre danielvegamyhre Feb 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, done. I also did another round of torchtitan benchmarks with the final implementation:

Float8 row wise without power of 2:

[rank0]:2025-02-06 12:13:42,245 - root - INFO - step:  1  loss: 12.2341  memory: 47.97GiB(50.49%)  tps: 755  mfu: 4.42%
[rank0]:2025-02-06 12:13:42,246 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2025-02-06 12:13:54,758 - root - INFO - step: 10  loss: 10.0339  memory: 62.87GiB(66.17%)  tps: 5,893  mfu: 34.51%
[rank0]:2025-02-06 12:14:08,326 - root - INFO - step: 20  loss:  8.4962  memory: 62.87GiB(66.17%)  tps: 6,038  mfu: 35.36%
[rank0]:2025-02-06 12:14:21,886 - root - INFO - step: 30  loss:  7.6160  memory: 62.87GiB(66.17%)  tps: 6,042  mfu: 35.38%

Float8 row wise with power of 2:

[rank0]:2025-02-06 12:10:54,300 - root - INFO - step:  1  loss: 12.2512  memory: 47.97GiB(50.49%)  tps: 347  mfu: 2.03%
[rank0]:2025-02-06 12:10:54,301 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2025-02-06 12:11:06,505 - root - INFO - step: 10  loss: 10.1018  memory: 62.87GiB(66.17%)  tps: 6,041  mfu: 35.38%
[rank0]:2025-02-06 12:11:20,063 - root - INFO - step: 20  loss:  8.6927  memory: 62.87GiB(66.17%)  tps: 6,043  mfu: 35.39%
[rank0]:2025-02-06 12:11:33,621 - root - INFO - step: 30  loss:  7.6843  memory: 62.87GiB(66.17%)  tps: 6,042  mfu: 35.38%

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is ready for another look when you have time - CI error on H100s is unrelated:

docker: Error response from daemon: failed to create task for container: failed to create shim task: OCI runtime create failed: runc create failed: unable to start container process: error during container init: error running prestart hook #0: exit status 1, stdout: , stderr: Auto-detected mode as 'legacy'
nvidia-container-cli: error parsing IMEX info: unsupported IMEX channel value: all: unknown.

I think it may be caused by using a legacy container image without certain IMEX env var set? NVIDIA/nvidia-container-toolkit#797

anyway, i'll try retriggering CI and also in the meantime i'll take a look at the triton kernels compile generates for the exp2(floor(log2(x)) approach and see if i can tell why it's slow

Copy link
Contributor Author

@danielvegamyhre danielvegamyhre Feb 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i just looked into why exp2(floor(log2(x))) is slow actually and it's actually an easy fix. when we do the rounding on this line, the scale is still in fp64:

this causes it to be slow for some reason.

if we convert to fp32 before doing the rounding, instead of at the end when we return (

return res.to(torch.float32)
), this eliminates the TPS regression.

Maybe simply because with double the bit-width these rounding ops are slower (we can only achieve 50% max TFLOPs in FP64 as FP32 on H100)? I'm surprised the effect is pronounced enough to cause a 8% regression in overall TPS when rounding fp64 scales, though. I haven't looked into the generated triton kernels yet, prioritizing shipping this first.

Benchmark data:

When scale is still float64 when rounding:

[rank0]:2025-02-06 13:24:25,812 - root - INFO - step:  1  loss: 12.2439  memory: 47.97GiB(50.49%)  tps: 863  mfu: 5.06%
[rank0]:2025-02-06 13:24:25,812 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2025-02-06 13:24:38,548 - root - INFO - step: 10  loss:  9.9485  memory: 62.87GiB(66.17%)  tps: 5,789  mfu: 33.90%
[rank0]:2025-02-06 13:24:52,685 - root - INFO - step: 20  loss:  8.4416  memory: 62.87GiB(66.17%)  tps: 5,795  mfu: 33.93%
[rank0]:2025-02-06 13:25:06,827 - root - INFO - step: 30  loss:  7.6019  memory: 62.87GiB(66.17%)  tps: 5,793  mfu: 33.92%
[rank0]:2025-02-06 13:25:20,968 - root - INFO - step: 40  loss:  7.4452  memory: 62.87GiB(66.17%)  tps: 5,793  mfu: 33.93%

When scale is converted to fp32 before rounding:

[rank0]:2025-02-06 13:22:44,780 - root - INFO - step:  1  loss: 12.2436  memory: 47.97GiB(50.49%)  tps: 859  mfu: 5.03%
[rank0]:2025-02-06 13:22:44,781 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2025-02-06 13:22:56,675 - root - INFO - step: 10  loss:  9.9731  memory: 62.87GiB(66.17%)  tps: 6,199  mfu: 36.30%
[rank0]:2025-02-06 13:23:09,869 - root - INFO - step: 20  loss:  8.5158  memory: 62.87GiB(66.17%)  tps: 6,209  mfu: 36.36%
[rank0]:2025-02-06 13:23:23,060 - root - INFO - step: 30  loss:  7.5902  memory: 62.87GiB(66.17%)  tps: 6,211  mfu: 36.37%
[rank0]:2025-02-06 13:23:36,270 - root - INFO - step: 40  loss:  7.3799  memory: 62.87GiB(66.17%)  tps: 6,202  mfu: 36.32%

res = _round_scale_down_to_power_of_2(res)
return res


@torch.no_grad()
Expand Down Expand Up @@ -119,21 +122,35 @@ def tensor_to_amax(

@torch.no_grad()
def tensor_to_scale(
x: torch.Tensor,
hp_tensor: torch.Tensor,
float8_dtype: torch.dtype,
reduce_amax: bool = False,
device_mesh=None,
scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE,
axiswise_dim: Optional[int] = None,
round_scales_to_power_of_2: bool = False,
) -> torch.Tensor:
"""
Compute scaling factor for the given high precision tensor.

Args:
hp_tensor: high precision tensor
float8_dtype: the float8 dtype to use
reduce_amax: whether to reduce the max(abs(hp_tensor)) value across distributed ranks
scaling_granularity: Defines the scaling granularity
axiswise_dim: if axiswise granularity is used, defines the dim to scale across
round_scales_to_power_of_2: if true, round scaling factor down to the nearest power of 2.
"""
amax = tensor_to_amax(
x,
hp_tensor,
reduce_amax,
device_mesh,
scaling_granularity,
axiswise_dim,
)
return amax_to_scale(amax, float8_dtype)
return amax_to_scale(
amax, float8_dtype, round_scales_to_power_of_2=round_scales_to_power_of_2
)


def to_fp8_saturated(x: torch.Tensor, float8_dtype: torch.dtype):
Expand Down Expand Up @@ -266,3 +283,8 @@ def config_has_stateful_scaling(config: Float8LinearConfig) -> bool:
or config.cast_config_weight.scaling_type != ScalingType.DYNAMIC
or config.cast_config_grad_output.scaling_type != ScalingType.DYNAMIC
)


def _round_scale_down_to_power_of_2(scale: torch.Tensor):
assert scale.dtype == torch.float32, "scale must be float32 tensor"
return torch.exp2(torch.floor(torch.log2(scale)))
Loading