Skip to content

Commit

Permalink
[release/2.3] [ROCm] Correct numerical issues in layer norm backwards…
Browse files Browse the repository at this point in the history
… kernel (pytorch#140259) (#1766)

It was raised that the backwards layer norm on AMD was slightly off the
accuracy of the equivalent NVIDIA implementation.

On AMD we call into a helper kernel `cuLoadWriteStridedInputs` which
processes strided input and accumulates the partial gradients into
shared memory.

In this kernel (pytorch#87635) we
truncated `mean` and `rstd` from T_ACC type to T which causes numerical
issues in the warp buffers created in this kernel. This PR will use the
correct accumulator type for mean and rstd.

Note: Only AMD call into this call stack for backwards layer norm, so
this was not an issue for NV.

Pull Request resolved: pytorch#140259
Approved by: https://github.com/jianyuh

(cherry picked from commit 001f736)
  • Loading branch information
jataylo authored Dec 6, 2024
1 parent 7870ca9 commit a7b07f9
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions aten/src/ATen/native/cuda/layer_norm_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -840,8 +840,8 @@ void cuLoadWriteStridedInputs(
{
int i1 = i1_block+thr_load_row_off;
if (i1 < i1_end) {
T curr_mean = mean[i1];
T curr_rstd = rstd[i1];
T_ACC curr_mean = mean[i1];
T_ACC curr_rstd = rstd[i1];
for (int k = 0; k < blockDim.y; ++k) {
int i2 = i2_off + k;
int load_idx = i1*N+i2;
Expand Down

0 comments on commit a7b07f9

Please sign in to comment.