Skip to content

Commit

Permalink
Support mixed MX element dtype in mx_mm function and MXLinear. (#…
Browse files Browse the repository at this point in the history
…1667)

* Support mixed MX element dtype in `mx_mm` function.

Following the MXFP and quantization literature, it is useful to support different element dtypes for activations, weights and gradients.

* Support (input, weight, gradient) element dtype tuple in MXLinear layer factory method.

Passing a tuple of 3 element dtypes avoids introducing a breaking change in the current interface
of `MXLinear` and `swap_linear_with_mx_linear`.

Some additional unit test coverage has been added on MXLinear.

* Using default `elem_dtype` argument and optional weight/grad overrides.
  • Loading branch information
balancap authored Feb 6, 2025
1 parent 867a91f commit 1d75c8f
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 26 deletions.
32 changes: 26 additions & 6 deletions test/prototype/mx_formats/test_mx_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

import copy
import itertools

import pytest
import torch
Expand Down Expand Up @@ -41,13 +42,16 @@ def run_around_tests():


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
@pytest.mark.parametrize(
"elem_dtype", itertools.product(SUPPORTED_ELEM_DTYPES, repeat=3)
)
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("input_shape", [(4, 8), (1, 4, 8), (1, 1, 4, 8)])
def test_linear_eager(elem_dtype, bias, input_shape):
"""
Smoke test for training linear module with mx weight
"""
# elem_dtype is a tuple of (input, weight, gradient) dtypes.
grad_shape = list(input_shape)
grad_shape[-1] = 6

Expand All @@ -56,7 +60,7 @@ def test_linear_eager(elem_dtype, bias, input_shape):
)
m_mx = copy.deepcopy(m)
block_size = 2
swap_linear_with_mx_linear(m_mx, elem_dtype, block_size)
swap_linear_with_mx_linear(m_mx, *elem_dtype, block_size=block_size)

x_ref = torch.randn(*input_shape, device="cuda").requires_grad_()
x = copy.deepcopy(x_ref)
Expand All @@ -72,7 +76,7 @@ def test_linear_eager(elem_dtype, bias, input_shape):
w_g_sqnr = compute_error(m[0].weight.grad, getattr(m_mx, "0").weight.grad)
x_g_sqnr = compute_error(x_ref.grad, x.grad)

if elem_dtype is torch.float8_e4m3fn:
if elem_dtype == (torch.float8_e4m3fn, torch.float8_e4m3fn, torch.float8_e4m3fn):
assert y_sqnr >= 18.0
assert w_g_sqnr >= 18.0
assert x_g_sqnr >= 12.0
Expand All @@ -94,7 +98,7 @@ def test_activation_checkpointing():
nn.Linear(6, 6, bias=True, device="cuda"),
)
block_size = 2
swap_linear_with_mx_linear(m, elem_dtype, block_size)
swap_linear_with_mx_linear(m, elem_dtype, block_size=block_size)

x = torch.randn(*input_shape, device="cuda").requires_grad_()
g = torch.randn(*grad_shape, device="cuda")
Expand Down Expand Up @@ -130,7 +134,7 @@ def test_linear_compile(elem_dtype, bias, use_autocast):
nn.Linear(K, N, bias=bias, device="cuda"),
)
block_size = 2
swap_linear_with_mx_linear(m_mx, elem_dtype, block_size)
swap_linear_with_mx_linear(m_mx, elem_dtype, block_size=block_size)
m_mx_c = copy.deepcopy(m_mx)
m_mx_c = torch.compile(m_mx_c, fullgraph=True, backend="inductor")

Expand Down Expand Up @@ -219,6 +223,20 @@ def test_inference_compile_simple(elem_dtype):
assert sqnr >= 13.5


def test_mx_linear_input_weight_gradient_dtypes():
m = nn.Sequential(nn.Linear(32, 32))
swap_linear_with_mx_linear(m, *SUPPORTED_ELEM_DTYPES[:3], block_size=32)
assert m[0].in_elem_dtype == SUPPORTED_ELEM_DTYPES[0]
assert m[0].w_elem_dtype == SUPPORTED_ELEM_DTYPES[1]
assert m[0].grad_elem_dtype == SUPPORTED_ELEM_DTYPES[2]

m = nn.Sequential(nn.Linear(32, 32))
swap_linear_with_mx_linear(m, torch.float8_e4m3fn, block_size=32)
assert m[0].in_elem_dtype == torch.float8_e4m3fn
assert m[0].w_elem_dtype == torch.float8_e4m3fn
assert m[0].grad_elem_dtype == torch.float8_e4m3fn


def test_filter_fn():
m1 = nn.Sequential(
nn.Linear(32, 32),
Expand All @@ -227,7 +245,9 @@ def test_filter_fn():
m2 = copy.deepcopy(m1)
filter_fn = lambda mod, fqn: fqn != "1" # noqa: E731

swap_linear_with_mx_linear(m1, torch.float8_e4m3fn, 32, filter_fn)
swap_linear_with_mx_linear(
m1, torch.float8_e4m3fn, block_size=32, filter_fn=filter_fn
)
assert type(m1[0]) == MXLinear
assert type(m1[1]) == torch.nn.Linear

Expand Down
9 changes: 4 additions & 5 deletions torchao/prototype/mx_formats/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

This is a POC of training and inference with tensors in the MX format from the OCP spec (https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) in native PyTorch.

Note that the current version of the code is written for readability and
numerical correctness and not yet for optimal performance. We welcome
Note that the current version of the code is written for readability and
numerical correctness and not yet for optimal performance. We welcome
contributions on performance improvements.

Note that there are no BC guarantees at the moment and we plan to evolve
Expand Down Expand Up @@ -44,8 +44,7 @@ from torchao.prototype.mx_formats.mx_linear import swap_linear_with_mx_linear

m = torch.nn.Sequential(torch.nn.Linear(32, 32)).cuda()
elem_dtype = torch.float8_e4m3fn
block_size = 32
swap_linear_with_mx_linear(m, elem_dtype, block_size)
swap_linear_with_mx_linear(m, elem_dtype, block_size=32)

# training loop (not shown)
```
Expand Down Expand Up @@ -93,7 +92,7 @@ python torchao/prototype/mx_formats/benchmarks/bench_qdq.py

## floating point format convenience functions

We have a convenience script which summarizes the various properties of
We have a convenience script which summarizes the various properties of
floating point formats:

```bash
Expand Down
73 changes: 58 additions & 15 deletions torchao/prototype/mx_formats/mx_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,25 +23,31 @@ class mx_mm(torch.autograd.Function):
# 1. input @ weight_t = output (forward pass)
# 2. grad_output @ weight = grad_input (backward pass)
# 3. input_t @ grad_output = grad_weight (backward pass)
#
# input, weight and grad_output can have each their own MX element dtype.

@staticmethod
def forward(
ctx,
input_hp: torch.Tensor,
weight_hp: torch.Tensor,
elem_dtype: Any,
in_elem_dtype: Any,
w_elem_dtype: Any,
grad_elem_dtype: Any,
block_size: int,
):
ctx.save_for_backward(input_hp, weight_hp)
ctx.elem_dtype = elem_dtype
ctx.in_elem_dtype = in_elem_dtype
ctx.w_elem_dtype = w_elem_dtype
ctx.grad_elem_dtype = grad_elem_dtype
ctx.block_size = block_size

# input @ weight_t = output
input_orig_shape = input_hp.shape
input_hp_r = input_hp.reshape(-1, input_orig_shape[-1])

input_mx_r_dim0 = MXTensor.to_mx(input_hp_r, elem_dtype, block_size)
weight_mx_dim0 = MXTensor.to_mx(weight_hp, elem_dtype, block_size)
input_mx_r_dim0 = MXTensor.to_mx(input_hp_r, in_elem_dtype, block_size)
weight_mx_dim0 = MXTensor.to_mx(weight_hp, w_elem_dtype, block_size)
output = torch.mm(input_mx_r_dim0, weight_mx_dim0.t())
output = output.reshape(*input_orig_shape[:-1], output.shape[-1])

Expand All @@ -51,7 +57,9 @@ def forward(
def backward(ctx, grad_output_hp: torch.Tensor):
input_hp, weight_hp = ctx.saved_tensors
weight_hp_t_c = weight_hp.t().contiguous()
elem_dtype = ctx.elem_dtype
in_elem_dtype = ctx.in_elem_dtype
w_elem_dtype = ctx.w_elem_dtype
grad_elem_dtype = ctx.grad_elem_dtype
block_size = ctx.block_size

grad_output_orig_shape = grad_output_hp.shape
Expand All @@ -61,24 +69,26 @@ def backward(ctx, grad_output_hp: torch.Tensor):
input_hp_r = input_hp.reshape(-1, input_hp_orig_shape[-1])

# grad_output @ weight = grad_input
grad_output_mx_dim0 = MXTensor.to_mx(grad_output_hp_r, elem_dtype, block_size)
weight_mx_dim1 = MXTensor.to_mx(weight_hp_t_c, elem_dtype, block_size)
grad_output_mx_dim0 = MXTensor.to_mx(
grad_output_hp_r, grad_elem_dtype, block_size
)
weight_mx_dim1 = MXTensor.to_mx(weight_hp_t_c, w_elem_dtype, block_size)
grad_input = torch.mm(grad_output_mx_dim0, weight_mx_dim1.t())
grad_input = grad_input.reshape(
*grad_output_orig_shape[:-1], grad_input.shape[-1]
)

# input_t @ grad_output = grad_weight
grad_output_mx_dim1 = MXTensor.to_mx(
grad_output_hp_r.t().contiguous(), elem_dtype, block_size
grad_output_hp_r.t().contiguous(), grad_elem_dtype, block_size
)
input_t_mx_dim0_tmp = MXTensor.to_mx(
input_hp_r.t().contiguous(), elem_dtype, block_size
input_hp_r.t().contiguous(), in_elem_dtype, block_size
)
input_t_mx_dim0 = input_t_mx_dim0_tmp.t()
grad_weight = torch.mm(grad_output_mx_dim1, input_t_mx_dim0)

return grad_input, grad_weight, None, None
return grad_input, grad_weight, None, None, None, None


class MXLinear(torch.nn.Linear):
Expand All @@ -87,13 +97,25 @@ class MXLinear(torch.nn.Linear):
matmul is emulated since there is no hardware support yet. Activations,
weights and grads are casted to MX and back to high precision for each
matmul.
Input, weight and grad_output can have each their own MX element dtype.
"""

@classmethod
@torch.no_grad()
def from_float(cls, mod, elem_dtype, block_size):
def from_float(
cls,
mod,
elem_dtype,
elem_dtype_weight_override=None,
elem_dtype_grad_output_override=None,
*,
block_size=32,
):
mod.__class__ = MXLinear
mod.elem_dtype = elem_dtype
mod.in_elem_dtype = elem_dtype
mod.w_elem_dtype = elem_dtype_weight_override or elem_dtype
mod.grad_elem_dtype = elem_dtype_grad_output_override or elem_dtype
mod.block_size = block_size
return mod

Expand All @@ -106,7 +128,14 @@ def forward(self, x):
else:
w = self.weight

y = mx_mm.apply(x, w, self.elem_dtype, self.block_size)
y = mx_mm.apply(
x,
w,
self.in_elem_dtype,
self.w_elem_dtype,
self.grad_elem_dtype,
self.block_size,
)
if self.bias is not None:
y = y + self.bias
return y
Expand Down Expand Up @@ -172,7 +201,15 @@ def _is_linear(mod, fqn):
return isinstance(mod, torch.nn.Linear)


def swap_linear_with_mx_linear(model, elem_dtype, block_size, filter_fn=None):
def swap_linear_with_mx_linear(
model,
elem_dtype,
elem_dtype_weight_override=None,
elem_dtype_grad_output_override=None,
*,
block_size=32,
filter_fn=None,
):
if filter_fn is None:
combined_filter_fn = _is_linear
else:
Expand All @@ -183,7 +220,13 @@ def __fn(mod, fqn):
combined_filter_fn = __fn
replace_with_custom_fn_if_matches_filter(
model,
lambda mod: MXLinear.from_float(mod, elem_dtype, block_size),
lambda mod: MXLinear.from_float(
mod,
elem_dtype,
elem_dtype_weight_override,
elem_dtype_grad_output_override,
block_size=block_size,
),
combined_filter_fn,
)

Expand Down

0 comments on commit 1d75c8f

Please sign in to comment.