Skip to content

Commit

Permalink
fix linter issues
Browse files Browse the repository at this point in the history
  • Loading branch information
danielvegamyhre committed Feb 5, 2025
1 parent f2433b1 commit a9fe17e
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 12 deletions.
12 changes: 6 additions & 6 deletions test/float8/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
import torch.nn as nn

from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_5,
is_sm_at_least_89,
is_sm_at_least_90,
TORCH_VERSION_AT_LEAST_2_5,
)

if not TORCH_VERSION_AT_LEAST_2_5:
Expand All @@ -26,13 +26,13 @@

from torchao.float8.config import (
CastConfig,
e4m3_dtype,
e5m2_dtype,
Float8LinearConfig,
Float8LinearRecipeName,
recipe_name_to_linear_config,
ScalingGranularity,
ScalingType,
e4m3_dtype,
e5m2_dtype,
recipe_name_to_linear_config,
)
from torchao.float8.float8_linear import Float8Linear
from torchao.float8.float8_linear_utils import (
Expand All @@ -48,15 +48,15 @@
from torchao.float8.float8_tensor import (
Float8Tensor,
GemmInputRole,
hp_tensor_and_scale_to_float8,
LinearMMConfig,
ScaledMMConfig,
hp_tensor_and_scale_to_float8,
)
from torchao.float8.float8_utils import (
FP8_TYPES,
compute_error,
config_has_stateful_scaling,
fp8_tensor_statistics,
FP8_TYPES,
tensor_to_scale,
)
from torchao.float8.stateful_float8_linear import StatefulFloat8Linear
Expand Down
14 changes: 8 additions & 6 deletions test/float8/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
import pytest

from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_5,
is_sm_at_least_89,
is_sm_at_least_90,
TORCH_VERSION_AT_LEAST_2_5,
)

if not TORCH_VERSION_AT_LEAST_2_5:
Expand All @@ -29,11 +29,11 @@
from torchao.float8 import _prototype_register_float8_delayed_scaling_inductor_passes
from torchao.float8.config import (
CastConfig,
e4m3_dtype,
Float8LinearConfig,
Float8LinearRecipeName,
recipe_name_to_linear_config,
ScalingType,
e4m3_dtype,
recipe_name_to_linear_config,
)
from torchao.float8.float8_linear import Float8Linear
from torchao.float8.float8_linear_utils import (
Expand Down Expand Up @@ -430,6 +430,7 @@ def test_dynamic_scale_numeric_parity(dtype: torch.dtype, power_of_2_scale: bool
hp_tensor2 = hp_tensor1.detach().clone()
float8_config = Float8LinearConfig(
cast_config_weight=CastConfig(scaling_type=scaling_type_weight),
power_of_2_scale=power_of_2_scale,
)
linear_mm_config = LinearMMConfig(
# output
Expand Down Expand Up @@ -459,15 +460,15 @@ def test_dynamic_scale_numeric_parity(dtype: torch.dtype, power_of_2_scale: bool
e4m3_dtype,
linear_mm_config,
gemm_input_role=GemmInputRole.WEIGHT,
power_of_2_scale=power_of_2_scale,
power_of_2_scale=float8_config.power_of_2_scale,
)
torch._dynamo.reset()
float8_compile = torch.compile(hp_tensor_to_float8_dynamic)(
hp_tensor2,
e4m3_dtype,
linear_mm_config,
gemm_input_role=GemmInputRole.WEIGHT,
power_of_2_scale=power_of_2_scale,
power_of_2_scale=float8_config.power_of_2_scale,
)
assert torch.equal(float8_eager._scale, float8_compile._scale)
assert torch.equal(float8_eager._data, float8_compile._data)
Expand All @@ -479,7 +480,8 @@ def test_dynamic_scale_numeric_parity(dtype: torch.dtype, power_of_2_scale: bool
)
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
def test_delayed_scaling_pattern_replacement(dtype: torch.dtype):
from torch._inductor import config as inductor_config, metrics
from torch._inductor import config as inductor_config
from torch._inductor import metrics

inductor_config.loop_ordering_after_fusion = True

Expand Down

0 comments on commit a9fe17e

Please sign in to comment.