Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Apply vectorization for batch_norm channels last kernel #1306

Closed
wants to merge 2 commits into from

Conversation

xytintel
Copy link
Contributor

No description provided.

@xytintel
Copy link
Contributor Author

For BatchNorm2d ChannelsLast with N, H, W, C = 4, 160, 256, 24

Original

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg      Self XPU    Self XPU %     XPU total  XPU time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                aten::native_batch_norm        32.46%      36.835ms        55.59%      63.086ms     525.716us       9.036ms        94.62%       9.284ms      77.364us           120  
at::native::xpu::BatchNormCollectStatisticsChannelsL...         0.00%       0.000us         0.00%       0.000us       0.000us       5.028ms        52.65%       5.028ms      41.899us           120  
at::native::xpu::BatchNormTransformInputChannelsLast...         0.00%       0.000us         0.00%       0.000us       0.000us       3.707ms        38.81%       3.707ms      30.889us           120  
at::native::xpu::UnrolledElementwiseForMultiOutputsK...         0.00%       0.000us         0.00%       0.000us       0.000us     301.760us         3.16%     301.760us       2.515us           120  
                                             aten::add_        39.55%      44.881ms        43.19%      49.021ms     408.507us     266.400us         2.79%     266.400us       2.220us           120  
at::native::xpu::VectorizedElementwiseKernel<2, at::...         0.00%       0.000us         0.00%       0.000us       0.000us     266.400us         2.79%     266.400us       2.220us           120  
                                            aten::fill_        10.76%      12.215ms        13.91%      15.782ms     131.518us     247.360us         2.59%     247.360us       2.061us           120  
at::native::xpu::VectorizedElementwiseKernel<4, at::...         0.00%       0.000us         0.00%       0.000us       0.000us     247.360us         2.59%     247.360us       2.061us           120  
                                  urEnqueueKernelLaunch        14.10%      16.002ms        14.10%      16.002ms      26.670us       0.000us         0.00%       0.000us       0.000us           600  
                                       aten::batch_norm         0.32%     358.626us        56.81%      64.473ms     537.271us       0.000us         0.00%       9.284ms      77.364us           120  
                           aten::_batch_norm_impl_index         0.59%     672.126us        56.49%      64.114ms     534.282us       0.000us         0.00%       9.284ms      77.364us           120  
                                            aten::empty         0.77%     877.312us         0.77%     877.312us       1.462us       0.000us         0.00%       0.000us       0.000us           600  
                                       aten::empty_like         0.26%     291.281us         0.90%       1.016ms       8.469us       0.000us         0.00%       0.000us       0.000us           120  
                                    aten::empty_strided         0.48%     543.179us         0.64%     725.000us       6.042us       0.000us         0.00%       0.000us       0.000us           120  
                                            aten::zeros         0.31%     347.857us        14.58%      16.543ms     137.856us       0.000us         0.00%     247.360us       2.061us           120  
                                            aten::zero_         0.25%     287.867us        14.16%      16.070ms     133.917us       0.000us         0.00%     247.360us       2.061us           120  
                                       urUSMDeviceAlloc         0.16%     181.821us         0.16%     181.821us     181.821us       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 113.493ms
Self XPU time total: 9.550ms

Optimized

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg      Self XPU    Self XPU %     XPU total  XPU time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                aten::native_batch_norm        31.62%      28.617ms        56.36%      51.009ms     425.076us       8.527ms        94.54%       8.760ms      72.996us           120  
at::native::xpu::BatchNormCollectStatisticsChannelsL...         0.00%       0.000us         0.00%       0.000us       0.000us       4.938ms        54.75%       4.938ms      41.152us           120  
at::native::xpu::BatchNormTransformInputChannelsLast...         0.00%       0.000us         0.00%       0.000us       0.000us       3.291ms        36.49%       3.291ms      27.427us           120  
at::native::xpu::UnrolledElementwiseForMultiOutputsK...         0.00%       0.000us         0.00%       0.000us       0.000us     297.920us         3.30%     297.920us       2.483us           120  
                                             aten::add_        38.60%      34.936ms        42.60%      38.563ms     321.355us     260.480us         2.89%     260.480us       2.171us           120  
at::native::xpu::VectorizedElementwiseKernel<2, at::...         0.00%       0.000us         0.00%       0.000us       0.000us     260.480us         2.89%     260.480us       2.171us           120  
                                            aten::fill_        11.32%      10.248ms        14.73%      13.333ms     111.107us     232.160us         2.57%     232.160us       1.935us           120  
at::native::xpu::VectorizedElementwiseKernel<4, at::...         0.00%       0.000us         0.00%       0.000us       0.000us     232.160us         2.57%     232.160us       1.935us           120  
                                  urEnqueueKernelLaunch        15.67%      14.180ms        15.67%      14.180ms      23.634us       0.000us         0.00%       0.000us       0.000us           600  
                                       aten::batch_norm         0.25%     230.599us        57.40%      51.951ms     432.922us       0.000us         0.00%       8.760ms      72.996us           120  
                           aten::_batch_norm_impl_index         0.52%     469.456us        57.14%      51.720ms     431.001us       0.000us         0.00%       8.760ms      72.996us           120  
                                            aten::empty         0.79%     711.994us         0.79%     711.994us       1.187us       0.000us         0.00%       0.000us       0.000us           600  
                                       aten::empty_like         0.22%     199.079us         0.74%     667.652us       5.564us       0.000us         0.00%       0.000us       0.000us           120  
                                    aten::empty_strided         0.41%     370.126us         0.52%     468.573us       3.905us       0.000us         0.00%       0.000us       0.000us           120  
                                            aten::zeros         0.30%     270.881us        15.36%      13.899ms     115.826us       0.000us         0.00%     232.160us       1.935us           120  
                                            aten::zero_         0.20%     181.815us        14.93%      13.515ms     112.622us       0.000us         0.00%     232.160us       1.935us           120  
                                       urUSMDeviceAlloc         0.11%      98.447us         0.11%      98.447us      98.447us       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 90.513ms
Self XPU time total: 9.020ms

@xytintel
Copy link
Contributor Author

import torch
import torch.nn as nn

N, C, H, W = 4, 24, 160, 256
x = torch.randn(N, C, H, W).to(memory_format=torch.channels_last).bfloat16().xpu()
bn = nn.BatchNorm2d(C).xpu()

prof_xpu = torch.profiler.profile(
    activities=[
    torch.profiler.ProfilerActivity.CPU,
    torch.profiler.ProfilerActivity.XPU],
)
with prof_xpu:
    for i in range(120):
        output = bn(x)
print(prof_xpu.key_averages(group_by_input_shape=True).table(sort_by="self_xpu_time_total", row_limit=100000))
print(output.dtype)

@xytintel
Copy link
Contributor Author

Move to #1317

@xytintel xytintel closed this Jan 24, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant