Skip to content

Commit

Permalink
Neuron SDK Release 2.21.1 (#1089)
Browse files Browse the repository at this point in the history
  • Loading branch information
natemail-aws authored Jan 15, 2025
1 parent d853f41 commit ae20816
Show file tree
Hide file tree
Showing 25 changed files with 519 additions and 212 deletions.
13 changes: 11 additions & 2 deletions dlami/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,19 @@ to easily get started on single Neuron instance. Below sections describe the sup

Neuron Multi Framework DLAMI
----------------------------
Neuron Deep Learning AMI (DLAMI) is a multi-framework DLAMI that supports multiple Neuron framework/libraries. Each DLAMI is pre-installed with Neuron drivers and support all Neuron instance types. Each virtual environment that corresponds to a specific Neuron framework/library
comes pre-installed with all the Neuron libraries including Neuron compiler and Neuron run-time needed for you to easily get started.
Neuron Deep Learning AMI (DLAMI) is a multi-framework DLAMI that supports multiple Neuron framework/libraries. Each DLAMI is pre-installed with Neuron drivers and support all Neuron instance types. Each virtual environment that corresponds to a specific Neuron framework/library
comes pre-installed with all the Neuron libraries including Neuron compiler and Neuron runtime needed for you to easily get started.


.. note::

Tensorflow-neuron 2.10 (inf1) released in SDK v2.20.2 is not compatible with the latest runtime in v2.21 SDK.
Code that compiles will face runtime errors with the latest SDK 2.21.1 version.

Neuron team is aware of this issue and it will be fixed in the next minor release.

Please refer to `this page <https://github.com/aws-neuron/aws-neuron-sdk/issues/1071>`_ for more information on the issue and a temporary work-around.

Multi Framework DLAMIs supported
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,83 +218,82 @@ compiled and executed if there are extra mark-steps or functions with
implicit mark-steps. Additionally, more graphs can be generated if there
are different execution paths taken due to control-flows.

Automatic casting of float tensors to BFloat16
----------------------------------------------

With PyTorch Neuron, the default behavior is for torch.float (FP32) and torch.double (FP64) tensors
to be mapped to torch.float in hardware. To reduce memory footprint and improve performance,
torch.float and torch.double tensors can automatically be converted to BFloat16 by setting
the environment variable ``XLA_USE_BF16=1``. Alternatively, torch.float can automatically be converted
to BFloat16 and torch.double converted to FP32 by setting the environment variable ``XLA_DOWNCAST_BF16=1``.

Automatic Mixed-Precision
-------------------------

BF16 mixed-precision using PyTorch Autocast
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

By default, the compiler automatically cast internal FP32 operations to
BF16. You can disable this and allow PyTorch's BF16 mixed-precision to
do the casting. PyTorch's BF16 mixed-precision is achieved by casting
certain operations to operate BF16. We currently use CUDA's list of
operations that can operate in BF16:

.. code:: bash
_convolution
_convolution
_convolution_nogroup
conv1d
conv2d
conv3d
conv_tbc
conv_transpose1d
conv_transpose2d
conv_transpose3d
convolution
cudnn_convolution
cudnn_convolution_transpose
cudnn_convolution
cudnn_convolution_transpose
cudnn_convolution
cudnn_convolution_transpose
prelu
addmm
addmv
addr
matmul
mm
mv
linear
addbmm
baddbmm
bmm
chain_matmul
linalg_multi_dot
Full BF16 with stochastic rounding enabled
------------------------------------------

To enable PyTorch's BF16 mixed-precision, first turn off the Neuron
compiler auto-cast:
Previously, on torch-neuronx 2.1 and earlier, the environmental variables ``XLA_USE_BF16`` or ``XLA_DOWNCAST_BF16`` provided full casting to BF16 with stochastic rounding enabled by default. These environmental variables are deprecated in torch-neuronx 2.5, although still functional with warnings. To replace ``XLA_USE_BF16`` or ``XLA_DOWNCAST_BF16`` with stochastic rounding on Neuron, set ``NEURON_RT_STOCHASTIC_ROUNDING_EN=1`` and use the ``torch.nn.Module.to`` method to cast model floating-point parameters and buffers to data-type BF16 as follows:

.. code:: python
os.environ["NEURON_CC_FLAGS"] = "--auto-cast=none"
os.environ["NEURON_RT_STOCHASTIC_ROUNDING_EN"] = "1"
# model is created
model.to(torch.bfloat16)
Stochastic rounding is needed to enable faster convergence for full BF16 model.

If the loss is to be kept in FP32, initialize it with ``dtype=torch.float`` as follows:

.. code:: python
running_loss = torch.zeros(1, dtype=torch.float).to(device)
Next, overwrite torch.cuda.is_bf16_supported to return True:
Similarly, if the optimizer states are to be kept in FP32, convert the gradients to FP32 before optimizer computations:

.. code:: python
torch.cuda.is_bf16_supported = lambda: True
grad = p.grad.data.float()
Next, per recommendation from official PyTorch documentation, place only
the forward-pass of the training step in the torch.autocast scope:
For a full example, please see the :ref:`PyTorch Neuron BERT Pretraining Tutorial (Data-Parallel) <hf-bert-pretraining-tutorial>`, which has been updated to use ``torch.nn.Module.to`` instead of ``XLA_DOWNCAST_BF16``.

BF16 in GPU-compatible mode without stochastic rounding enabled
---------------------------------------------------------------

Full BF16 training in GPU-compatible mode would enable faster convergence without the need for stochastic rounding, but would require a FP32 copy of weights/parameters to be saved and used in the optimizer. To enable BF16 in GPU-compatible mode without stochastic rounding enabled, use the ``torch.nn.Module.to`` method to cast model floating-point parameters and buffers to data-type bfloat16 as follows without setting ``NEURON_RT_STOCHASTIC_ROUNDING_EN=1``:

.. code:: python
with torch.autocast(dtype=torch.bfloat16, device_type='cuda'):
# model is created
model.to(torch.bfloat16)
In the initializer of the optimizer, for example AdamW, you can add code like the following code snippet to make a FP32 copy of weights:

.. code:: python
# keep a copy of weights in highprec
self.param_groups_highprec = []
for group in self.param_groups:
params = group['params']
param_groups_highprec = [p.data.float() for p in params]
self.param_groups_highprec.append({'params': param_groups_highprec})
In the :ref:`PyTorch Neuron BERT Pretraining Tutorial (Data-Parallel) <hf-bert-pretraining-tutorial>`, this mode can be enabled by pasing ``--optimizer=AdamW_FP32ParamsCopy`` option to ``dp_bert_large_hf_pretrain_hdf5.py`` and setting ``NEURON_RT_STOCHASTIC_ROUNDING_EN=0`` (or leave it unset).

.. _automatic_mixed_precision_autocast:

BF16 automatic mixed precision using PyTorch Autocast
-----------------------------------------------------

By default, the compiler automatically casts internal FP32 operations to
BF16. You can disable this and allow PyTorch's BF16 automatic mixed precision function (``torch.autocast``) to
do the casting of certain operations to operate in BF16.

To enable PyTorch's BF16 mixed-precision, first turn off the Neuron
compiler auto-cast:

.. code:: python
os.environ["NEURON_CC_FLAGS"] = "--auto-cast=none"
Next, per recommendation from official PyTorch `torch.autocast documentation <https://pytorch.org/docs/stable/amp.html#autocasting>`__, place only
the forward-pass of the training step in the ``torch.autocast`` scope with ``xla`` device type:

.. code:: python
with torch.autocast(dtype=torch.bfloat16, device_type='xla'):
# forward pass
The device type is CUDA because we are using CUDA's list of BF16
compatible operations as mentioned above.
The device type is XLA because we are using PyTorch-XLA's autocast backend. The PyTorch-XLA `autocast mode source code <https://github.com/pytorch/xla/blob/master/torch_xla/csrc/autocast_mode.cpp>`_ lists which operations are casted to lower precision BF16 ("lower precision fp cast policy" section), which are maintained in FP32 ("fp32 cast policy"), and which are promoted to the widest input types ("promote" section).

Example showing the original training code snippet:

Expand All @@ -319,7 +318,7 @@ The following shows the training loop modified to use BF16 autocast:
def train_loop_fn(train_loader):
for i, data in enumerate(train_loader):
torch.cuda.is_bf16_supported = lambda: True
with torch.autocast(dtype=torch.bfloat16, device_type='cuda'):
with torch.autocast(dtype=torch.bfloat16, device_type='xla'):
inputs = data[0]
labels = data[3]
outputs = model(inputs, labels=labels)
Expand All @@ -328,7 +327,7 @@ The following shows the training loop modified to use BF16 autocast:
optimizer.step()
xm.mark_step()
For a full example of BF16 mixed-precision, see :ref:`PyTorch Neuron BERT Pretraining Tutorial <hf-bert-pretraining-tutorial>`.
For a full example of BF16 mixed-precision, see :ref:`PyTorch Neuron BERT Pretraining Tutorial (Data-Parallel) <hf-bert-pretraining-tutorial>`.

See official PyTorch documentation for more details about
`torch.autocast <https://pytorch.org/docs/stable/amp.html#autocasting>`__
Expand Down Expand Up @@ -370,6 +369,12 @@ intermediate results such as loss values. In such case, the printing of
lazy tensors should be wrapped using ``xm.add_step_closure()`` to avoid
unnecessary compilation-and-executions.

Aggregate the data transfers between host CPUs and devices
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

For best performance, you may try to aggregate the data transfers between host CPUs and devices.
For example, increasing the value for `batches_per_execution` argument when instantiating ``MpDeviceLoader`` can help increase performance for certain where there's frequent host-device traffic like ViT as described in `a blog <https://towardsdatascience.com/ai-model-optimization-on-aws-inferentia-and-trainium-cfd48e85d5ac>`_. NOTE: Increasing `batches_per_execution` value would delay the mark-step for multiple batches specified by this value, increasing graph size and could lead to out-of-memory (device OOM) error.

Ensure common initial weights across workers
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand All @@ -396,13 +401,6 @@ be loaded using ``serialization.load`` api. More information on this here: `Savi

FAQ
---

What is the difference between Trainium and Inferentia?
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Trainium is an accelerator designed to speed up training, whereas
Inferentia is an accelerator designed to speed up inference.

Debugging and troubleshooting
-----------------------------

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,9 @@ just like standard versions of ``torch-neuronx``.
Building ``torch`` and ``torch-xla`` with C++11 ABI
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

The instructions for building ``torch`` from source is at https://github.com/pytorch/pytorch#from-source
The instructions for building ``torch`` from source are at https://github.com/pytorch/pytorch#from-source

The instructions for building ``torch-xla`` from source is at https://github.com/pytorch/xla/blob/master/CONTRIBUTING.md
The instructions for building ``torch-xla`` from source are at https://github.com/pytorch/xla/blob/master/CONTRIBUTING.md

The following are simplified instructions (subject to change):

Expand Down Expand Up @@ -184,8 +184,8 @@ package file directly and ``unzip`` the wheel:
.. _pytorch-neuronx-cxx11-versioning:

How can I know which ABI torch-neuron is using?
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
How can I know which ABI torch-neuronx is using?
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Packages which use the pre-C++11 ABI have no local identifier and use the
following version scheme:
Expand Down
10 changes: 5 additions & 5 deletions frameworks/torch/torch-neuronx/tutorials/training/bert.rst
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
.. _hf-bert-pretraining-tutorial:

Hugging Face BERT Pretraining Tutorial
======================================
Hugging Face BERT Pretraining Tutorial (Data-Parallel)
======================================================

This tutorial explains how to run Hugging Face BERT-Large model
pretraining on Trainium using PyTorch Neuron.
pretraining on Trainium using PyTorch Neuron and data-parallel mode.

The Hugging Face BERT pretraining example demonstrates the steps
required to perform single-node, multi-accelerator PyTorch model
Expand Down Expand Up @@ -44,7 +44,7 @@ Phase 1 BFloat16 BERT-Large pretraining with AdamW and stochastic rounding
Setting up the training environment on trn1.32xlarge
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

The BERT training script ``dp_bert_large_hf_pretrain_hdf5.py``
The BERT training script ``dp_bert_large_hf_pretrain_hdf5.py`` (`source <https://github.com/aws-neuron/aws-neuron-samples/blob/master/torch-neuronx/training/dp_bert_hf_pretrain/dp_bert_large_hf_pretrain_hdf5.py>`_)
can run on a Trainium instance (trn1.32xlarge) that contains the
appropriate Neuron runtime and Python dependencies.

Expand All @@ -60,7 +60,7 @@ For all the commands below, make sure you are in the virtual environment that yo
source ~/aws_neuron_venv_pytorch/bin/activate
Next, clone the AWS Neuron Samples repository and install requirements in the BERT tutorial directory ``aws-neuron-samples/torch-neuronx/training/dp_bert_hf_pretrain``:
Next, clone the `AWS Neuron Samples repository <https://github.com/aws-neuron/aws-neuron-samples/>`_ and install requirements in the BERT tutorial directory ``aws-neuron-samples/torch-neuronx/training/dp_bert_hf_pretrain`` (`directory link <https://github.com/aws-neuron/aws-neuron-samples/tree/master/torch-neuronx/training/dp_bert_hf_pretrain>`_):

.. code:: shell
Expand Down
4 changes: 2 additions & 2 deletions general/appnotes/torch-neuronx/introducing-pytorch-2-x.rst
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ To migrate the training scripts from PyTorch NeuronX 2.1 to PyTorch NeuronX 2.5,

``xm`` below refers to ``torch_xla.core.xla_model`` and ``xr`` refers to ``torch_xla.runtime``

* The environment variables ``XLA_DOWNCAST_BF16`` and ``XLA_USE_BF16`` are deprecated (warning when used). Please switch to automatic mixed-precision or use ``model.to(torch.bfloat16)`` command to convert model to BF16 format. (see :ref:`<migration_from_xla_downcast_bf16>`)
* The environment variables ``XLA_DOWNCAST_BF16`` and ``XLA_USE_BF16`` are deprecated (warning when used). Please switch to automatic mixed-precision or use ``model.to(torch.bfloat16)`` command to convert model to BF16 format. (see :ref:`migration_from_xla_downcast_bf16`)
* The ``torch_xla.experimental.pjrt`` module which was replaced by ``torch_xla.runtime`` in Torch-XLA 2.1, has been removed in Torch-XLA 2.5. Users should now utilize the ``torch_xla.runtime`` module as a replacement.
* ``torch_xla.runtime.using_pjrt`` is removed because PJRT is the sole Torch-XLA runtime.
* ``xm.all_reduce`` no longer operates in-place for single tensors. To fix this, please convert the single tensor to an array (e.g.. ``[single_tensor]``) or assign the output of ``xm.all_reduce`` to a variable.
Expand Down Expand Up @@ -108,7 +108,7 @@ This is a warning that ``torch_xla.core.xla_model.xrt_world_size()`` will be rem
WARNING:torch_xla.core.xla_model.xla_model.get_ordinal() will be removed in release 2.7. is deprecated. Use torch_xla.runtime.global_ordinal instead.
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

This is a warning that ``torch_xla.core.xla_model.xla_model.get_ordinal() `` will be removed in a future release. Please switch to using ``torch_xla.runtime.global_ordinal`` instead.
This is a warning that ``torch_xla.core.xla_model.xla_model.get_ordinal()`` will be removed in a future release. Please switch to using ``torch_xla.runtime.global_ordinal`` instead.


AttributeError: module 'torch_xla.runtime' has no attribute 'using_pjrt'
Expand Down
Loading

0 comments on commit ae20816

Please sign in to comment.