From 1d4d5dc05b56dc4ddfeb7093d8f2848d636c9914 Mon Sep 17 00:00:00 2001 From: Roger Feng Date: Fri, 16 Dec 2022 14:44:43 +0800 Subject: [PATCH] Initial code base MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Below features are verified with GPT-2 3.6 billion parameters Large Language Model training example with 6 Intel GPUs: - DeepSpeed ZeRO stage 2 - Model parallel size is 1 3.6B parameters Large Language Model is composed of following components Intel® Extension for DeepSpeed* Megatron LM Models with 3.6B parameters configuration Microsoft® DeepSpeed with pull request to support Intel®GPU Signed-off-by: Roger Feng --- LICENSE | 21 + MANIFEST.in | 2 + README.md | 18 + examples/LICENSE | 21 + examples/Megatron-LM/LICENSE | 231 ++ examples/Megatron-LM/NOTICE.txt | 253 ++ examples/Megatron-LM/README.md | 123 + examples/Megatron-LM/arguments.py | 379 +++ examples/Megatron-LM/bf16/__init__.py | 16 + examples/Megatron-LM/bf16/bf16.py | 70 + examples/Megatron-LM/configure_data.py | 246 ++ examples/Megatron-LM/data_utils/__init__.py | 122 + examples/Megatron-LM/data_utils/corpora.py | 75 + examples/Megatron-LM/data_utils/datasets.py | 884 ++++++ examples/Megatron-LM/data_utils/file_utils.py | 253 ++ .../Megatron-LM/data_utils/lazy_loader.py | 196 ++ examples/Megatron-LM/data_utils/samplers.py | 139 + examples/Megatron-LM/data_utils/tf_dl.py | 121 + .../Megatron-LM/data_utils/tokenization.py | 890 ++++++ .../data_utils/tokenization_gpt2.py | 304 ++ examples/Megatron-LM/data_utils/wordpiece.py | 390 +++ examples/Megatron-LM/detokenizer.py | 60 + examples/Megatron-LM/docker/Dockerfile | 38 + examples/Megatron-LM/docker/README.md | 1 + examples/Megatron-LM/docker/requirements.txt | 10 + examples/Megatron-LM/evaluate_gpt2.py | 554 ++++ examples/Megatron-LM/fp16/__init__.py | 30 + examples/Megatron-LM/fp16/fp16.py | 634 +++++ examples/Megatron-LM/fp16/fp16util.py | 204 ++ examples/Megatron-LM/fp16/loss_scaler.py | 238 ++ examples/Megatron-LM/generate_samples.py | 281 ++ examples/Megatron-LM/gpt2_data_loader.py | 211 ++ examples/Megatron-LM/learning_rates.py | 76 + examples/Megatron-LM/model/__init__.py | 20 + examples/Megatron-LM/model/distributed.py | 111 + examples/Megatron-LM/model/gpt2_modeling.py | 125 + examples/Megatron-LM/model/model.py | 90 + examples/Megatron-LM/model/modeling.py | 1382 +++++++++ examples/Megatron-LM/mpu/__init__.py | 53 + examples/Megatron-LM/mpu/cross_entropy.py | 109 + examples/Megatron-LM/mpu/data.py | 117 + examples/Megatron-LM/mpu/grads.py | 75 + examples/Megatron-LM/mpu/initialize.py | 135 + examples/Megatron-LM/mpu/layers.py | 331 +++ examples/Megatron-LM/mpu/mappings.py | 141 + examples/Megatron-LM/mpu/random.py | 387 +++ examples/Megatron-LM/mpu/tests/__init__.py | 0 examples/Megatron-LM/mpu/tests/commons.py | 83 + .../mpu/tests/test_cross_entropy.py | 111 + examples/Megatron-LM/mpu/tests/test_data.py | 92 + .../Megatron-LM/mpu/tests/test_initialize.py | 98 + examples/Megatron-LM/mpu/tests/test_layers.py | 529 ++++ examples/Megatron-LM/mpu/tests/test_random.py | 208 ++ examples/Megatron-LM/mpu/transformer.py | 650 +++++ examples/Megatron-LM/mpu/utils.py | 70 + examples/Megatron-LM/openwebtext/README.md | 46 + .../Megatron-LM/openwebtext/blacklist_urls.py | 312 ++ .../openwebtext/cleanup_dataset.py | 115 + .../openwebtext/find_duplicates.py | 100 + .../openwebtext/group_duplicates_url.py | 90 + .../openwebtext/make_gpt2_dataset.py | 77 + .../openwebtext/make_gpt2_sizes.py | 38 + .../Megatron-LM/openwebtext/merge_jsons.py | 55 + .../openwebtext/remove_group_duplicates.py | 69 + .../openwebtext/run_make_gpt2_dataset.sh | 8 + examples/Megatron-LM/openwebtext/tokenizer.py | 36 + examples/Megatron-LM/pretrain_bert.py | 586 ++++ examples/Megatron-LM/pretrain_gpt2.py | 771 +++++ examples/Megatron-LM/requirements.txt | 8 + .../scripts/ds_checkpoint_check.sh | 51 + .../scripts/ds_zero-offload_10B_config.json | 33 + ...ffload_10B_pretrain_gpt2_model_parallel.sh | 49 + .../scripts/ds_zero-offload_config.json | 31 + .../scripts/ds_zero-offload_config_bf16.json | 44 + ...ro-offload_pretrain_gpt2_model_parallel.sh | 48 + ...fload_pretrain_gpt2_model_parallel_bf16.sh | 48 + .../Megatron-LM/scripts/ds_zero2_config.json | 29 + .../scripts/ds_zero2_config_bf16.json | 43 + .../ds_zero2_pretrain_gpt2_model_parallel.sh | 48 + examples/Megatron-LM/scripts/generate_text.sh | 29 + examples/Megatron-LM/scripts/gpt-3.6b-fp16.sh | 49 + .../Megatron-LM/scripts/gpt-3.6b-offload.sh | 49 + examples/Megatron-LM/scripts/gpt-3.6b.sh | 61 + examples/Megatron-LM/scripts/mp2_256m.json | 41 + examples/Megatron-LM/scripts/mp2_256m.sh | 48 + .../scripts/presplit_sentences_json.py | 27 + examples/Megatron-LM/scripts/pretrain_bert.sh | 34 + .../scripts/pretrain_bert_distributed.sh | 43 + .../scripts/pretrain_bert_model_parallel.sh | 44 + .../scripts/pretrain_bert_sentencepiece.sh | 35 + .../pretrain_bert_tfrecords_distributed.sh | 44 + examples/Megatron-LM/scripts/pretrain_gpt2.sh | 34 + .../scripts/pretrain_gpt2_distributed.sh | 42 + .../scripts/pretrain_gpt2_model_parallel.sh | 42 + examples/Megatron-LM/scripts/run_gpt2_eval.py | 89 + examples/Megatron-LM/scripts/split_json.py | 119 + examples/Megatron-LM/utils.py | 411 +++ examples/README.md | 8 + examples/autotuning/.gitignore | 4 + examples/autotuning/README.md | 3 + examples/autotuning/hf/README.md | 62 + examples/autotuning/hf/bert-base/README.md | 58 + .../hf/bert-base/ds_config_tune.json | 12 + examples/autotuning/hf/bert-base/test_tune.sh | 114 + examples/autotuning/hf/bert-large/README.md | 55 + .../hf/bert-large/ds_config_tune.json | 11 + .../autotuning/hf/bert-large/test_tune.sh | 114 + examples/autotuning/hf/deberta/README.md | 72 + .../hf/deberta/ds_config_fp16_tune.json | 16 + examples/autotuning/hf/deberta/test_tune.sh | 127 + examples/autotuning/hf/distilbert/README.md | 69 + .../hf/distilbert/ds_config_tune.json | 12 + .../autotuning/hf/distilbert/test_tune.sh | 119 + .../hf/dsconfigs/ds_config_fp16_tune.json | 15 + .../hf/dsconfigs/ds_config_fp16_z0.json | 9 + .../hf/dsconfigs/ds_config_fp16_z1.json | 9 + .../hf/dsconfigs/ds_config_fp16_z2.json | 9 + .../hf/dsconfigs/ds_config_fp16_z3.json | 9 + .../hf/dsconfigs/ds_config_tune.json | 12 + .../autotuning/hf/dsconfigs/ds_config_z0.json | 6 + .../autotuning/hf/dsconfigs/ds_config_z1.json | 6 + .../autotuning/hf/dsconfigs/ds_config_z2.json | 6 + .../autotuning/hf/dsconfigs/ds_config_z3.json | 6 + examples/autotuning/hf/gpt2-large/README.md | 59 + .../autotuning/hf/gpt2-large/test_tune.sh | 132 + examples/autotuning/hf/gpt2-medium/README.md | 57 + .../autotuning/hf/gpt2-medium/test_tune.sh | 142 + examples/autotuning/hf/gpt2-xl/README.md | 56 + examples/autotuning/hf/gpt2-xl/test_tune.sh | 142 + examples/autotuning/hf/gpt2/README.md | 59 + examples/autotuning/hf/gpt2/test_tune.sh | 133 + examples/pipeline_parallelism/alexnet.py | 47 + examples/pipeline_parallelism/ds_config.json | 19 + examples/pipeline_parallelism/run.sh | 3 + examples/pipeline_parallelism/train.py | 159 ++ intel_extension_for_deepspeed/__init__.py | 1 + .../op_builder/__init__.py | 7 + .../op_builder/builder.py | 61 + .../op_builder/cpu_adagrad.py | 21 + .../op_builder/cpu_adam.py | 31 + .../op_builder/csrc/adam/sycl/cpu_adam.dp.cpp | 707 +++++ .../csrc/adam/sycl/custom_sycl_kernel.dp.cpp | 26 + .../csrc/adam/sycl/fused_adam_frontend.cpp | 20 + .../csrc/adam/sycl/multi_tensor_adam.dp.cpp | 215 ++ .../csrc/includes/multi_tensor_apply.dp.hpp | 174 ++ .../op_builder/csrc/includes/sycl/Timer.hpp | 41 + .../op_builder/csrc/includes/sycl/common.hpp | 30 + .../op_builder/csrc/includes/sycl/context.hpp | 146 + .../csrc/includes/sycl/cpu_adam.hpp | 164 ++ .../csrc/includes/sycl/custom_sycl_layers.hpp | 283 ++ .../op_builder/csrc/includes/sycl/dropout.hpp | 76 + .../includes/sycl/ds_transformer_sycl.hpp | 184 ++ .../csrc/includes/sycl/feed_forward.hpp | 125 + .../op_builder/csrc/includes/sycl/gelu.hpp | 35 + .../csrc/includes/sycl/gemm_test.hpp | 297 ++ .../csrc/includes/sycl/general_kernels.hpp | 43 + .../csrc/includes/sycl/normalize_layer.hpp | 197 ++ .../csrc/includes/sycl/onednn_wrappers.hpp | 31 + .../csrc/includes/sycl/onemkl_wrappers.hpp | 65 + .../op_builder/csrc/includes/sycl/softmax.hpp | 57 + .../csrc/includes/sycl/strided_batch_gemm.hpp | 268 ++ .../csrc/includes/sycl/type_shim.hpp | 112 + .../transformer/sycl/dropout_kernels.dp.cpp | 1194 ++++++++ .../transformer/sycl/ds_dropout_sycl.dp.cpp | 92 + .../sycl/ds_feedforward_sycl.dp.cpp | 81 + .../csrc/transformer/sycl/ds_gelu_sycl.dp.cpp | 39 + .../sycl/ds_layer_reorder_sycl.dp.cpp | 122 + .../transformer/sycl/ds_normalize_sycl.dp.cpp | 151 + .../transformer/sycl/ds_softmax_sycl.dp.cpp | 44 + .../sycl/ds_stridedbatchgemm_sycl.dp.cpp | 95 + .../sycl/ds_transformer_sycl.dp.cpp | 1081 +++++++ .../csrc/transformer/sycl/gelu_kernels.dp.cpp | 447 +++ .../transformer/sycl/general_kernels.dp.cpp | 540 ++++ .../transformer/sycl/normalize_kernels.dp.cpp | 2529 +++++++++++++++++ .../transformer/sycl/onednn_wrappers.dp.cpp | 124 + .../transformer/sycl/onemkl_wrappers.dp.cpp | 144 + .../transformer/sycl/softmax_kernels.dp.cpp | 854 ++++++ .../transformer/sycl/transform_kernels.dp.cpp | 891 ++++++ .../op_builder/fused_adam.py | 31 + .../op_builder/quantizer.py | 19 + .../op_builder/transformer.py | 47 + .../op_builder/utils.py | 24 + .../xpu_accelerator.py | 190 ++ setup.py | 37 + 184 files changed, 30128 insertions(+) create mode 100644 LICENSE create mode 100644 MANIFEST.in create mode 100644 README.md create mode 100644 examples/LICENSE create mode 100644 examples/Megatron-LM/LICENSE create mode 100644 examples/Megatron-LM/NOTICE.txt create mode 100644 examples/Megatron-LM/README.md create mode 100755 examples/Megatron-LM/arguments.py create mode 100644 examples/Megatron-LM/bf16/__init__.py create mode 100755 examples/Megatron-LM/bf16/bf16.py create mode 100644 examples/Megatron-LM/configure_data.py create mode 100644 examples/Megatron-LM/data_utils/__init__.py create mode 100755 examples/Megatron-LM/data_utils/corpora.py create mode 100644 examples/Megatron-LM/data_utils/datasets.py create mode 100755 examples/Megatron-LM/data_utils/file_utils.py create mode 100644 examples/Megatron-LM/data_utils/lazy_loader.py create mode 100644 examples/Megatron-LM/data_utils/samplers.py create mode 100755 examples/Megatron-LM/data_utils/tf_dl.py create mode 100755 examples/Megatron-LM/data_utils/tokenization.py create mode 100644 examples/Megatron-LM/data_utils/tokenization_gpt2.py create mode 100755 examples/Megatron-LM/data_utils/wordpiece.py create mode 100755 examples/Megatron-LM/detokenizer.py create mode 100644 examples/Megatron-LM/docker/Dockerfile create mode 100644 examples/Megatron-LM/docker/README.md create mode 100644 examples/Megatron-LM/docker/requirements.txt create mode 100755 examples/Megatron-LM/evaluate_gpt2.py create mode 100644 examples/Megatron-LM/fp16/__init__.py create mode 100755 examples/Megatron-LM/fp16/fp16.py create mode 100644 examples/Megatron-LM/fp16/fp16util.py create mode 100755 examples/Megatron-LM/fp16/loss_scaler.py create mode 100755 examples/Megatron-LM/generate_samples.py create mode 100644 examples/Megatron-LM/gpt2_data_loader.py create mode 100644 examples/Megatron-LM/learning_rates.py create mode 100755 examples/Megatron-LM/model/__init__.py create mode 100755 examples/Megatron-LM/model/distributed.py create mode 100644 examples/Megatron-LM/model/gpt2_modeling.py create mode 100755 examples/Megatron-LM/model/model.py create mode 100644 examples/Megatron-LM/model/modeling.py create mode 100755 examples/Megatron-LM/mpu/__init__.py create mode 100644 examples/Megatron-LM/mpu/cross_entropy.py create mode 100644 examples/Megatron-LM/mpu/data.py create mode 100644 examples/Megatron-LM/mpu/grads.py create mode 100644 examples/Megatron-LM/mpu/initialize.py create mode 100644 examples/Megatron-LM/mpu/layers.py create mode 100644 examples/Megatron-LM/mpu/mappings.py create mode 100755 examples/Megatron-LM/mpu/random.py create mode 100644 examples/Megatron-LM/mpu/tests/__init__.py create mode 100644 examples/Megatron-LM/mpu/tests/commons.py create mode 100644 examples/Megatron-LM/mpu/tests/test_cross_entropy.py create mode 100644 examples/Megatron-LM/mpu/tests/test_data.py create mode 100644 examples/Megatron-LM/mpu/tests/test_initialize.py create mode 100644 examples/Megatron-LM/mpu/tests/test_layers.py create mode 100644 examples/Megatron-LM/mpu/tests/test_random.py create mode 100755 examples/Megatron-LM/mpu/transformer.py create mode 100644 examples/Megatron-LM/mpu/utils.py create mode 100644 examples/Megatron-LM/openwebtext/README.md create mode 100644 examples/Megatron-LM/openwebtext/blacklist_urls.py create mode 100644 examples/Megatron-LM/openwebtext/cleanup_dataset.py create mode 100644 examples/Megatron-LM/openwebtext/find_duplicates.py create mode 100644 examples/Megatron-LM/openwebtext/group_duplicates_url.py create mode 100644 examples/Megatron-LM/openwebtext/make_gpt2_dataset.py create mode 100644 examples/Megatron-LM/openwebtext/make_gpt2_sizes.py create mode 100644 examples/Megatron-LM/openwebtext/merge_jsons.py create mode 100644 examples/Megatron-LM/openwebtext/remove_group_duplicates.py create mode 100755 examples/Megatron-LM/openwebtext/run_make_gpt2_dataset.sh create mode 100644 examples/Megatron-LM/openwebtext/tokenizer.py create mode 100755 examples/Megatron-LM/pretrain_bert.py create mode 100755 examples/Megatron-LM/pretrain_gpt2.py create mode 100644 examples/Megatron-LM/requirements.txt create mode 100644 examples/Megatron-LM/scripts/ds_checkpoint_check.sh create mode 100755 examples/Megatron-LM/scripts/ds_zero-offload_10B_config.json create mode 100755 examples/Megatron-LM/scripts/ds_zero-offload_10B_pretrain_gpt2_model_parallel.sh create mode 100755 examples/Megatron-LM/scripts/ds_zero-offload_config.json create mode 100644 examples/Megatron-LM/scripts/ds_zero-offload_config_bf16.json create mode 100755 examples/Megatron-LM/scripts/ds_zero-offload_pretrain_gpt2_model_parallel.sh create mode 100644 examples/Megatron-LM/scripts/ds_zero-offload_pretrain_gpt2_model_parallel_bf16.sh create mode 100755 examples/Megatron-LM/scripts/ds_zero2_config.json create mode 100755 examples/Megatron-LM/scripts/ds_zero2_config_bf16.json create mode 100755 examples/Megatron-LM/scripts/ds_zero2_pretrain_gpt2_model_parallel.sh create mode 100755 examples/Megatron-LM/scripts/generate_text.sh create mode 100755 examples/Megatron-LM/scripts/gpt-3.6b-fp16.sh create mode 100644 examples/Megatron-LM/scripts/gpt-3.6b-offload.sh create mode 100755 examples/Megatron-LM/scripts/gpt-3.6b.sh create mode 100755 examples/Megatron-LM/scripts/mp2_256m.json create mode 100755 examples/Megatron-LM/scripts/mp2_256m.sh create mode 100644 examples/Megatron-LM/scripts/presplit_sentences_json.py create mode 100755 examples/Megatron-LM/scripts/pretrain_bert.sh create mode 100755 examples/Megatron-LM/scripts/pretrain_bert_distributed.sh create mode 100644 examples/Megatron-LM/scripts/pretrain_bert_model_parallel.sh create mode 100755 examples/Megatron-LM/scripts/pretrain_bert_sentencepiece.sh create mode 100755 examples/Megatron-LM/scripts/pretrain_bert_tfrecords_distributed.sh create mode 100644 examples/Megatron-LM/scripts/pretrain_gpt2.sh create mode 100755 examples/Megatron-LM/scripts/pretrain_gpt2_distributed.sh create mode 100644 examples/Megatron-LM/scripts/pretrain_gpt2_model_parallel.sh create mode 100644 examples/Megatron-LM/scripts/run_gpt2_eval.py create mode 100644 examples/Megatron-LM/scripts/split_json.py create mode 100644 examples/Megatron-LM/utils.py create mode 100644 examples/README.md create mode 100644 examples/autotuning/.gitignore create mode 100644 examples/autotuning/README.md create mode 100644 examples/autotuning/hf/README.md create mode 100644 examples/autotuning/hf/bert-base/README.md create mode 100644 examples/autotuning/hf/bert-base/ds_config_tune.json create mode 100755 examples/autotuning/hf/bert-base/test_tune.sh create mode 100644 examples/autotuning/hf/bert-large/README.md create mode 100644 examples/autotuning/hf/bert-large/ds_config_tune.json create mode 100755 examples/autotuning/hf/bert-large/test_tune.sh create mode 100644 examples/autotuning/hf/deberta/README.md create mode 100644 examples/autotuning/hf/deberta/ds_config_fp16_tune.json create mode 100755 examples/autotuning/hf/deberta/test_tune.sh create mode 100644 examples/autotuning/hf/distilbert/README.md create mode 100644 examples/autotuning/hf/distilbert/ds_config_tune.json create mode 100755 examples/autotuning/hf/distilbert/test_tune.sh create mode 100644 examples/autotuning/hf/dsconfigs/ds_config_fp16_tune.json create mode 100644 examples/autotuning/hf/dsconfigs/ds_config_fp16_z0.json create mode 100644 examples/autotuning/hf/dsconfigs/ds_config_fp16_z1.json create mode 100644 examples/autotuning/hf/dsconfigs/ds_config_fp16_z2.json create mode 100644 examples/autotuning/hf/dsconfigs/ds_config_fp16_z3.json create mode 100644 examples/autotuning/hf/dsconfigs/ds_config_tune.json create mode 100644 examples/autotuning/hf/dsconfigs/ds_config_z0.json create mode 100644 examples/autotuning/hf/dsconfigs/ds_config_z1.json create mode 100644 examples/autotuning/hf/dsconfigs/ds_config_z2.json create mode 100644 examples/autotuning/hf/dsconfigs/ds_config_z3.json create mode 100644 examples/autotuning/hf/gpt2-large/README.md create mode 100755 examples/autotuning/hf/gpt2-large/test_tune.sh create mode 100644 examples/autotuning/hf/gpt2-medium/README.md create mode 100755 examples/autotuning/hf/gpt2-medium/test_tune.sh create mode 100644 examples/autotuning/hf/gpt2-xl/README.md create mode 100755 examples/autotuning/hf/gpt2-xl/test_tune.sh create mode 100644 examples/autotuning/hf/gpt2/README.md create mode 100755 examples/autotuning/hf/gpt2/test_tune.sh create mode 100644 examples/pipeline_parallelism/alexnet.py create mode 100644 examples/pipeline_parallelism/ds_config.json create mode 100755 examples/pipeline_parallelism/run.sh create mode 100755 examples/pipeline_parallelism/train.py create mode 100644 intel_extension_for_deepspeed/__init__.py create mode 100755 intel_extension_for_deepspeed/op_builder/__init__.py create mode 100644 intel_extension_for_deepspeed/op_builder/builder.py create mode 100644 intel_extension_for_deepspeed/op_builder/cpu_adagrad.py create mode 100644 intel_extension_for_deepspeed/op_builder/cpu_adam.py create mode 100644 intel_extension_for_deepspeed/op_builder/csrc/adam/sycl/cpu_adam.dp.cpp create mode 100644 intel_extension_for_deepspeed/op_builder/csrc/adam/sycl/custom_sycl_kernel.dp.cpp create mode 100644 intel_extension_for_deepspeed/op_builder/csrc/adam/sycl/fused_adam_frontend.cpp create mode 100644 intel_extension_for_deepspeed/op_builder/csrc/adam/sycl/multi_tensor_adam.dp.cpp create mode 100644 intel_extension_for_deepspeed/op_builder/csrc/includes/multi_tensor_apply.dp.hpp create mode 100644 intel_extension_for_deepspeed/op_builder/csrc/includes/sycl/Timer.hpp create mode 100644 intel_extension_for_deepspeed/op_builder/csrc/includes/sycl/common.hpp create mode 100644 intel_extension_for_deepspeed/op_builder/csrc/includes/sycl/context.hpp create mode 100644 intel_extension_for_deepspeed/op_builder/csrc/includes/sycl/cpu_adam.hpp create mode 100644 intel_extension_for_deepspeed/op_builder/csrc/includes/sycl/custom_sycl_layers.hpp create mode 100644 intel_extension_for_deepspeed/op_builder/csrc/includes/sycl/dropout.hpp create mode 100644 intel_extension_for_deepspeed/op_builder/csrc/includes/sycl/ds_transformer_sycl.hpp create mode 100644 intel_extension_for_deepspeed/op_builder/csrc/includes/sycl/feed_forward.hpp create mode 100644 intel_extension_for_deepspeed/op_builder/csrc/includes/sycl/gelu.hpp create mode 100644 intel_extension_for_deepspeed/op_builder/csrc/includes/sycl/gemm_test.hpp create mode 100644 intel_extension_for_deepspeed/op_builder/csrc/includes/sycl/general_kernels.hpp create mode 100644 intel_extension_for_deepspeed/op_builder/csrc/includes/sycl/normalize_layer.hpp create mode 100644 intel_extension_for_deepspeed/op_builder/csrc/includes/sycl/onednn_wrappers.hpp create mode 100644 intel_extension_for_deepspeed/op_builder/csrc/includes/sycl/onemkl_wrappers.hpp create mode 100644 intel_extension_for_deepspeed/op_builder/csrc/includes/sycl/softmax.hpp create mode 100644 intel_extension_for_deepspeed/op_builder/csrc/includes/sycl/strided_batch_gemm.hpp create mode 100644 intel_extension_for_deepspeed/op_builder/csrc/includes/sycl/type_shim.hpp create mode 100644 intel_extension_for_deepspeed/op_builder/csrc/transformer/sycl/dropout_kernels.dp.cpp create mode 100644 intel_extension_for_deepspeed/op_builder/csrc/transformer/sycl/ds_dropout_sycl.dp.cpp create mode 100644 intel_extension_for_deepspeed/op_builder/csrc/transformer/sycl/ds_feedforward_sycl.dp.cpp create mode 100644 intel_extension_for_deepspeed/op_builder/csrc/transformer/sycl/ds_gelu_sycl.dp.cpp create mode 100644 intel_extension_for_deepspeed/op_builder/csrc/transformer/sycl/ds_layer_reorder_sycl.dp.cpp create mode 100644 intel_extension_for_deepspeed/op_builder/csrc/transformer/sycl/ds_normalize_sycl.dp.cpp create mode 100644 intel_extension_for_deepspeed/op_builder/csrc/transformer/sycl/ds_softmax_sycl.dp.cpp create mode 100644 intel_extension_for_deepspeed/op_builder/csrc/transformer/sycl/ds_stridedbatchgemm_sycl.dp.cpp create mode 100644 intel_extension_for_deepspeed/op_builder/csrc/transformer/sycl/ds_transformer_sycl.dp.cpp create mode 100644 intel_extension_for_deepspeed/op_builder/csrc/transformer/sycl/gelu_kernels.dp.cpp create mode 100644 intel_extension_for_deepspeed/op_builder/csrc/transformer/sycl/general_kernels.dp.cpp create mode 100644 intel_extension_for_deepspeed/op_builder/csrc/transformer/sycl/normalize_kernels.dp.cpp create mode 100644 intel_extension_for_deepspeed/op_builder/csrc/transformer/sycl/onednn_wrappers.dp.cpp create mode 100644 intel_extension_for_deepspeed/op_builder/csrc/transformer/sycl/onemkl_wrappers.dp.cpp create mode 100644 intel_extension_for_deepspeed/op_builder/csrc/transformer/sycl/softmax_kernels.dp.cpp create mode 100644 intel_extension_for_deepspeed/op_builder/csrc/transformer/sycl/transform_kernels.dp.cpp create mode 100644 intel_extension_for_deepspeed/op_builder/fused_adam.py create mode 100644 intel_extension_for_deepspeed/op_builder/quantizer.py create mode 100644 intel_extension_for_deepspeed/op_builder/transformer.py create mode 100644 intel_extension_for_deepspeed/op_builder/utils.py create mode 100644 intel_extension_for_deepspeed/xpu_accelerator.py create mode 100644 setup.py diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..082ff4a --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2022 Intel Corporation + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000..13ca22a --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,2 @@ +recursive-include intel_extension_for_deepspeed/op_builder/csrc *.cpp *.hpp +recursive-include intel_extension_for_deepspeed *.py diff --git a/README.md b/README.md new file mode 100644 index 0000000..8533735 --- /dev/null +++ b/README.md @@ -0,0 +1,18 @@ +# Intel Extension for DeepSpeed +Intel Extension for DeepSpeed is an extension that brings Intel GPU (XPU) support to DeepSpeed(https://github.com/Microsoft/DeepSpeed). It implements DeepSpeed Accelerator Interface as defined in https://github.com/microsoft/DeepSpeed/pull/2471. + +Intel Extension for DeepSpeed comes with the following components: +1. DeepSpeed Accelerator Interface implementation +2. DeepSpeed op builders implmentation for XPU +3. DeepSpeed op builder kernel code + +DeepSpeed would automatically use Intel Extension for DeepSpeed when it is installed as a python package. After installation, models ported for DeepSpeed Accelerator Interface that run on DeepSpeed as in https://github.com/microsoft/DeepSpeed/pull/2471 could run on Intel GPU device. + +Usage: +1. Install Intel Extension for DeepSpeed + +`python setup.py install` + +2. Install DeepSpeed + +`CC=dpcpp CFLAGS=-fPIC CXX=dpcpp CXXFLAGS=-fPIC DS_BUILD_DEVICE=dpcpp DS_BUILD_CPU_ADAM=1 DS_BUILD_FUSED_ADAM=1 DS_BUILD_QUANTIZER=1 DS_BUILD_TRANSFORMER=1 DS_BUILD_UTILS=1 python setup.py install` diff --git a/examples/LICENSE b/examples/LICENSE new file mode 100644 index 0000000..3d8b93b --- /dev/null +++ b/examples/LICENSE @@ -0,0 +1,21 @@ + MIT License + + Copyright (c) Microsoft Corporation. + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE diff --git a/examples/Megatron-LM/LICENSE b/examples/Megatron-LM/LICENSE new file mode 100644 index 0000000..b84f5de --- /dev/null +++ b/examples/Megatron-LM/LICENSE @@ -0,0 +1,231 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +------------- LICENSE FOR huggingface(transformer) repository -------------- + + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/examples/Megatron-LM/NOTICE.txt b/examples/Megatron-LM/NOTICE.txt new file mode 100644 index 0000000..6271b03 --- /dev/null +++ b/examples/Megatron-LM/NOTICE.txt @@ -0,0 +1,253 @@ +NOTICES AND INFORMATION +Do Not Translate or Localize + +This software incorporates material from third parties. Microsoft makes certain +open source code available at https://3rdpartysource.microsoft.com, or you may +send a check or money order for US $5.00, including the product name, the open +source component name, and version number, to: + +Source Code Compliance Team +Microsoft Corporation +One Microsoft Way +Redmond, WA 98052 +USA + +Notwithstanding any other terms, you may reverse engineer this software to the +extent required to debug changes to any libraries licensed under the GNU Lesser +General Public License. + +Component: Megatron-LM + +Open Source License/Copyright Notice. + +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +------------- LICENSE FOR huggingface(transformer) repository -------------- + + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/examples/Megatron-LM/README.md b/examples/Megatron-LM/README.md new file mode 100644 index 0000000..6a94b46 --- /dev/null +++ b/examples/Megatron-LM/README.md @@ -0,0 +1,123 @@ +Megatron is a large, powerful transformer. This repo is for ongoing research on training large, powerful transformer language models at scale. Currently, we support multicards training of [GPT2](https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf). + +The codebase is capable of efficiently training a 30-layer, 3.6 Billion Parameter GPT2 Language Model across Intel GPUs. + + +# Setup +We officially support python3.9.7 and we highly recommend to install an [Anaconda](https://www.anaconda.com/distribution/#download-section) environment. + +## Prerequisites +To use this repo please install the specified versions of dependent software. You will need: +- Python 3.9.7 or later. +- Intel GPU driver for AI/compute workload + +## Install Dependencies +Install Framework Dependency: +- [Intel® Extension for PyTorch\*](https://github.com/intel/intel-extension-for-pytorch/tree/xpu-master) with XPU support. +- [Torch-ccl](https://github.com/intel/torch-ccl) with XPU support. + +Install DeepSpeed Dependency: +- [DeepSpeed with XPU support](https://github.com/microsoft/DeepSpeed/pull/2221) +- [Intel® Extension for DeepSpeed\*](https://github.com/intel/intel-extension-for-deepspeed) + + +Create a virtual environment by conda and install dependent python packages: +``` +conda create --name gpt2_env python=3.9.7 +conda activate gpt2_env +pip install -r requirements.txt +``` + +# Usage +We've provided a script, gpt-3.6b.sh, for pretrain GPT2. + +## GPT2 Pretraining +`bash scripts/gpt-3.6b.sh` + +This script launches gpt2 pretraining that is verified on Intel GPUs. + + +``` +python pretrain_gpt2.py \ + --model-parallel-size 1 \ + --num-layers 30 \ + --hidden-size 3072 \ + --num-attention-heads 32 \ + --batch-size 8 \ + --seq-length 2048 \ + --max-position-embeddings 1024 \ + --train-iters 1000 \ + --resume-dataloader \ + --train-data c4/en \ + --lazy-loader \ + --tokenizer-type GPT2BPETokenizer \ + --cache-dir cache \ + --split 949,50,1 \ + --distributed-backend ccl \ + --lr 0.00015 \ + --no-load-optim \ + --lr-decay-style cosine \ + --weight-decay 1e-2 \ + --clip-grad 1.0 \ + --warmup .01 \ + --checkpoint-activations \ + --deepspeed-activation-checkpointing \ + --bf16 +``` + +## Datasets +We do not host any datasets for GPT2 training. However, we detail the collection so that our results can be reproduced. + +### Prepare c4/en Training Data +We use c4/en/3.0.1 dataset from [HuggingFace/AllenAI](https://huggingface.co/datasets/allenai/c4). First, make sure you have [Git Large File Storage](https://git-lfs.github.com/) installed. + +``` +GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/datasets/allenai/c4 +cd c4 +git-lfs pull --include "en/*" +``` + +This will download c4/en about 305GB to your local device. Then you can run the following commands to get a train json named c4-train.json. +``` +mkdir -p data/c4/en/ +cat c4/en/c4-train* > data/c4/en/c4-train.json.gz +pushd data/c4/en +gzip -d c4-train.json.gz +popd +cat c4/en/c4-validation.0000* > data/c4/en/c4-validation.json.gz +``` +If you don't need all c4/en datasets, you can run the following commands to merge 1024 original json.gz files into 8 json.gz files. +``` +cd + +mkdir -p softlinks +for shard in {0..7}; do + start=$((shard * 128)) + end=$((shard * 128 + 127)) + mkdir -p softlinks/en_$shard + for ind in $(seq -f "%05g" $start $end); do + ln -s ../../en/c4-train.${ind}-of-01024.json.gz softlinks/en_${shard}/c4-train.${ind}-of-01024.json.gz + done +done +mkdir -p en_merge +for shard in {0..7}; do + cat softlinks/en_${shard}/*gz > en_merge/c4-train.en_${shard}.json.gz +done +``` + +If your system is memory limited we also recommend to run pretraining with the `--lazy-loader` argument as we've done. After preprocessing the dataset once, this will allow the dataset to be lazily loaded from disk, as opposed to storing it in memory. + + +### Aliasing datasets with corpora.py +We recommend aliasing datasets with human readable names (eg. `--train-data wikipedia`). This helps avoid forgetting arguments when submitting jobs, and allows one to combine datasets that would otherwise require different commandline options/data structures. + +Examples of how to create these dataset objects can be found in [`./data_utils/corpora.py`](./data_utils/corpora.py). We recommend that the objects inherit from or adhere to the interface laid out by `torch.utils.data.Dataset` objects. + +Any created datasets should be then added to the `NAMED_CORPORA` dictionary object in [`./data_utils/corpora.py`](./data_utils/corpora.py). At runtime one can specify one or more corpora from the commandline with `--train-data corpus1 corpus2 corpus3`, `--valid-data corpus1 corpus2 corpus3`, or `--test-data ...`. + + +### Partitioning datasets into Train/Val/Test +We support multiple ways to partition corpora into train/val/test splits. By specifying a `--split 95,5` commandline argument, the corpora specified by `--train-data` will have it's documents split proportionally into a 95%, 5% train/val split. The split is performed lazily on the fly and is efficient and deterministic from run to run given the same `--seed`. Note that if `--valid-data` or `--test-data` is specified then the train data will still be split accordingly, but `--valid-data`/`--test-data` will still be used as the validation/test source. + +We do realize that this method, while effective, introduces noise into the development process, since different seeds will change the dataset and outcome. To have fixed training/validation/test sets across all your runs please utilize our script [`./scripts/split_json.py`](./scripts/split_json.py) + diff --git a/examples/Megatron-LM/arguments.py b/examples/Megatron-LM/arguments.py new file mode 100755 index 0000000..834a04d --- /dev/null +++ b/examples/Megatron-LM/arguments.py @@ -0,0 +1,379 @@ +# coding=utf-8 +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""argparser configuration""" + +import argparse +import os +import torch +import deepspeed +from deepspeed.accelerator.real_accelerator import get_accelerator + + +def add_model_config_args(parser): + """Model arguments""" + + group = parser.add_argument_group('model', 'model configuration') + + group.add_argument('--pretrained-bert', action='store_true', + help='use a pretrained bert-large-uncased model instead' + 'of initializing from scratch. See ' + '--tokenizer-model-type to specify which pretrained ' + 'BERT model to use') + group.add_argument('--attention-dropout', type=float, default=0.1, + help='dropout probability for attention weights') + group.add_argument('--num-attention-heads', type=int, default=16, + help='num of transformer attention heads') + group.add_argument('--hidden-size', type=int, default=1024, + help='tansformer hidden size') + group.add_argument('--intermediate-size', type=int, default=None, + help='transformer embedding dimension for FFN' + 'set to 4*`--hidden-size` if it is None') + group.add_argument('--num-layers', type=int, default=24, + help='num decoder layers') + group.add_argument('--layernorm-epsilon', type=float, default=1e-5, + help='layer norm epsilon') + group.add_argument('--hidden-dropout', type=float, default=0.1, + help='dropout probability for hidden state transformer') + group.add_argument('--max-position-embeddings', type=int, default=512, + help='maximum number of position embeddings to use') + group.add_argument('--vocab-size', type=int, default=30522, + help='vocab size to use for non-character-level ' + 'tokenization. This value will only be used when ' + 'creating a tokenizer') + group.add_argument('--deep-init', action='store_true', + help='initialize bert model similar to gpt2 model.' + 'scales initialization of projection layers by a ' + 'factor of 1/sqrt(2N). Necessary to train bert ' + 'models larger than BERT-Large.') + group.add_argument('--make-vocab-size-divisible-by', type=int, default=128, + help='Pad the vocab size to be divisible by this value.' + 'This is added for computational efficieny reasons.') + group.add_argument('--cpu-optimizer', action='store_true', + help='Run optimizer on CPU') + group.add_argument('--cpu_torch_adam', action='store_true', + help='Use Torch Adam as optimizer on CPU.') + + return parser + + +def add_fp16_config_args(parser): + """Mixed precision arguments.""" + + group = parser.add_argument_group('fp16', 'fp16 configurations') + + group.add_argument('--fp16', action='store_true', + help='Run model in fp16 mode') + group.add_argument('--fp32-embedding', action='store_true', + help='embedding in fp32') + group.add_argument('--fp32-layernorm', action='store_true', + help='layer norm in fp32') + group.add_argument('--fp32-tokentypes', action='store_true', + help='embedding token types in fp32') + group.add_argument('--fp32-allreduce', action='store_true', + help='all-reduce in fp32') + group.add_argument('--hysteresis', type=int, default=2, + help='hysteresis for dynamic loss scaling') + group.add_argument('--loss-scale', type=float, default=None, + help='Static loss scaling, positive power of 2 ' + 'values can improve fp16 convergence. If None, dynamic' + 'loss scaling is used.') + group.add_argument('--loss-scale-window', type=float, default=1000, + help='Window over which to raise/lower dynamic scale') + group.add_argument('--min-scale', type=float, default=1, + help='Minimum loss scale for dynamic loss scale') + + return parser + + +def add_bf16_config_args(parser): + """Mixed precision arguments.""" + + group = parser.add_argument_group('bf16', 'bf16 configurations') + + group.add_argument('--bf16', action='store_true', + help='Run model in bf16 mode') + + return parser + + +def add_training_args(parser): + """Training arguments.""" + + group = parser.add_argument_group('train', 'training configurations') + + group.add_argument('--batch-size', type=int, default=4, + help='Data Loader batch size') + group.add_argument('--weight-decay', type=float, default=0.01, + help='weight decay coefficient for L2 regularization') + group.add_argument('--checkpoint-activations', action='store_true', + help='checkpoint activation to allow for training ' + 'with larger models and sequences') + group.add_argument('--checkpoint-num-layers', type=int, default=1, + help='chunk size (number of layers) for checkpointing') + group.add_argument('--deepspeed-activation-checkpointing', action='store_true', + help='uses activation checkpointing from deepspeed') + group.add_argument('--clip-grad', type=float, default=1.0, + help='gradient clipping') + group.add_argument('--train-iters', type=int, default=1000000, + help='total number of iterations to train over all training runs') + group.add_argument('--log-interval', type=int, default=100, + help='report interval') + group.add_argument('--exit-interval', type=int, default=None, + help='Exit the program after this many new iterations.') + + group.add_argument('--seed', type=int, default=1234, + help='random seed') + # Batch prodecuer arguments + group.add_argument('--reset-position-ids', action='store_true', + help='Reset posistion ids after end-of-document token.') + group.add_argument('--reset-attention-mask', action='store_true', + help='Reset self attention maske after ' + 'end-of-document token.') + + # Learning rate. + group.add_argument('--lr-decay-iters', type=int, default=None, + help='number of iterations to decay LR over,' + ' If None defaults to `--train-iters`*`--epochs`') + group.add_argument('--lr-decay-style', type=str, default='linear', + choices=['constant', 'linear', 'cosine', 'exponential'], + help='learning rate decay function') + group.add_argument('--lr', type=float, default=1.0e-4, + help='initial learning rate') + group.add_argument('--warmup', type=float, default=0.01, + help='percentage of data to warmup on (.01 = 1% of all ' + 'training iters). Default 0.01') + # model checkpointing + group.add_argument('--save', type=str, default=None, + help='Output directory to save checkpoints to.') + group.add_argument('--save-interval', type=int, default=5000, + help='number of iterations between saves') + group.add_argument('--no-save-optim', action='store_true', + help='Do not save current optimizer.') + group.add_argument('--no-save-rng', action='store_true', + help='Do not save current rng state.') + group.add_argument('--load', type=str, default=None, + help='Path to a directory containing a model checkpoint.') + group.add_argument('--no-load-optim', action='store_true', + help='Do not load optimizer when loading checkpoint.') + group.add_argument('--no-load-rng', action='store_true', + help='Do not load rng state when loading checkpoint.') + group.add_argument('--finetune', action='store_true', + help='Load model for finetuning. Do not load optimizer ' + 'or rng state from checkpoint and set iteration to 0. ' + 'Assumed when loading a release checkpoint.') + group.add_argument('--resume-dataloader', action='store_true', + help='Resume the dataloader when resuming training. ' + 'Does not apply to tfrecords dataloader, try resuming' + 'with a different seed in this case.') + # distributed training args + group.add_argument('--distributed-backend', default='nccl', + help='which backend to use for distributed ' + 'training. One of [gloo, nccl]') + + group.add_argument('--local_rank', type=int, default=None, + help='local rank passed from distributed launcher') + group.add_argument('--disable-sysmon', action='store_true', + help='#check the GPU memory after each iteration') + + return parser + + +def add_evaluation_args(parser): + """Evaluation arguments.""" + + group = parser.add_argument_group('validation', 'validation configurations') + + group.add_argument('--eval-batch-size', type=int, default=None, + help='Data Loader batch size for evaluation datasets.' + 'Defaults to `--batch-size`') + group.add_argument('--eval-iters', type=int, default=100, + help='number of iterations to run for evaluation' + 'validation/test for') + group.add_argument('--eval-interval', type=int, default=1000, + help='interval between running evaluation on validation set') + group.add_argument('--eval-seq-length', type=int, default=None, + help='Maximum sequence length to process for ' + 'evaluation. Defaults to `--seq-length`') + group.add_argument('--eval-max-preds-per-seq', type=int, default=None, + help='Maximum number of predictions to use for ' + 'evaluation. Defaults to ' + 'math.ceil(`--eval-seq-length`*.15/10)*10') + group.add_argument('--overlapping-eval', type=int, default=32, + help='sliding window for overlapping eval ') + group.add_argument('--cloze-eval', action='store_true', + help='Evaluation dataset from `--valid-data` is a cloze task') + group.add_argument('--eval-hf', action='store_true', + help='perform evaluation with huggingface openai model.' + 'use `--load` to specify weights path to be loaded') + group.add_argument('--load-openai', action='store_true', + help='load openai weights into our model. Use `--load` ' + 'to specify weights path to be loaded') + + return parser + +def add_text_generate_args(parser): + """Text generate arguments.""" + + group = parser.add_argument_group('Text generation', 'configurations') + group.add_argument("--temperature", type=float, default=1.0) + group.add_argument("--top_p", type=float, default=0.0) + group.add_argument("--top_k", type=int, default=0) + group.add_argument("--out-seq-length", type=int, default=256) + return parser + + +def add_data_args(parser): + """Train/valid/test data arguments.""" + + group = parser.add_argument_group('data', 'data configurations') + + group.add_argument('--model-parallel-size', type=int, default=1, + help='size of the model parallel.') + group.add_argument('--shuffle', action='store_true', + help='Shuffle data. Shuffling is deterministic ' + 'based on seed and current epoch.') + group.add_argument('--train-data', nargs='+', default=None, + help='Whitespace separated filenames or corpora names ' + 'for training.') + + group.add_argument('--use-npy-data-loader', action='store_true', + help='Use the numpy data loader. If set, then' + 'train-data-path, val-data-path, and test-data-path' + 'should also be provided.') + group.add_argument('--train-data-path', type=str, default='', + help='path to the training data') + group.add_argument('--val-data-path', type=str, default='', + help='path to the validation data') + group.add_argument('--test-data-path', type=str, default='', + help='path to the test data') + group.add_argument('--input-data-sizes-file', type=str, default='sizes.txt', + help='the filename containing all the shards sizes') + + group.add_argument('--delim', default=',', + help='delimiter used to parse csv data files') + group.add_argument('--text-key', default='sentence', + help='key to use to extract text from json/csv') + group.add_argument('--eval-text-key', default=None, + help='key to use to extract text from ' + 'json/csv evaluation datasets') + group.add_argument('--valid-data', nargs='*', default=None, + help="""Filename for validation data.""") + group.add_argument('--split', default='1000,1,1', + help='comma-separated list of proportions for training,' + ' validation, and test split') + group.add_argument('--test-data', nargs='*', default=None, + help="""Filename for testing""") + + group.add_argument('--lazy-loader', action='store_true', + help='whether to lazy read the data set') + group.add_argument('--loose-json', action='store_true', + help='Use loose json (one json-formatted string per ' + 'newline), instead of tight json (data file is one ' + 'json string)') + group.add_argument('--presplit-sentences', action='store_true', + help='Dataset content consists of documents where ' + 'each document consists of newline separated sentences') + group.add_argument('--num-workers', type=int, default=2, + help="""Number of workers to use for dataloading""") + group.add_argument('--tokenizer-model-type', type=str, + default='bert-large-uncased', + help="Model type to use for sentencepiece tokenization \ + (one of ['bpe', 'char', 'unigram', 'word']) or \ + bert vocab to use for BertWordPieceTokenizer (one of \ + ['bert-large-uncased', 'bert-large-cased', etc.])") + group.add_argument('--tokenizer-path', type=str, default='tokenizer.model', + help='path used to save/load sentencepiece tokenization ' + 'models') + group.add_argument('--tokenizer-type', type=str, + default='BertWordPieceTokenizer', + choices=['CharacterLevelTokenizer', + 'SentencePieceTokenizer', + 'BertWordPieceTokenizer', + 'GPT2BPETokenizer'], + help='what type of tokenizer to use') + group.add_argument("--cache-dir", default=None, type=str, + help="Where to store pre-trained BERT downloads") + group.add_argument('--use-tfrecords', action='store_true', + help='load `--train-data`, `--valid-data`, ' + '`--test-data` from BERT tf records instead of ' + 'normal data pipeline') + group.add_argument('--seq-length', type=int, default=512, + help="Maximum sequence length to process") + group.add_argument('--max-preds-per-seq', type=int, default=None, + help='Maximum number of predictions to use per sequence.' + 'Defaults to math.ceil(`--seq-length`*.15/10)*10.' + 'MUST BE SPECIFIED IF `--use-tfrecords` is True.') + + return parser + +def get_args(): + """Parse all the args.""" + + parser = argparse.ArgumentParser(description='PyTorch BERT Model') + parser = add_model_config_args(parser) + parser = add_fp16_config_args(parser) + parser = add_bf16_config_args(parser) + parser = add_training_args(parser) + parser = add_evaluation_args(parser) + parser = add_text_generate_args(parser) + parser = add_data_args(parser) + + # Include DeepSpeed configuration arguments + parser = deepspeed.add_config_arguments(parser) + + args = parser.parse_args() + + if not args.train_data and not args.train_data_path: + print('WARNING: No training data specified') + + args.cuda = get_accelerator().is_available() + + args.rank = int(os.getenv('RANK', '0')) + args.world_size = int(os.getenv("WORLD_SIZE", '1')) + + if os.getenv('OMPI_COMM_WORLD_LOCAL_RANK'): + # We are using (OpenMPI) mpirun for launching distributed data parallel processes + local_rank = int(os.getenv('OMPI_COMM_WORLD_LOCAL_RANK')) + local_size = int(os.getenv('OMPI_COMM_WORLD_LOCAL_SIZE')) + + # Possibly running with Slurm + num_nodes = int(os.getenv('SLURM_JOB_NUM_NODES', '1')) + nodeid = int(os.getenv('SLURM_NODEID', '0')) + + args.local_rank = local_rank + args.rank = nodeid*local_size + local_rank + args.world_size = num_nodes*local_size + + args.model_parallel_size = min(args.model_parallel_size, args.world_size) + if args.rank == 0: + print('using world size: {} and model-parallel size: {} '.format( + args.world_size, args.model_parallel_size)) + + args.dynamic_loss_scale = False + if args.loss_scale is None: + args.dynamic_loss_scale = True + if args.rank == 0: + print(' > using dynamic loss scaling') + + # The args fp32_* or fp16_* meant to be active when the + # args fp16 is set. So the default behaviour should all + # be false. + if not args.fp16: + args.fp32_embedding = False + args.fp32_tokentypes = False + args.fp32_layernorm = False + + return args diff --git a/examples/Megatron-LM/bf16/__init__.py b/examples/Megatron-LM/bf16/__init__.py new file mode 100644 index 0000000..8dd3b10 --- /dev/null +++ b/examples/Megatron-LM/bf16/__init__.py @@ -0,0 +1,16 @@ +# coding=utf-8 +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .bf16 import * diff --git a/examples/Megatron-LM/bf16/bf16.py b/examples/Megatron-LM/bf16/bf16.py new file mode 100755 index 0000000..aca1c7d --- /dev/null +++ b/examples/Megatron-LM/bf16/bf16.py @@ -0,0 +1,70 @@ +# coding=utf-8 +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +from torch import nn +from torch.autograd import Variable +from torch.nn.parameter import Parameter +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors +from deepspeed.accelerator.real_accelerator import get_accelerator + + +FLOAT_TYPES = (torch.FloatTensor, get_accelerator().FloatTensor) +HALF_TYPES = (torch.HalfTensor, get_accelerator().HalfTensor) + +def conversion_helper(val, conversion): + """Apply conversion to val. Recursively apply conversion if `val` is a nested tuple/list structure.""" + if not isinstance(val, (tuple, list)): + return conversion(val) + rtn = [conversion_helper(v, conversion) for v in val] + if isinstance(val, tuple): + rtn = tuple(rtn) + return rtn + +def fp32_to_bf16(val): + """Convert fp32 `val` to bf16""" + def half_conversion(val): + val_typecheck = val + if isinstance(val_typecheck, (Parameter, Variable)): + val_typecheck = val.data + if isinstance(val_typecheck, FLOAT_TYPES): + val = val.bf16() + return val + return conversion_helper(val, half_conversion) + +def bf16_to_fp32(val): + """Convert bf16 `val` to fp32""" + def float_conversion(val): + val_typecheck = val + if isinstance(val_typecheck, (Parameter, Variable)): + val_typecheck = val.data + if isinstance(val_typecheck, HALF_TYPES): + val = val.float() + return val + return conversion_helper(val, float_conversion) + +class BF16_Module(nn.Module): + def __init__(self, module): + super(BF16_Module, self).__init__() + self.add_module('module', module.bfloat16()) + + def forward(self, *inputs, **kwargs): + return bf16_to_fp32(self.module(*(fp32_to_bf16(inputs)), **kwargs)) + + def state_dict(self, destination=None, prefix='', keep_vars=False): + return self.module.state_dict(destination, prefix, keep_vars) + + def load_state_dict(self, state_dict, strict=True): + self.module.load_state_dict(state_dict, strict=strict) + diff --git a/examples/Megatron-LM/configure_data.py b/examples/Megatron-LM/configure_data.py new file mode 100644 index 0000000..9598921 --- /dev/null +++ b/examples/Megatron-LM/configure_data.py @@ -0,0 +1,246 @@ +# coding=utf-8 +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""parses arguments and preps data loader""" + +import copy +import torch +import data_utils + +import mpu + +class DataConfig: + + def __init__(self, defaults={}): + super(DataConfig, self).__init__() + self.defaults = defaults + + def apply(self, args): + if torch.distributed.get_rank() == 0: + print('configuring data') + self.apply_defaults(args) + return make_loaders(args) + + def set_defaults(self, **kwargs): + for k, v in kwargs.items(): + self.defaults[k] = v + + def apply_defaults(self, args): + for k, v in self.defaults.items(): + k = k.replace('-', '_') + if not hasattr(args, k): + setattr(args, k, v) + + +def make_data_loader(dataset, batch_size, args): + + shuffle = args.shuffle + if shuffle: + sampler = data_utils.samplers.RandomSampler(dataset, replacement=True, num_samples=batch_size*args.train_iters) + else: + sampler = torch.utils.data.SequentialSampler(dataset) + world_size = torch.distributed.get_world_size( + group=mpu.get_data_parallel_group()) + rank = torch.distributed.get_rank(group=mpu.get_data_parallel_group()) + distributed = world_size > 1 + drop_last = distributed + + if distributed: + batch_sampler = data_utils.samplers.DistributedBatchSampler(sampler, + batch_size, + drop_last, + rank, + world_size) + else: + batch_sampler = torch.utils.data.BatchSampler(sampler, + batch_size, + drop_last) + + data_loader = torch.utils.data.DataLoader(dataset, + batch_sampler=batch_sampler, + num_workers=args.num_workers, + pin_memory=True) + + return data_loader + + +def make_tfrecord_loaders(args): + """Load train/val/test dataset from shuffled TFRecords""" + + import data_utils.tf_dl + data_set_args = {'batch_size': args.batch_size, + 'max_seq_len': args.seq_length, + 'max_preds_per_seq': args.max_preds_per_seq, + 'train': True, + 'num_workers': max(args.num_workers, 1), + 'seed': args.seed + args.rank + 1, + 'threaded_dl': args.num_workers > 0 + } + train = data_utils.tf_dl.TFRecordDataLoader(args.train_data, + **data_set_args) + data_set_args['train'] = False + if args.eval_seq_length is not None: + data_set_args['max_seq_len'] = args.eval_seq_length + if args.eval_max_preds_per_seq is not None: + data_set_args['max_preds_per_seq'] = args.eval_max_preds_per_seq + valid = None + if args.valid_data is not None: + valid = data_utils.tf_dl.TFRecordDataLoader(args.valid_data, + **data_set_args) + test = None + if args.test_data is not None: + test = data_utils.tf_dl.TFRecordDataLoader(args.test_data, + **data_set_args) + tokenizer = data_utils.make_tokenizer(args.tokenizer_type, + train, + args.tokenizer_path, + args.vocab_size, + args.tokenizer_model_type, + cache_dir=args.cache_dir) + + return (train, valid, test), tokenizer + + +def make_loaders(args): + """makes training/val/test""" + + if args.use_tfrecords: + return make_tfrecord_loaders(args) + world_size = torch.distributed.get_world_size( + group=mpu.get_data_parallel_group()) + batch_size = args.batch_size * world_size + eval_batch_size = batch_size + if args.eval_batch_size is not None: + eval_batch_size = args.eval_batch_size * world_size + seq_length = args.seq_length + if seq_length < 0: + seq_length = seq_length * world_size + eval_seq_length = args.eval_seq_length + if eval_seq_length is not None and eval_seq_length < 0: + eval_seq_length = eval_seq_length * world_size + split = get_split(args) + data_set_args = { + 'path': args.train_data, + 'seq_length': seq_length, + 'lazy': args.lazy_loader, + 'delim': args.delim, + 'text_key': args.text_key, + 'label_key': 'label', + 'non_binary_cols': None, + 'ds_type': args.data_set_type, + 'split': split, + 'loose': args.loose_json, + 'tokenizer_type': args.tokenizer_type, + 'tokenizer_model_path': args.tokenizer_path, + 'vocab_size': args.vocab_size, + 'model_type': args.tokenizer_model_type, + 'cache_dir': args.cache_dir, + 'max_preds_per_seq': args.max_preds_per_seq, + 'presplit_sentences': args.presplit_sentences} + + eval_set_args = copy.copy(data_set_args) + eval_set_args['split'] = [1.] + # if optional eval args were set then replace their + # equivalent values in the arg dict + if eval_seq_length: + eval_set_args['seq_length'] = eval_seq_length + if args.eval_max_preds_per_seq: + eval_set_args['max_preds_per_seq'] = args.eval_max_preds_per_seq + if args.eval_text_key is not None: + eval_set_args['text_key'] = args.eval_text_key + + # make datasets splits and tokenizer + train = None + valid = None + test = None + + if args.train_data is not None: + train, tokenizer = data_utils.make_dataset(**data_set_args) + if data_utils.should_split(split): + train, valid, test = train + eval_set_args['tokenizer'] = tokenizer + + # make training and val dataset if necessary + if valid is None and args.valid_data is not None: + eval_set_args['path'] = args.valid_data + valid, tokenizer = data_utils.make_dataset(**eval_set_args) + eval_set_args['tokenizer'] = tokenizer + if test is None and args.test_data is not None: + eval_set_args['path'] = args.test_data + test, tokenizer = data_utils.make_dataset(**eval_set_args) + + # wrap datasets with data loader + if train is not None and args.batch_size > 0: + train = make_data_loader(train, batch_size, args) + args.do_train = True + else: + args.do_train = False + eval_batch_size = eval_batch_size if eval_batch_size != 0 else batch_size + if valid is not None: + valid = make_data_loader(valid, eval_batch_size, args) + args.do_valid = True + else: + args.do_valid = False + if test is not None: + test = make_data_loader(test, eval_batch_size, args) + args.do_test = True + else: + args.do_test = False + + return (train, valid, test), tokenizer + +def get_split(args): + """ + Get dataset splits from comma separated string list + """ + splits = [] + if args.split.find(',') != -1: + splits = [float(s) for s in args.split.split(',')] + elif args.split.find('/') != -1: + splits = [float(s) for s in args.split.split('/')] + else: + splits = [float(args.split)] + split_total = sum(splits) + if split_total < 1.: + splits.append(1-split_total) + while len(splits) < 3: + splits.append(0.) + splits = splits[:3] + if args.valid_data is not None: + splits[1] = 0. + if args.test_data is not None: + splits[2] = 0. + final_sum = sum(splits) + return [s/final_sum for s in splits] + +def configure_data(): + + """add cmdline flags for configuring datasets""" + # These are options that are used by data_utils, but are either + # deprecated or not meant to be exposed to the command line user. + # These options are intneded to be set in code by specific scripts. + defaults = { + 'world_size': 1, + 'rank': -1, + 'persist_state': 0, + 'lazy': False, + 'transpose': False, + 'data_set_type': 'supervised', + 'seq_length': 256, + 'eval_seq_length': 256, + 'samples_per_shard': 100 + } + + return DataConfig(defaults=defaults) diff --git a/examples/Megatron-LM/data_utils/__init__.py b/examples/Megatron-LM/data_utils/__init__.py new file mode 100644 index 0000000..6cb092c --- /dev/null +++ b/examples/Megatron-LM/data_utils/__init__.py @@ -0,0 +1,122 @@ +# coding=utf-8 +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""utils for creating datasets""" +import os +import math + +from .samplers import DistributedBatchSampler +from .datasets import json_dataset, csv_dataset, split_ds, ConcatDataset, SplitDataset, bert_sentencepair_dataset, GPT2Dataset +from .lazy_loader import exists_lazy, make_lazy, lazy_array_loader +from .tokenization import Tokenization, CommandToken, Tokenizer, CharacterLevelTokenizer, BertWordPieceTokenizer, GPT2BPETokenizer, make_tokenizer +from . import corpora + +TRAIN_DATA = 0 +VAL_DATA = 1 +TEST_DATA = 2 + +def should_split(split): + """ + given split proportions checks if should split + Examples: + >>> should_split([10,0,0]) + False + >>> should_split([1,.1,.2]) + True + """ + return max(split)/sum(split) != 1. + +def get_ext(path): + """gets path extension""" + return os.path.splitext(path)[1] + +def get_dataset(path, **kwargs): + """gets dataset object based on keyword args and file at `path`""" + if supported_corpus(path): + return corpora.NAMED_CORPORA[path](**kwargs) + ext = get_ext(path) + if '.json' in ext: + text = json_dataset(path, **kwargs) + elif ext in ['.csv', '.tsv']: + text = csv_dataset(path, **kwargs) + else: + raise NotImplementedError('data file type %s is not supported'%(ext)) + return text + +def supported_corpus(corpus_name): + """checks if corpus name is defined in `corpora.py`""" + return corpus_name in corpora.NAMED_CORPORA + +def make_dataset(path, seq_length, text_key, label_key, lazy=False, process_fn=None, split=[1.], + delim=',', loose=False, binarize_sent=False, drop_unlabeled=False, tokenizer=None, + tokenizer_type='CharacterLevelTokenizer', tokenizer_model_path=None, vocab_size=None, + model_type='bpe', pad_token=0, character_converage=1.0, non_binary_cols=None, **kwargs): + """function to create datasets+tokenizers for common options""" + if isinstance(process_fn, str): + process_fn = eval(process_fn) + if non_binary_cols is not None: + # multilabel dataset support (only for csvs) + label_key = non_binary_cols + def get_dataset_from_path(path_): + if lazy: + # get lazily loaded dataset + named_corpora = False + if supported_corpus(path_): + named_corpora = True + name = path_ + path_ = corpora.NAMED_CORPORA[path_].PATH + if not exists_lazy(path_, data_type='data'): + # create cached version of dataset for lazy loading if it doesn't exist + text = get_dataset(name if named_corpora else path_, text_key=text_key, label_key=label_key, binarize_sent=binarize_sent, + delim=delim, drop_unlabeled=drop_unlabeled, loose_json=loose, is_lazy=True) + if '.json' not in get_ext(path_): + make_lazy(path_, text.X, data_type='data') + text = lazy_array_loader(path_, data_type='data', map_fn=process_fn) + else: + # get dataset + text = get_dataset(path_, text_key=text_key, label_key=label_key, binarize_sent=binarize_sent, + delim=delim, drop_unlabeled=drop_unlabeled, loose_json=loose, preprocess_fn=process_fn) + return text + # get one or multiple datasets and concatenate + if isinstance(path, str): + path = [path] + datasets = [get_dataset_from_path(p) for p in path] + if len(datasets) == 1: + ds = datasets[0] + else: + ds = ConcatDataset(datasets) + # make tokenizer for dataset + if tokenizer is None: + tokenizer = make_tokenizer(tokenizer_type, ds, tokenizer_model_path, vocab_size, model_type, + pad_token, character_converage, **kwargs) + + ds_type = '' + if 'ds_type' in kwargs: + ds_type = kwargs['ds_type'] + ds.SetTokenizer(tokenizer) + # Split dataset into train/val/test (and wrap bert dataset) + if should_split(split): + ds = split_ds(ds, split) + if ds_type.lower() == 'bert': + presplit_sentences = kwargs['presplit_sentences'] if 'presplit_sentences' in kwargs else False + ds = [bert_sentencepair_dataset(d, max_seq_len=seq_length, presplit_sentences=presplit_sentences) if d is not None else None for d in ds] + elif ds_type.lower() == 'gpt2': + ds = [GPT2Dataset(d, max_seq_len=seq_length) if d is not None else None for d in ds] + else: + if ds_type.lower() == 'bert': + presplit_sentences = kwargs['presplit_sentences'] if 'presplit_sentences' in kwargs else False + ds = bert_sentencepair_dataset(ds, max_seq_len=seq_length, presplit_sentences=presplit_sentences) + elif ds_type.lower() == 'gpt2': + ds = GPT2Dataset(ds, max_seq_len=seq_length) + return ds, tokenizer diff --git a/examples/Megatron-LM/data_utils/corpora.py b/examples/Megatron-LM/data_utils/corpora.py new file mode 100755 index 0000000..57de0b5 --- /dev/null +++ b/examples/Megatron-LM/data_utils/corpora.py @@ -0,0 +1,75 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""several datasets with preset arguments""" +from .datasets import json_dataset, csv_dataset +import os + +class wikipedia(json_dataset): + """ + dataset for wikipedia with arguments configured for convenience + + command line usage: `--train-data wikipedia` + """ + PATH = 'data/wikipedia/wikidump_lines.json' + assert_str = "make sure to set PATH for wikipedia data_utils/corpora.py" + def __init__(self, **kwargs): + assert os.path.exists(wikipedia.PATH), \ + wikipedia.assert_str + if not kwargs: + kwargs = {} + kwargs['text_key'] = 'text' + kwargs['loose_json'] = True + super(wikipedia, self).__init__(wikipedia.PATH, **kwargs) + + +class webtext(json_dataset): + """ + dataset for webtext with arguments configured for convenience + + command line usage: `--train-data webtext` + """ + PATH = 'data/webtext/data.json' + assert_str = "make sure to set PATH for webtext data_utils/corpora.py" + def __init__(self, **kwargs): + assert os.path.exists(webtext.PATH), \ + webtext.assert_str + if not kwargs: + kwargs = {} + kwargs['text_key'] = 'text' + kwargs['loose_json'] = True + super(webtext, self).__init__(webtext.PATH, **kwargs) + +class c4_en(json_dataset): + """ + dataset for c4/en with arguments configured for convenience + + command line usage: `--train-data c4/en` + """ + PATH = 'data/c4/en/c4-train.json' + assert_str = "make sure to set PATH for c4/en data_utils/corpora.py" + def __init__(self, **kwargs): + assert os.path.exists(c4_en.PATH), \ + c4_en.assert_str + if not kwargs: + kwargs = {} + kwargs['text_key'] = 'text' + kwargs['loose_json'] = True + super(c4_en, self).__init__(c4_en.PATH, **kwargs) + +NAMED_CORPORA = { + 'wikipedia': wikipedia, + 'webtext': webtext, + 'c4/en': c4_en, +} diff --git a/examples/Megatron-LM/data_utils/datasets.py b/examples/Megatron-LM/data_utils/datasets.py new file mode 100644 index 0000000..7596621 --- /dev/null +++ b/examples/Megatron-LM/data_utils/datasets.py @@ -0,0 +1,884 @@ +# coding=utf-8 +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""dataset objects for jsons, csvs, and BERT datasets""" + +import os +import time +import pickle as pkl +from operator import itemgetter +from bisect import bisect_right +import json +import csv +import math +import random +from itertools import accumulate + +from torch.utils import data +import torch +import pandas as pd +import numpy as np + +import nltk +from nltk import tokenize + +from .lazy_loader import lazy_array_loader, exists_lazy, make_lazy, get_lazy_path +from .tokenization import Tokenization + +class ConcatDataset(data.Dataset): + """ + Dataset to concatenate multiple datasets. + Purpose: useful to assemble different existing datasets, possibly + large-scale datasets as the concatenation operation is done in an + on-the-fly manner. + Arguments: + datasets (sequence): List of datasets to be concatenated. + """ + + @staticmethod + def cumsum(sequence): + r, s = [], 0 + for e in sequence: + l = len(e) + r.append(l + s) + s += l + return r + + def __init__(self, datasets, **kwargs): + super(ConcatDataset, self).__init__() + assert len(datasets) > 0, 'datasets should not be an empty iterable' + self.datasets = list(datasets) + self.is_lazy = sum([isinstance(ds, lazy_array_loader) for ds in self.datasets]) == len(self.datasets) + self.cumulative_sizes = self.cumsum(self.datasets) + self._X = None + self._Y = None + self._lens = None + + def SetTokenizer(self, tokenizer): + for ds in self.datasets: + ds.SetTokenizer(tokenizer) + + def GetTokenizer(self): + return self.datasets[0].GetTokenizer() + + def __len__(self): + return self.cumulative_sizes[-1] + + def __getitem__(self, idx): + dataset_idx = bisect_right(self.cumulative_sizes, idx) + if dataset_idx == 0: + sample_idx = idx + else: + sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] + return self.datasets[dataset_idx][sample_idx] + + @property + def lens(self): + if self._lens is None: + self._lens = [] + if self.is_lazy: + for data in self.datasets: + self._lens.extend(data.lens) + else: + for data in self.datasets: + self._lens.extend([len(d['text']) if isinstance(d, dict) else len(d) for d in data]) + return self._lens + + @property + def X(self): + if self._X is None: + self._X = [] + for data in self.datasets: + self._X.extend(data.X) + return self._X + + @property + def Y(self): + if self._Y is None: + self._Y = [] + for data in self.datasets: + self._Y.extend(list(data.Y)) + self._Y = np.array(self._Y) + return self._Y + + @property + def cummulative_sizes(self): + warnings.warn("cummulative_sizes attribute is renamed to " + "cumulative_sizes", DeprecationWarning, stacklevel=2) + return self.cumulative_sizes + +class SplitDataset(data.Dataset): + """ + Dataset wrapper to access a subset of another dataset. + Purpose: useful to index into existing datasets, possibly + large-scale datasets as the subindexing operation is done in an + on-the-fly manner. + Arguments: + ds (Dataset or array-like): List of datasets to be subindexed + split_inds (1D array-like): List of indices part of subset + """ + def __init__(self, ds, split_inds, **kwargs): + self.split_inds = list(split_inds) + self.wrapped_data = ds + self.is_lazy = isinstance(ds, lazy_array_loader) or (hasattr(ds, 'is_lazy') and ds.is_lazy) + if self.is_lazy: + self.lens = itemgetter(*self.split_inds)(list(self.wrapped_data.lens)) + self._X = None + self._Y = None + + def __len__(self): + return len(self.split_inds) + + def __getitem__(self, index): + return self.wrapped_data[self.split_inds[index]] + + def SetTokenizer(self, tokenizer): + self.wrapped_data.SetTokenizer(tokenizer) + + def GetTokenizer(self): + return self.wrapped_data.GetTokenizer() + + @property + def X(self): + if self._X is None: + self._X = itemgetter(*self.split_inds)(self.wrapped_data.X) + return self._X + + @property + def Y(self): + if self._Y is None: + self._Y = np.array(itemgetter(*self.split_inds)(self.wrapped_data.Y)) + return self._Y + + def __iter__(self): + for idx in self.split_inds: + yield self.wrapped_data[idx] + +def split_ds(ds, split=[.8,.2,.0], shuffle=True): + """ + Split a dataset into subsets given proportions of how + much to allocate per split. If a split is 0% returns None for that split. + Purpose: Useful for creating train/val/test splits + Arguments: + ds (Dataset or array-like): Data to be split. + split (1D array-like): proportions to split `ds`. `sum(splits) != 0` + shuffle (boolean): Randomly split dataset. Default: True + """ + split_sum = sum(split) + if split_sum == 0: + raise Exception('Split cannot sum to 0.') + split = np.array(split) + split /= split_sum + ds_len = len(ds) + inds = np.arange(ds_len) + if shuffle: + np.random.shuffle(inds) + start_idx = 0 + residual_idx = 0 + rtn_ds = [None]*len(split) + for i, f in enumerate(split): + if f != 0: + proportion = ds_len*split[i] + residual_idx += proportion % 1 + split_ = int(int(proportion) + residual_idx) + split_inds = inds[start_idx:start_idx+max(split_, 1)] + rtn_ds[i] = SplitDataset(ds, split_inds) + start_idx += split_ + residual_idx %= 1 + return rtn_ds + +class csv_dataset(data.Dataset): + """ + Class for loading datasets from csv files. + Purpose: Useful for loading data for unsupervised modeling or transfer tasks + Arguments: + path (str): Path to csv file with dataset. + tokenizer (data_utils.Tokenizer): Tokenizer to use when processing text. Default: None + preprocess_fn (callable): Callable that process a string into desired format. + delim (str): delimiter for csv. Default: ',' + binarize_sent (bool): binarize label values to 0 or 1 if they\'re on a different scale. Default: False + drop_unlabeled (bool): drop rows with unlabelled values. Always fills remaining empty + columns with -1 (regardless if rows are dropped based on value) Default: False + text_key (str): key to get text from csv. Default: 'sentence' + label_key (str): key to get label from json dictionary. Default: 'label' + Attributes: + X (list): all strings from the csv file + Y (np.ndarray): labels to train with + """ + def __init__(self, path, tokenizer=None, preprocess_fn=None, delim=',', + binarize_sent=False, drop_unlabeled=False, text_key='sentence', label_key='label', + **kwargs): + self.is_lazy = False + self.preprocess_fn = preprocess_fn + self.SetTokenizer(tokenizer) + self.path = path + self.delim = delim + self.text_key = text_key + self.label_key = label_key + self.drop_unlabeled = drop_unlabeled + + if '.tsv' in self.path: + self.delim = '\t' + + + self.X = [] + self.Y = [] + try: + cols = [text_key] + if isinstance(label_key, list): + cols += label_key + else: + cols += [label_key] + data = pd.read_csv(self.path, sep=self.delim, usecols=cols, encoding='latin-1') + except: + data = pd.read_csv(self.path, sep=self.delim, usecols=[text_key], encoding='latin-1') + + data = data.dropna(axis=0) + + self.X = data[text_key].values.tolist() + try: + self.Y = data[label_key].values + except Exception as e: + self.Y = np.ones(len(self.X))*-1 + + if binarize_sent: + self.Y = binarize_labels(self.Y, hard=binarize_sent) + + def SetTokenizer(self, tokenizer): + if tokenizer is None: + self.using_tokenizer = False + if not hasattr(self, '_tokenizer'): + self._tokenizer = tokenizer + else: + self.using_tokenizer = True + self._tokenizer = tokenizer + + def GetTokenizer(self): + return self._tokenizer + + @property + def tokenizer(self): + if self.using_tokenizer: + return self._tokenizer + return None + + def __len__(self): + return len(self.X) + + def __getitem__(self, index): + """process+tokenize string and return string,label,and stringlen""" + x = self.X[index] + if self.tokenizer is not None: + x = self.tokenizer.EncodeAsIds(x, self.preprocess_fn) + elif self.preprocess_fn is not None: + x = self.preprocess_fn(x) + y = self.Y[index] + if isinstance(y, str): + if self.tokenizer is not None: + y = self.tokenizer.EncodeAsIds(y, self.preprocess_fn) + elif self.preprocess_fn is not None: + y = self.preprocess_fn(y) + return {'text': x, 'length': len(x), 'label': y} + + def write(self, writer_gen=None, path=None, skip_header=False): + """ + given a generator of metrics for each of the data points X_i, + write the metrics, text, and labels to a csv file + """ + if path is None: + path = self.path+'.results' + print('generating csv at ' + path) + with open(path, 'w') as csvfile: + c = csv.writer(csvfile, delimiter=self.delim) + if writer_gen is not None: + #if first item of generator is a header of what the metrics mean then write header to csv file + if not skip_header: + header = (self.label_key,)+tuple(next(writer_gen))+(self.text_key,) + c.writerow(header) + for i, row in enumerate(writer_gen): + row = (self.Y[i],)+tuple(row)+(self.X[i],) + c.writerow(row) + else: + c.writerow([self.label_key, self.text_key]) + for row in zip(self.Y, self.X): + c.writerow(row) + +class json_dataset(data.Dataset): + """ + Class for loading datasets from a json dump. + Purpose: Useful for loading data for unsupervised modeling or transfer tasks + Arguments: + path (str): path to json file with dataset. + tokenizer (data_utils.Tokenizer): Tokenizer to use when processing text. Default: None + preprocess_fn (callable): callable function that process a string into desired format. + Takes string, maxlen=None, encode=None as arguments. Default: process_str + text_key (str): key to get text from json dictionary. Default: 'sentence' + label_key (str): key to get label from json dictionary. Default: 'label' + Attributes: + all_strs (list): list of all strings from the dataset + all_labels (list): list of all labels from the dataset (if they have it) + """ + + CACHE_LEN = 1000000 + def __init__(self, path, tokenizer=None, preprocess_fn=None, binarize_sent=False, + text_key='sentence', label_key='label', loose_json=False, **kwargs): + self.is_lazy = kwargs.get('is_lazy', False) + self.preprocess_fn = preprocess_fn + self.path = path + self.SetTokenizer(tokenizer) + self.X = [] + self.Y = [] + self.text_key = text_key + self.label_key = label_key + self.loose_json = loose_json + self.str_lens = [] + + for j in self.load_json_stream(self.path): + s = j[text_key] + self.X.append(s) + self.Y.append(j[label_key]) + if self.is_lazy and len(self.X) % json_dataset.CACHE_LEN == 0: + if self.make_lazy(self.path, self.X): + return + del self.X[:] + + if self.is_lazy: + self.make_lazy(self.path, self.X, is_pkl=True) + + if binarize_sent: + self.Y = binarize_labels(self.Y, hard=binarize_sent) + + def SetTokenizer(self, tokenizer): + if tokenizer is None: + self.using_tokenizer = False + if not hasattr(self, '_tokenizer'): + self._tokenizer = tokenizer + else: + self.using_tokenizer = True + self._tokenizer = tokenizer + + def GetTokenizer(self): + return self._tokenizer + + @property + def tokenizer(self): + if self.using_tokenizer: + return self._tokenizer + return None + + def __getitem__(self, index): + """gets the index'th string from the dataset""" + x = self.X[index] + if self.tokenizer is not None: + x = self.tokenizer.EncodeAsIds(x, self.preprocess_fn) + elif self.preprocess_fn is not None: + x = self.preprocess_fn(x) + y = self.Y[index] + if isinstance(y, str): + if self.tokenizer is not None: + y = self.tokenizer.EncodeAsIds(y, self.preprocess_fn) + elif self.preprocess_fn is not None: + y = self.preprocess_fn(y) + return {'text': x, 'length': len(x), 'label': y} + + def __len__(self): + return len(self.X) + + def make_lazy(self, path, strs, data_type='data', is_pkl=False): + """ + Make lazy version of `data_type` field of the file. Byte offsets + corresponding to data indices are stored in a `.len.pkl` data file. + """ + lazypath = get_lazy_path(path) + datapath = os.path.join(lazypath, data_type) + lenpath = os.path.join(lazypath, data_type+'.len.pkl') + + if torch.distributed.is_initialized() and torch.distributed.get_rank() > 0: + while not os.path.exists(lenpath): + time.sleep(1) + return True + + if not os.path.exists(lazypath): + os.makedirs(lazypath) + + with open(datapath, 'ab') as f: + str_cnt = 0 + for s in strs: + if isinstance(s, dict): + s = s['text'] + encoded = s.encode('utf-8') + f.write(encoded) + str_cnt = len(encoded) + self.str_lens.append(str_cnt) + + if is_pkl: + pkl.dump(self.str_lens, open(lenpath, 'wb')) + + return False + + def write(self, writer_gen=None, path=None, skip_header=False): + """ + given a generator of metrics for each of the data points X_i, + write the metrics, text, and labels to a json file + """ + if path is None: + path = self.path+'.results' + + jsons = [] + + if writer_gen is not None: + #if first item of generator is a header of what the metrics mean then write header to csv file + def gen_helper(): + keys = {} + keys[0] = self.label_key + if not skip_header: + for idx, k in enumerate(tuple(next(writer_gen))): + keys[idx+1] = k + for i, row in enumerate(writer_gen): + if i == 0 and skip_header: + for idx, _ in enumerate(row): + keys[idx+1] = 'metric_%d'%(idx,) + j = {} + for idx, v in enumerate((self.Y[i],)+tuple(row)): + k = keys[idx] + j[k] = v + yield j + else: + def gen_helper(): + for y in self.Y: + j = {} + j[self.label_key] = y + yield j + + def out_stream(): + for i, j in enumerate(gen_helper()): + j[self.text_key] = self.X[i] + yield j + + self.save_json_stream(path, out_stream()) + + def save_json_stream(self, save_path, json_stream): + if self.loose_json: + with open(save_path, 'w') as f: + for i, j in enumerate(json_stream): + write_string = '' + if i != 0: + write_string = '\n' + write_string += json.dumps(j) + f.write(write_string) + else: + jsons = [j for j in json_stream] + json.dump(jsons, open(save_path, 'w'), separators=(',', ':')) + + def load_json_stream(self, load_path): + if not self.loose_json: + jsons = json.load(open(load_path, 'r')) + generator = iter(jsons) + else: + def gen_helper(): + with open(load_path, 'r') as f: + for row in f: + yield json.loads(row) + generator = gen_helper() + + for j in generator: + if self.label_key not in j: + j[self.label_key] = -1 + yield j + +class GPT2Dataset(data.Dataset): + + def __init__(self, ds, + max_seq_len=1024, + num_samples=None, + weighted=True, + sample_across_doc=True, + random_across_doc_sampling=True, + sentence_start=False, **kwargs): + self.ds = ds + self.ds_len = len(self.ds) + self.num_samples = num_samples + if num_samples is None: + self.num_samples = 1000 * self.ds_len + self.max_seq_len = max_seq_len + self.tokenizer = self.ds.GetTokenizer() + self.ds.SetTokenizer(None) + self.weighted = weighted + self.sample_across_doc = sample_across_doc + self.random_across_doc_sampling = random_across_doc_sampling + self.sentence_start = sentence_start + self.init_weighting() + + def init_weighting(self): + if self.weighted: + if hasattr(self.ds, 'is_lazy') and self.ds.is_lazy: + lens = np.array(self.ds.lens) + else: + lens = np.array([len(d['text']) if isinstance(d, dict) + else len(d) for d in self.ds]) + self.total_len = np.sum(lens) + self.weighting = list(accumulate(lens)) + else: + self.weighting = None + + def get_weighted_samples(self, np_rng): + if self.weighting is not None: + idx = np_rng.randint(self.total_len) + return bisect_right(self.weighting, idx) + else: + return np_rng.randint(self.ds_len) + + def __len__(self): + return self.num_samples + + def __getitem__(self, idx): + # init rng + rng = random.Random(idx) + rng = np.random.RandomState(seed=[rng.randint(0, 2**32-1) for _ in range(16)]) + + # get possibly weighted random index from dataset + data_idx = self.get_weighted_samples(rng) +# data_idx = rng.choice(self.ds_len, p=self.weighting) + tokens = self.getidx(data_idx) + + # truncate or pad tokens + num_tokens = len(tokens) + tokens_to_strip = num_tokens - self.max_seq_len - 1 + if tokens_to_strip > 0: + strip_left_tokens = rng.randint(tokens_to_strip + 1) + tokens = tokens[strip_left_tokens:] + if self.sentence_start: + token_copy = list(tokens) + not_done = True + while (len(token_copy) > 0) and not_done: + tok = token_copy.pop(0) + if self.contains_sentence_end(tok): + tokens = token_copy + not_done = False + strip_right_rokens = len(tokens) - self.max_seq_len - 1 + if strip_right_rokens > 0: + tokens = tokens[:-strip_right_rokens] + + if self.sample_across_doc: + while (len(tokens) < (self.max_seq_len + 1)): + if self.random_across_doc_sampling: + data_idx = self.get_weighted_samples(rng) + else: + data_idx = (data_idx + 1) % self.ds_len + tokens += self.getidx(data_idx) + tokens = tokens[:(self.max_seq_len+1)] + + tokens = self.pad_seq(tokens) + return {'text': np.array(tokens),} + + def getidx(self, data_idx): + data = self.ds[data_idx] + if isinstance(data, dict): + data = data['text'] + # tokenize + tokenization = self.tokenizer.EncodeAsIds(data) + tokenization.append(self.tokenizer.get_command('eos')) + tokens = tokenization.tokenization + return tokens + + def pad_seq(self, seq): + total_tokens = self.max_seq_len + 1 + num_pad_tokens = max(0, total_tokens - len(seq)) + seq += [self.tokenizer.get_command('pad').Id]*(num_pad_tokens) + return seq + + def contains_sentence_end(self, tok): + tok = self.tokenizer.IdToToken(tok) + if '.' in tok: + return True + if '?' in tok: + return True + if '!' in tok: + return True + return False + +class bert_sentencepair_dataset(data.Dataset): + """ + Dataset containing sentencepairs for BERT training. Each index corresponds to a randomly generated sentence pair. + Arguments: + ds (Dataset or array-like): data corpus to use for training + max_seq_len (int): maximum sequence length to use for a sentence pair + mask_lm_prob (float): proportion of tokens to mask for masked LM + max_preds_per_seq (int): Maximum number of masked tokens per sentence pair. Default: math.ceil(max_seq_len*mask_lm_prob/10)*10 + short_seq_prob (float): Proportion of sentence pairs purposefully shorter than max_seq_len + dataset_size (int): number of random sentencepairs in the dataset. Default: len(ds)*(len(ds)-1) + + """ + def __init__(self, ds, max_seq_len=512, mask_lm_prob=.15, max_preds_per_seq=None, short_seq_prob=.01, dataset_size=None, presplit_sentences=False, weighted=True,**kwargs): + self.ds = ds + self.ds_len = len(self.ds) + self.tokenizer = self.ds.GetTokenizer() + self.vocab_words = list(self.tokenizer.text_token_vocab.values()) + self.ds.SetTokenizer(None) + self.max_seq_len = max_seq_len + self.mask_lm_prob = mask_lm_prob + if max_preds_per_seq is None: + max_preds_per_seq = math.ceil(max_seq_len*mask_lm_prob /10)*10 + self.max_preds_per_seq = max_preds_per_seq + self.short_seq_prob = short_seq_prob + self.dataset_size = dataset_size + if self.dataset_size is None: + self.dataset_size = self.ds_len * (self.ds_len-1) + self.presplit_sentences = presplit_sentences + if not self.presplit_sentences: + nltk.download('punkt', download_dir="./nltk") + self.weighted = weighted + self.get_weighting() + + def get_weighting(self): + if self.weighted: + if hasattr(self.ds, 'is_lazy') and self.ds.is_lazy: + lens = np.array(self.ds.lens) + else: + lens = np.array([len(d['text']) if isinstance(d, dict) else len(d) for d in self.ds]) + self.total_len = np.sum(lens) + self.weighting = list(accumulate(lens)) + else: + self.weighting = None + + def get_weighted_samples(self, np_rng): + if self.weighting is not None: + idx = np_rng.randint(self.total_len) + return bisect_right(self.weighting, idx) + else: + return np_rng.randint(self.ds_len) + + def __len__(self): + return self.dataset_size + + def __getitem__(self, idx): + # get rng state corresponding to index (allows deterministic random pair) + rng = random.Random(idx) + np_rng = np.random.RandomState(seed=[rng.randint(0, 2**32-1) for _ in range(16)]) + # get seq length + target_seq_length = self.max_seq_len + short_seq = False + if rng.random() < self.short_seq_prob: + target_seq_length = rng.randint(2, target_seq_length) + short_seq = True + + # get sentence pair and label + is_random_next = None + lena = 0 + lenb = 0 + while (is_random_next is None) or (lena < 1) or (lenb < 1): + tokensa, tokensb, is_random_next = self.create_random_sentencepair(target_seq_length, rng, np_rng) + lena = len(tokensa[0]) + lenb = len(tokensb[0]) + + # truncate sentence pair to max_seq_len + tokensa, tokensb = self.truncate_seq_pair(tokensa, tokensb, self.max_seq_len, rng) + # join sentence pair, mask, and pad + tokens, mask, mask_labels, pad_mask = self.create_masked_lm_predictions(tokensa, tokensb, self.mask_lm_prob, self.max_preds_per_seq, self.vocab_words, rng) + sample = {'text': np.array(tokens[0]), 'types': np.array(tokens[1]), 'is_random': int(is_random_next), 'mask': np.array(mask), 'mask_labels': np.array(mask_labels), 'pad_mask': np.array(pad_mask)} + return sample + + def sentence_split(self, document): + """split document into sentences""" + lines = document.split('\n') + if self.presplit_sentences: + return [line for line in lines if line] + rtn = [] + for line in lines: + if line != '': + rtn.extend(tokenize.sent_tokenize(line)) + return rtn + + def sentence_tokenize(self, sent, sentence_num=0, beginning=False, ending=False): + """tokenize sentence and get token types""" + tokens = self.tokenizer.EncodeAsIds(sent).tokenization + str_type = 'str' + str(sentence_num) + token_types = [self.tokenizer.get_type(str_type).Id]*len(tokens) + return tokens, token_types + + def get_doc(self, idx): + """gets text of document corresponding to idx""" + rtn = self.ds[idx] + if isinstance(rtn, dict): + rtn = rtn['text'] + return rtn + + def create_random_sentencepair(self, target_seq_length, rng, np_rng): + """ + fetches a random sentencepair corresponding to rng state similar to + https://github.com/google-research/bert/blob/master/create_pretraining_data.py#L248-L294 + """ + is_random_next = None + + curr_strs = [] + curr_str_types = [] + curr_len = 0 + + while curr_len < 1: + curr_len = 0 + doc_a = None + while doc_a is None: + if self.weighted: + # doc_a_idx = np_rng.choice(self.ds_len, p=self.weighting) + doc_a_idx = self.get_weighted_samples(np_rng) + else: + doc_a_idx = rng.randint(0, self.ds_len-1) + doc_a = self.sentence_split(self.get_doc(doc_a_idx)) + if not doc_a: + doc_a = None + + random_start_a = rng.randint(0, len(doc_a)-1) + while random_start_a < len(doc_a): + sentence = doc_a[random_start_a] + sentence, sentence_types = self.sentence_tokenize(sentence, 0, random_start_a == 0, random_start_a == len(doc_a)) + curr_strs.append(sentence) + curr_str_types.append(sentence_types) + curr_len += len(sentence) + if random_start_a == len(doc_a) - 1 or curr_len >= target_seq_length: + break + random_start_a = (random_start_a+1) + + if curr_strs: + num_a = 1 + if len(curr_strs) >= 2: + num_a = rng.randint(0, len(curr_strs)) + + tokens_a = [] + token_types_a = [] + for j in range(num_a): + tokens_a.extend(curr_strs[j]) + token_types_a.extend(curr_str_types[j]) + + tokens_b = [] + token_types_b = [] + is_random_next = False + if len(curr_strs) == 1 or rng.random() < 0.5: + is_random_next = True + target_b_length = target_seq_length - len(tokens_a) + b_len = 0 + while b_len < 1: + doc_b = None + while doc_b is None: + doc_b_idx = rng.randint(0, self.ds_len - 2) + doc_b_idx += int(doc_b_idx >= doc_a_idx) + + doc_b = self.sentence_split(self.get_doc(doc_b_idx)) + if not doc_b: + doc_b = None + + random_start_b = rng.randint(0, len(doc_b)-1) + while random_start_b < len(doc_b): + sentence_b = doc_b[random_start_b] + new_b_tokens, new_b_types = self.sentence_tokenize(sentence_b, 1, random_start_b == 0, random_start_b == len(doc_b)) + b_len += len(new_b_tokens) + tokens_b.extend(new_b_tokens) + token_types_b.extend(new_b_types) + if len(tokens_b) >= target_b_length: + break + random_start_b = (random_start_b+1) + else: + is_random_next = False + for j in range(num_a, len(curr_strs)): + tokens_b.extend(curr_strs[j]) + token_types_b.extend(curr_str_types[j]) + + return (tokens_a, token_types_a), (tokens_b, token_types_b), is_random_next + + def truncate_seq_pair(self, a, b, max_seq_len, rng): + """ + Truncate sequence pair according to original BERT implementation: + https://github.com/google-research/bert/blob/master/create_pretraining_data.py#L391 + """ + tokens_a, token_types_a = a + tokens_b, token_types_b = b + max_num_tokens = max_seq_len - 3 + while True: + len_a = len(tokens_a) + len_b = len(tokens_b) + total_length = len_a + len_b + if total_length <= max_num_tokens: + break + if len(tokens_a) > len(tokens_b): + trunc_tokens = tokens_a + trunc_types = token_types_a + else: + trunc_tokens = tokens_b + trunc_types = token_types_b + + assert len(trunc_tokens) >= 1 + + if rng.random() < 0.5: + trunc_tokens.pop(0) + trunc_types.pop(0) + else: + trunc_tokens.pop() + trunc_types.pop() + return (tokens_a, token_types_a), (tokens_b, token_types_b) + + def mask_token(self, idx, tokens, types, vocab_words, rng): + """ + helper function to mask `idx` token from `tokens` according to + section 3.3.1 of https://arxiv.org/pdf/1810.04805.pdf + """ + label = tokens[idx] + if rng.random() < 0.8: + new_label = self.tokenizer.get_command('MASK').Id + else: + if rng.random() < 0.5: + new_label = label + else: + new_label = rng.choice(vocab_words) + + tokens[idx] = new_label + + return label + + def pad_seq(self, seq): + """helper function to pad sequence pair""" + num_pad = max(0, self.max_seq_len - len(seq)) + pad_mask = [0] * len(seq) + [1] * num_pad + seq += [self.tokenizer.get_command('pad').Id] * num_pad + return seq, pad_mask + + def create_masked_lm_predictions(self, a, b, mask_lm_prob, max_preds_per_seq, vocab_words, rng): + """ + Mask sequence pair for BERT training according to: + https://github.com/google-research/bert/blob/master/create_pretraining_data.py#L338 + """ + tokens_a, token_types_a = a + tokens_b, token_types_b = b + tokens = [self.tokenizer.get_command('ENC').Id] + tokens_a + [self.tokenizer.get_command('sep').Id] + tokens_b + [self.tokenizer.get_command('sep').Id] + token_types = [token_types_a[0]] + token_types_a + [token_types_a[0]] + token_types_b + [token_types_b[0]] + + len_a = len(tokens_a) + len_b = len(tokens_b) + + cand_indices = [idx+1 for idx in range(len_a)] + [idx+2+len_a for idx in range(len_b)] + + rng.shuffle(cand_indices) + + output_tokens, pad_mask = self.pad_seq(list(tokens)) + output_types, _ = self.pad_seq(list(token_types)) + + num_to_predict = min(max_preds_per_seq, max(1, int(round(len(tokens) * mask_lm_prob)))) + + mask = [0] * len(output_tokens) + mask_labels = [-1] * len(output_tokens) + + for idx in sorted(cand_indices[:num_to_predict]): + mask[idx] = 1 + label = self.mask_token(idx, output_tokens, output_types, vocab_words, rng) + mask_labels[idx] = label + + return (output_tokens, output_types), mask, mask_labels, pad_mask diff --git a/examples/Megatron-LM/data_utils/file_utils.py b/examples/Megatron-LM/data_utils/file_utils.py new file mode 100755 index 0000000..44ecc41 --- /dev/null +++ b/examples/Megatron-LM/data_utils/file_utils.py @@ -0,0 +1,253 @@ +# This file is provided as is from: +# https://github.com/huggingface/pytorch-pretrained-BERT +# Please refer to their repository for copyright. + +""" +Utilities for working with the local dataset cache. +This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp +Copyright by the AllenNLP authors. +""" +from __future__ import (absolute_import, division, print_function, unicode_literals) + +import json +import logging +import os +import shutil +import tempfile +from functools import wraps +from hashlib import sha256 +import sys +from io import open + +import boto3 +import requests +from botocore.exceptions import ClientError +from tqdm import tqdm + +try: + from urllib.parse import urlparse +except ImportError: + from urlparse import urlparse + +try: + from pathlib import Path + PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', + Path.home() / '.pytorch_pretrained_bert')) +except (AttributeError, ImportError): + PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', + os.path.join(os.path.expanduser("~"), '.pytorch_pretrained_bert')) + +logger = logging.getLogger(__name__) # pylint: disable=invalid-name + + +def url_to_filename(url, etag=None): + """ + Convert `url` into a hashed filename in a repeatable way. + If `etag` is specified, append its hash to the url's, delimited + by a period. + """ + url_bytes = url.encode('utf-8') + url_hash = sha256(url_bytes) + filename = url_hash.hexdigest() + + if etag: + etag_bytes = etag.encode('utf-8') + etag_hash = sha256(etag_bytes) + filename += '.' + etag_hash.hexdigest() + + return filename + + +def filename_to_url(filename, cache_dir=None): + """ + Return the url and etag (which may be ``None``) stored for `filename`. + Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist. + """ + if cache_dir is None: + cache_dir = PYTORCH_PRETRAINED_BERT_CACHE + if sys.version_info[0] == 3 and isinstance(cache_dir, Path): + cache_dir = str(cache_dir) + + cache_path = os.path.join(cache_dir, filename) + if not os.path.exists(cache_path): + raise EnvironmentError("file {} not found".format(cache_path)) + + meta_path = cache_path + '.json' + if not os.path.exists(meta_path): + raise EnvironmentError("file {} not found".format(meta_path)) + + with open(meta_path, encoding="utf-8") as meta_file: + metadata = json.load(meta_file) + url = metadata['url'] + etag = metadata['etag'] + + return url, etag + + +def cached_path(url_or_filename, cache_dir=None): + """ + Given something that might be a URL (or might be a local path), + determine which. If it's a URL, download the file and cache it, and + return the path to the cached file. If it's already a local path, + make sure the file exists and then return the path. + """ + if cache_dir is None: + cache_dir = PYTORCH_PRETRAINED_BERT_CACHE + if sys.version_info[0] == 3 and isinstance(url_or_filename, Path): + url_or_filename = str(url_or_filename) + if sys.version_info[0] == 3 and isinstance(cache_dir, Path): + cache_dir = str(cache_dir) + + parsed = urlparse(url_or_filename) + + if parsed.scheme in ('http', 'https', 's3'): + # URL, so get it from the cache (downloading if necessary) + return get_from_cache(url_or_filename, cache_dir) + elif os.path.exists(url_or_filename): + # File, and it exists. + return url_or_filename + elif parsed.scheme == '': + # File, but it doesn't exist. + raise EnvironmentError("file {} not found".format(url_or_filename)) + else: + # Something unknown + raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) + + +def split_s3_path(url): + """Split a full s3 path into the bucket name and path.""" + parsed = urlparse(url) + if not parsed.netloc or not parsed.path: + raise ValueError("bad s3 path {}".format(url)) + bucket_name = parsed.netloc + s3_path = parsed.path + # Remove '/' at beginning of path. + if s3_path.startswith("/"): + s3_path = s3_path[1:] + return bucket_name, s3_path + + +def s3_request(func): + """ + Wrapper function for s3 requests in order to create more helpful error + messages. + """ + + @wraps(func) + def wrapper(url, *args, **kwargs): + try: + return func(url, *args, **kwargs) + except ClientError as exc: + if int(exc.response["Error"]["Code"]) == 404: + raise EnvironmentError("file {} not found".format(url)) + else: + raise + + return wrapper + + +@s3_request +def s3_etag(url): + """Check ETag on S3 object.""" + s3_resource = boto3.resource("s3") + bucket_name, s3_path = split_s3_path(url) + s3_object = s3_resource.Object(bucket_name, s3_path) + return s3_object.e_tag + + +@s3_request +def s3_get(url, temp_file): + """Pull a file directly from S3.""" + s3_resource = boto3.resource("s3") + bucket_name, s3_path = split_s3_path(url) + s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) + + +def http_get(url, temp_file): + req = requests.get(url, stream=True) + content_length = req.headers.get('Content-Length') + total = int(content_length) if content_length is not None else None + progress = tqdm(unit="B", total=total) + for chunk in req.iter_content(chunk_size=1024): + if chunk: # filter out keep-alive new chunks + progress.update(len(chunk)) + temp_file.write(chunk) + progress.close() + + +def get_from_cache(url, cache_dir=None): + """ + Given a URL, look for the corresponding dataset in the local cache. + If it's not there, download it. Then return the path to the cached file. + """ + if cache_dir is None: + cache_dir = PYTORCH_PRETRAINED_BERT_CACHE + if sys.version_info[0] == 3 and isinstance(cache_dir, Path): + cache_dir = str(cache_dir) + + if not os.path.exists(cache_dir): + os.makedirs(cache_dir) + + # Get eTag to add to filename, if it exists. + if url.startswith("s3://"): + etag = s3_etag(url) + else: + response = requests.head(url, allow_redirects=True) + if response.status_code != 200: + raise IOError("HEAD request failed for url {} with status code {}" + .format(url, response.status_code)) + etag = response.headers.get("ETag") + + filename = url_to_filename(url, etag) + + # get cache path to put the file + cache_path = os.path.join(cache_dir, filename) + + if not os.path.exists(cache_path): + # Download to temporary file, then copy to cache dir once finished. + # Otherwise you get corrupt cache entries if the download gets interrupted. + with tempfile.NamedTemporaryFile() as temp_file: + logger.info("%s not found in cache, downloading to %s", url, temp_file.name) + + # GET file object + if url.startswith("s3://"): + s3_get(url, temp_file) + else: + http_get(url, temp_file) + + # we are copying the file before closing it, so flush to avoid truncation + temp_file.flush() + # shutil.copyfileobj() starts at the current position, so go to the start + temp_file.seek(0) + + logger.info("copying %s to cache at %s", temp_file.name, cache_path) + with open(cache_path, 'wb') as cache_file: + shutil.copyfileobj(temp_file, cache_file) + + logger.info("creating metadata file for %s", cache_path) + meta = {'url': url, 'etag': etag} + meta_path = cache_path + '.json' + with open(meta_path, 'w', encoding="utf-8") as meta_file: + json.dump(meta, meta_file) + + logger.info("removing temp file %s", temp_file.name) + + return cache_path + + +def read_set_from_file(filename): + ''' + Extract a de-duped collection (set) of text from a file. + Expected file format is one item per line. + ''' + collection = set() + with open(filename, 'r', encoding='utf-8') as file_: + for line in file_: + collection.add(line.rstrip()) + return collection + + +def get_file_extension(path, dot=True, lower=True): + ext = os.path.splitext(path)[1] + ext = ext if dot else ext[1:] + return ext.lower() if lower else ext diff --git a/examples/Megatron-LM/data_utils/lazy_loader.py b/examples/Megatron-LM/data_utils/lazy_loader.py new file mode 100644 index 0000000..899a654 --- /dev/null +++ b/examples/Megatron-LM/data_utils/lazy_loader.py @@ -0,0 +1,196 @@ +# coding=utf-8 +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""utils for loading text from disk""" +import os +import mmap +import pickle as pkl +import time +from itertools import accumulate + +import torch +from torch.multiprocessing import Lock + +def get_lazy_path(path): + """ + Gets directory path where lazy files are stored. + """ + return os.path.splitext(path)[0]+'.lazy' + +def exists_lazy(path, data_type='data'): + """ + Check if we've already made a lazy version of this file for the `data_type` field. + """ + if not os.path.exists(get_lazy_path(path)): + return False + contents = os.listdir(get_lazy_path(path)) + if data_type not in contents: + return False + if data_type+'.len.pkl' not in contents: + return False + return True + +def make_lazy(path, strs, data_type='data'): + """ + Make lazy version of `data_type` field of the file. Byte offsets + corresponding to data indices are stored in a `.len.pkl` data file. + """ + lazypath = get_lazy_path(path) + if not os.path.exists(lazypath): + os.makedirs(lazypath) + datapath = os.path.join(lazypath, data_type) + lenpath = os.path.join(lazypath, data_type+'.len.pkl') + if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: + with open(datapath, 'wb') as f: + str_lens = [] + str_cnt = 0 + for s in strs: + if isinstance(s, dict): + s = s['text'] + encoded = s.encode('utf-8') + f.write(encoded) + str_cnt = len(encoded) + str_lens.append(str_cnt) + pkl.dump(str_lens, open(lenpath, 'wb')) + else: + while not os.path.exists(lenpath): + time.sleep(1) + +def split_strings(strings, start, chr_lens): + """ + Split strings based on string lengths and given start. + """ + return [strings[i-start:j-start] for i, j in zip([start]+chr_lens[:-1], chr_lens)] + +class ProcessorTokenizer: + """ + callable class that runs a preprocessing, as well as tokenization step, + on input text. + """ + def __init__(self, tokenizer, process_fn=None): + self.tokenizer = tokenizer + self.process_fn = process_fn + + def __call__(self, string): + if self.tokenizer is not None: + string = self.tokenizer(string, process_fn=self.process_fn) + elif self.process_fn is not None: + string = self.process_fn(string) + return string + +class lazy_array_loader(object): + """ + Arguments: + path: path to directory where array entries are concatenated into one big string file + and the .len file are located + data_type (str): Some datsets have multiple fields that are stored in different paths. + `data_type` specifies which of these fields to load in this class + mem_map (boolean): Specifies whether to memory map file `path` + map_fn (callable): Fetched strings are passed through map_fn before being returned. + + Example of lazy loader directory structure: + file.json + file.lazy/ + data_type1 + data_type1.len.pkl + data_type2 + data_type2.len.pkl + """ + def __init__(self, path, data_type='data', mem_map=False, map_fn=None): + lazypath = get_lazy_path(path) + datapath = os.path.join(lazypath, data_type) + #get file where array entries are concatenated into one big string + self._file = open(datapath, 'rb') + self.file = self._file + #memory map file if necessary + self.mem_map = mem_map + self.is_lazy = True + if self.mem_map: + self.file = mmap.mmap(self.file.fileno(), 0, prot=mmap.PROT_READ) + lenpath = os.path.join(lazypath, data_type+'.len.pkl') + self.lens = pkl.load(open(lenpath, 'rb')) + self.ends = list(accumulate(self.lens)) + self.dumb_ends = list(self.ends) + self.read_lock = Lock() + self.process_fn = map_fn + self.map_fn = map_fn + self._tokenizer = None + + def SetTokenizer(self, tokenizer): + """ + logic to set and remove (set to None) tokenizer. + combines preprocessing/tokenization into one callable. + """ + if tokenizer is None: + if not hasattr(self, '_tokenizer'): + self._tokenizer = tokenizer + else: + self._tokenizer = tokenizer + self.map_fn = ProcessorTokenizer(tokenizer, self.process_fn) + + def GetTokenizer(self): + return self._tokenizer + + def __getitem__(self, index): + """ + read file and splice strings based on string ending array `self.ends` + """ + if not isinstance(index, slice): + if index == 0: + start = 0 + else: + start = self.ends[index-1] + end = self.ends[index] + rtn = self.file_read(start, end) + if self.map_fn is not None: + return self.map_fn(rtn) + else: + # if slice, fetch strings with 1 diskread and then splice in memory + chr_lens = self.ends[index] + if index.start == 0 or index.start is None: + start = 0 + else: + start = self.ends[index.start-1] + stop = chr_lens[-1] + strings = self.file_read(start, stop) + rtn = split_strings(strings, start, chr_lens) + if self.map_fn is not None: + return self.map_fn([s for s in rtn]) + return rtn + + def __len__(self): + return len(self.ends) + + def file_read(self, start=0, end=None): + """read specified portion of file""" + + # atomic reads to avoid race conditions with multiprocess dataloader + self.read_lock.acquire() + # seek to start of file read + self.file.seek(start) + # read to end of file if no end point provided + if end is None: + rtn = self.file.read() + #else read amount needed to reach end point + else: + rtn = self.file.read(end-start) + self.read_lock.release() + #TODO: @raulp figure out mem map byte string bug + #if mem map'd need to decode byte string to string + rtn = rtn.decode('utf-8', 'ignore') + # rtn = str(rtn) + if self.mem_map: + rtn = rtn.decode('unicode_escape') + return rtn + diff --git a/examples/Megatron-LM/data_utils/samplers.py b/examples/Megatron-LM/data_utils/samplers.py new file mode 100644 index 0000000..c42a381 --- /dev/null +++ b/examples/Megatron-LM/data_utils/samplers.py @@ -0,0 +1,139 @@ +# coding=utf-8 +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""batch samplers that work with either random or sequential data samplers""" +import math +import os +import sys + +import torch +from torch.utils import data +import numpy as np + +class RandomSampler(data.sampler.Sampler): + r""" + Based off of pytorch RandomSampler and DistributedSampler. Essentially a RandomSampler, + but this class lets the user set an epoch like DistributedSampler + Samples elements randomly. If without replacement, then sample from a shuffled dataset. + If with replacement, then user can specify ``num_samples`` to draw. + Arguments: + data_source (Dataset): dataset to sample from + num_samples (int): number of samples to draw, default=len(dataset) + replacement (bool): samples are drawn with replacement if ``True``, default=False + """ + + def __init__(self, data_source, replacement=False, num_samples=None): + self.data_source = data_source + self.replacement = replacement + self._num_samples = num_samples + self.epoch = -1 + + if self._num_samples is not None and replacement is False: + raise ValueError("With replacement=False, num_samples should not be specified, " + "since a random permute will be performed.") + + if not isinstance(self.num_samples, int) or self.num_samples <= 0: + raise ValueError("num_samples should be a positive integer " + "value, but got num_samples={}".format(self.num_samples)) + if not isinstance(self.replacement, bool): + raise ValueError("replacement should be a boolean value, but got " + "replacement={}".format(self.replacement)) + + @property + def num_samples(self): + # dataset size might change at runtime + if self._num_samples is None: + return len(self.data_source) + return self._num_samples + + def __iter__(self): + n = len(self.data_source) + g = torch.Generator() + if self.epoch >= 0: + g.manual_seed(self.epoch) + if self.replacement: + return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64, generator=g).tolist()) + return iter(torch.randperm(n, generator=g).tolist()) + + def __len__(self): + return self.num_samples + + def set_epoch(self, epoch): + self.epoch = epoch + +class DistributedBatchSampler(data.sampler.BatchSampler): + """ + similar to normal implementation of distributed sampler, except implementation is at the + batch sampler level, instead of just the sampler level. This allows wrapping of arbitrary + data samplers (sequential, random, WeightedRandomSampler, etc.) with this batch sampler. + """ + def __init__(self, sampler, batch_size, drop_last, rank=-1, world_size=2, wrap_last=False): + super(DistributedBatchSampler, self).__init__(sampler, batch_size, drop_last) + if rank == -1: + assert False, 'should not be here' + rank = torch.distributed.get_rank() + self.rank = rank + self.world_size = world_size + self.sampler.wrap_around = 0 + self.wrap_around = 0 + self.wrap_last = wrap_last + self.start_iter = 0 + + def __iter__(self): + batch = [] + last_batch = None + i = 0 + for idx in self.data_iterator(self.sampler, wrap_around=False): + batch.append(idx) + if len(batch) == self.batch_size: + tbatch = self._batch(batch) + if i >= self.start_iter: + yield tbatch + self.start_iter = 0 + i += 1 + last_batch = np.array(list(tbatch)) + batch = [] + batch_len = len(batch) + if batch_len > 0 and not self.drop_last: + if self.wrap_last: + self.sampler.wrap_around -= (self.batch_size) + self.wrap_around += (len(batch)) + self.wrap_around %= self.batch_size + if isinstance(self.sampler, TransposedSampler): + for i, idx in enumerate(self.data_iterator(self.sampler, wrap_around=True)): + if i == 0: + continue + batch.append(idx) + new_batch_len = len(batch) + if len(batch) == self.batch_size: + break + yield self._batch(batch) + if self.wrap_last: + self.sampler.wrap_around += self.batch_size + + def data_iterator(self, _iter, wrap_around=False): + """iterates through data and handles wrap around""" + for i, idx in enumerate(_iter): + if i < self.wrap_around%self.batch_size: + continue + if wrap_around: + self.wrap_around += 1 + self.wrap_around %= self.batch_size + yield idx + + def _batch(self, batch): + """extracts samples only pertaining to this worker's batch""" + start = self.rank*self.batch_size//self.world_size + end = (self.rank+1)*self.batch_size//self.world_size + return batch[start:end] diff --git a/examples/Megatron-LM/data_utils/tf_dl.py b/examples/Megatron-LM/data_utils/tf_dl.py new file mode 100755 index 0000000..29b4056 --- /dev/null +++ b/examples/Megatron-LM/data_utils/tf_dl.py @@ -0,0 +1,121 @@ +# coding=utf-8 +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch DataLoader for TFRecords""" + +import queue +import threading + +import tensorflow as tf +tf.enable_eager_execution() +import torch +import numpy as np + +class TFRecordDataLoader(object): + def __init__(self, records, batch_size, max_seq_len, max_preds_per_seq, train, num_workers=2, seed=1, threaded_dl=False): + assert max_preds_per_seq is not None, "--max-preds-per-seq MUST BE SPECIFIED when using tfrecords" + tf.set_random_seed(seed) + if isinstance(records, str): + records = [records] + + self.record_converter = Record2Example({"input_ids": tf.FixedLenFeature([max_seq_len], tf.int64), + "input_mask": tf.FixedLenFeature([max_seq_len], tf.int64), + "segment_ids": tf.FixedLenFeature([max_seq_len], tf.int64), + "masked_lm_positions": tf.FixedLenFeature([max_preds_per_seq], tf.int64), + "masked_lm_ids": tf.FixedLenFeature([max_preds_per_seq], tf.int64), + "masked_lm_weights": tf.FixedLenFeature([max_preds_per_seq], tf.float32), + "next_sentence_labels": tf.FixedLenFeature([1], tf.int64)}) + + #Instantiate dataset according to original BERT implementation + if train: + self.dataset = tf.data.Dataset.from_tensor_slices(tf.constant(records)) + self.dataset = self.dataset.repeat() + self.dataset = self.dataset.shuffle(buffer_size=len(records)) + + # use sloppy tfrecord dataset + self.dataset = self.dataset.apply( + tf.contrib.data.parallel_interleave( + tf.data.TFRecordDataset, + sloppy=train, + cycle_length=min(num_workers, len(records)))) + self.dataset = self.dataset.shuffle(buffer_size=100) + else: + self.dataset = tf.data.TFRecordDataset(records) + self.dataset = self.dataset.repeat() + + # Instantiate dataloader (do not drop remainder for eval) + loader_args = {'batch_size': batch_size, + 'num_parallel_batches': num_workers, + 'drop_remainder': train} + self.dataloader = self.dataset.apply(tf.contrib.data.map_and_batch(self.record_converter, **loader_args)) + self.threaded_dl = threaded_dl + self.num_workers = num_workers + + def __iter__(self): + if self.threaded_dl: + data_iter = iter(MultiprocessLoader(self.dataloader, self.num_workers)) + for item in data_iter: + yield item + else: + data_iter = iter(self.dataloader) + for item in data_iter: + yield convert_tf_example_to_torch_tensors(item) + +class Record2Example(object): + def __init__(self, feature_map): + self.feature_map = feature_map + + def __call__(self, record): + """Decodes a BERT TF record to a TF example.""" + example = tf.parse_single_example(record, self.feature_map) + for k, v in list(example.items()): + if v.dtype == tf.int64: + example[k] = tf.to_int32(v) + return example + +def convert_tf_example_to_torch_tensors(example): + item = {k: (v.numpy()) for k,v in example.items()} + mask = np.zeros_like(item['input_ids']) + mask_labels = np.ones_like(item['input_ids'])*-1 + for b, row in enumerate(item['masked_lm_positions'].astype(int)): + for i, idx in enumerate(row): + if item['masked_lm_weights'][b, i] != 0: + mask[b, idx] = 1 + mask_labels[b, idx] = item['masked_lm_ids'][b, i] + output = {'text': item['input_ids'], 'types': item['segment_ids'],'is_random': item['next_sentence_labels'], + 'pad_mask': 1-item['input_mask'], 'mask': mask, 'mask_labels': mask_labels} + return {k: torch.from_numpy(v) for k,v in output.items()} + +class MultiprocessLoader(object): + def __init__(self, dataloader, num_workers=2): + self.dl = dataloader + self.queue_size = 2*num_workers + + def __iter__(self): + output_queue = queue.Queue(self.queue_size) + output_thread = threading.Thread(target=_multiproc_iter, + args=(self.dl, output_queue)) + output_thread.daemon = True + output_thread.start() + + while output_thread.is_alive(): + yield output_queue.get(block=True) + else: + print(RuntimeError('TF record data loader thread exited unexpectedly')) + +def _multiproc_iter(dl, output_queue): + data_iter = iter(dl) + for item in data_iter: + tensors = convert_tf_example_to_torch_tensors(item) + output_queue.put(tensors, block=True) \ No newline at end of file diff --git a/examples/Megatron-LM/data_utils/tokenization.py b/examples/Megatron-LM/data_utils/tokenization.py new file mode 100755 index 0000000..36a487a --- /dev/null +++ b/examples/Megatron-LM/data_utils/tokenization.py @@ -0,0 +1,890 @@ +# coding=utf-8 +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utilities for using and training tokenizers (char, wordpiece, sentencepiece)""" +from collections import namedtuple +import random +import os +import csv +import torch + +import nltk +from nltk import tokenize as nltk_tokenize +import sentencepiece as spm + +from .wordpiece import BertTokenizer, PRETRAINED_VOCAB_ARCHIVE_MAP + +from .tokenization_gpt2 import GPT2Tokenizer +import regex as re + +def make_tokenizer(tokenizer_type, corpus, model_path=None, vocab_size=None, model_type='bpe', pad_token=0, character_coverage=1.0, command_tokens=None, type_tokens=None, **kwargs): + """ + Helper function to instantiate a tokenizer given common combinations of options. + """ + tokenizer_class = tokenizer_type + if isinstance(tokenizer_class, str): + tokenizer_class = eval(tokenizer_class) + if tokenizer_class is BertWordPieceTokenizer: + return BertWordPieceTokenizer(model_type, **kwargs) + elif tokenizer_class is GPT2BPETokenizer: + return GPT2BPETokenizer(**kwargs) + text_tokenizer = tokenizer_class(corpus=corpus, vocab_size=vocab_size, model_path=model_path, model_type=model_type, + pad_token=pad_token, character_coverage=character_coverage) + return Tokenizer(text_tokenizer, command_tokens, type_tokens) + +class Tokenization(object): + """ + Tokenization object to hold tokenization, (processed text),and original + text. Can hold tokenization as Ids or tokens. + + It also holds command tokens (pad, unk, etc.) for the tokenization. + This allows functions to pad/operate on tokenizations without having + access to the full tokenizer, just the tokenization. + + Several standard array operations are implemented (insert, append, extend). + """ + def __init__(self, tokenization, text=None, original_text=None, command_tokens=None, asIds=True): + self.tokenization = tokenization + self.text = text + if self.text is None: + self.text = self.tokenization + self.original_text = original_text + if self.original_text is None: + self.original_text = self.text + self.command_tokens = command_tokens + self.asIds = asIds + self.parse_command_tokens() + + def set_command_tokens(self, command_tokens): + self.command_tokens = command_tokens + return self.parse_command_tokens() + + def parse_command_tokens(self): + if self.command_tokens is None: + return + for command_token in self.command_tokens: + if self.asIds: + setattr(self, command_token.name, command_token.Id) + else: + setattr(self, command_token.name, command_token.token) + + def __getitem__(self, index): + return self.tokenization[index] + + def __len__(self): + return len(self.tokenization) + + def insert(self, idx, other): + if isinstance(other, (CommandToken, TypeToken)): + self.tokenization.insert(idx, other.Id) + if idx == 0: + self.text = other.token + self.text + self.original_text = other.token + self.original_text + elif idx == len(self.tokenization)-1: + self.text += other.token + self.original_text += other.token + elif isinstance(other, Tokenization): + self.tokenization = self.tokenization[:idx] + other.tokenization + self.tokenization[idx:] + else: + self.tokenization = self.tokenization[:idx] + other.tokenization + self.tokenization[idx:] + + def append(self, other): + if isinstance(other, (CommandToken, TypeToken)): + self.tokenization.append(other.Id) + self.text += other.token + self.original_text += other.token + elif isinstance(other, Tokenization): + self.tokenization.extend(other.tokenization) + self.text += other.text + self.original_text += other.original_text + else: + self.tokenization.append(other) + return self + + def extend(self, other): + if isinstance(other, (CommandToken, TypeToken)): + self.tokenization.append(other.Id) + self.text += other.token + self.original_text += other.token + elif isinstance(other, list) and isinstance(other[0], (CommandToken, TypeToken)): + self.tokenization.extend([o.Id for o in other]) + self.text += [o.token for o in other] + self.original_text += [o.token for o in other] + elif isinstance(other, Tokenization): + self.tokenization.extend(other.tokenization) + self.text += other.text + self.original_text += other.original_text + else: + self.tokenization.extend(other) + return self + +"""define some default command tokens for the tokenizer to use""" +token_format = "<{0}>" + +COMMAND_TUPLE = namedtuple('CommandToken', ('name', 'token', 'Id')) + +def prep_command_tokens(tokenlist, token_format=token_format): + return [CommandToken(tok[0], token_format.format(tok[0]), tok[1]) for tok in tokenlist] + +class CommandToken(object): + def __init__(self, name, token, Id): + self.name = name + self.token = token + self.Id = Id + + def __str__(self): + return str(COMMAND_TUPLE(self.name, self.token, self.Id)) + +DEFAULT_COMMAND_TOKENS = [ + ('pad', 0), + ('eos', 1), + ('bos', 2), + ('unk', 3), + ('sep', 4), + ('L2R', 5), + ('ENC', 6), + ('MASK', 7), +] +DEFAULT_COMMAND_TOKENS = prep_command_tokens(DEFAULT_COMMAND_TOKENS) + +"""define some default type tokens for bert training""" + +TYPE_TUPLE = namedtuple('TypeToken', ('name', 'token', 'Id')) + +def prep_type_tokens(tokenlist, token_format=token_format): + return [TypeToken(tok[0], token_format.format(tok[0]), tok[1]) for tok in tokenlist] + +class TypeToken(object): + def __init__(self, name, token, Id): + self.name = name + self.token = token + self.Id = Id + + def __str__(self): + return str(TYPE_TUPLE(self.name, self.token, self.Id)) + +DEFAULT_TYPE_TOKENS = [ + ('function', 0), + ('command', 1), + ('str0', 2), + ('str1', 3), + ('str2', 4), + ('embedding0', 5), + ('embedding1', 6), + ('embedding2', 7), + ('arg0', 8), + ('arg1', 9), + ('arg2', 10), +] +DEFAULT_TYPE_TOKENS = prep_type_tokens(DEFAULT_TYPE_TOKENS) + +class Tokenizer(object): + """ + Tokenizer object that handles text tokenization, command tokens, and type tokens. + + Command tokens and text tokens are stored together in one mapping of size + `len(text_tokenizer)+len(command_tokens)`. Command tokens are stored as first + `len(command_tokens)` tokens. Token idx is stored at `idx+len(command_tokens)`. + + Token types are stored in a separate mapping of size `len(type_tokens)`. + """ + def __init__(self, text_tokenizer, command_tokens=None, type_tokens=None): + # set text tokenizer + self.text_tokenizer = text_tokenizer + if not hasattr(self, 'num_text_tokens'): + self.num_text_tokens = len(self.text_tokenizer) + + # set command tokens + if command_tokens is None: + command_tokens = DEFAULT_COMMAND_TOKENS + self._command_tokens = command_tokens + self.command_name_map = {tok.name: tok for tok in self._command_tokens} + self.command_token_map = {tok.token: tok for tok in self._command_tokens} + self.command_id_map = {tok.Id: tok for tok in self._command_tokens} + if not hasattr(self, 'num_command_tokens'): + self.num_command_tokens = len(self._command_tokens) + if not hasattr(self, 'num_tokens'): + self.num_tokens = self.num_command_tokens + self.num_text_tokens + + # set type tokens + if type_tokens is None: + type_tokens = DEFAULT_TYPE_TOKENS + self.type_tokens = type_tokens + self.type_name_map = {tok.name: tok for tok in self.type_tokens} + self.type_token_map = {tok.token: tok for tok in self.type_tokens} + self.type_id_map = {tok.Id: tok for tok in self.type_tokens} + if not hasattr(self, 'num_type_tokens'): + self.num_type_tokens = len(self.type_tokens) + + # parse tokens and vocabs from tokenizer + self._tokens = list(self.command_token_map.keys()) + list(self.text_tokenizer.tokens) + self._vocab = {t:Id for Id,t in self.command_id_map.items()} + self._vocab.update({t:Id+self.num_command_tokens for t,Id in self.text_tokenizer.vocab.items()}) + + self._text_tokens = list(self.text_tokenizer.tokens) + self._text_token_vocab = {t:Id+self.num_command_tokens for t,Id in self.text_tokenizer.vocab.items()} + + self._command_token_tokens = list(self.command_token_map.keys()) + self._command_token_vocab = {t:Id for Id,t in self.command_id_map.items()} + + self._token_types = list(self.type_token_map.keys()) + self._token_type_vocab = {t:Id for Id, t in self.type_id_map.items()} + + + def __call__(self, text, process_fn=None): + """run preprocessing and encode text as Ids""" + return self.EncodeAsIds(text, process_fn=process_fn) + + def __len__(self): + """total number of tokens""" + return self.num_tokens + + def get_command(self, name): + """get command token corresponding to `name`""" + return self.command_name_map[name] + + def get_type(self, name): + """get type token corresponding to `name`""" + return self.type_name_map[name] + + @property + def tokens(self): + """list (or iterable) of all tokens for tokenizer""" + return self._tokens + + @property + def vocab(self): + """dictionary mapping tokens to ids for tokenizer""" + return self._vocab + + @property + def token_types(self): + """list (or iterable) of all token types for tokenizer""" + return self._token_types + + @property + def token_type_vocab(self): + """dictionary mapping token types to ids for tokenizer""" + return self._token_type_vocab + + @property + def command_tokens(self): + """list (or iterable) of all command tokens for tokenizer""" + return self._command_token_tokens + + @property + def command_token_vocab(self): + """dictionary mapping command tokens to ids for tokenizer""" + return self._command_token_vocab + + @property + def text_tokens(self): + """list (or iterable) of text tokens for text tokenizer""" + return self._text_tokens + + @property + def text_token_vocab(self): + """dictionary mapping text tokens to ids for text tokenizer""" + return self._text_token_vocab + + def EncodeAsIds(self, text, process_fn=None): + """ + encode text using text tokenizer and shift Id values for command tokens + """ + tokenization = self.text_tokenizer.EncodeAsIds(text, process_fn=process_fn) + tokenization.tokenization = [t+self.num_command_tokens for t in tokenization.tokenization] + tokenization.set_command_tokens(self._command_tokens) + return tokenization + + def EncodeAsTokens(self, text, process_fn=None): + """ + encode text as tokens using text tokenizer + """ + tokenization = self.text_tokenizer.EncodeAsTokens(text, process_fn=process_fn) + tokenization.set_command_tokens(self._command_tokens) + return tokenization + + def IdToToken(self, Id, type_token=False): + """convert Id to token accounting for command and type tokens""" + if isinstance(Id, (TypeToken, CommandToken)): + return Id.token + if type_token: + return self.type_id_map[Id].token + if Id < self.num_command_tokens: + return self.command_id_map[Id].token + return self.text_tokenizer.IdToToken(Id-self.num_command_tokens) + + def TokenToId(self, token, type_token=False): + """convert token to Id accounting for command and type tokens""" + if isinstance(token, (TypeToken, CommandToken)): + return token.Id + if type_token: + return self.type_token_map[token].Id + if token in self.command_token_map: + return self.command_token_map[token].Id + return self.text_tokenizer.TokenToId(token)+self.num_command_tokens + + def DecodeIds(self, Ids, type_token=False): + """ + convert Ids to tokens accounting for command and type tokens, tokens + are joined and returned as a string. + """ + if type_token: + return ' '.join(Id.token if isinstance(Id, TypeToken) else self.type_id_map[Id].token for Id in Ids) + rtn_strs = [] + current_str = [] + if isinstance(Ids, Tokenization): + Ids = Ids.tokenization + for Id in Ids: + if isinstance(Id, CommandToken): + rtn_strs.append(self.text_tokenizer.DecodeIds(current_str)) + current_str = [] + rtn_strs.append(t.token) + elif Id < self.num_command_tokens: + rtn_strs.append(self.text_tokenizer.DecodeIds(current_str)) + current_str = [] + rtn_strs.append(self.command_id_map[Id].token) + else: + current_str.append(Id - self.num_command_tokens) + if current_str != []: + rtn_strs.append(self.text_tokenizer.DecodeIds(current_str)) + return ' '.join(rtn_strs) + + def DecodeTokens(self, Tokens, type_token=False): + """ + convert tokens to a string accounting for command and type tokens. + """ + if type_token: + return ' '.join(t.token if isinstance(t, TypeToken) else t for t in Tokens) + rtn_strs = [] + current_str = [] + if isinstance(Tokens, Tokenization): + Tokens = Tokens.tokenization + for t in Tokens: + if isinstance(t, CommandToken): + rtn_strs.append(self.text_tokenizer.DecodeTokens(current_str)) + current_str = [] + rtn_strs.append(t.token) + elif t in self.command_token_map: + rtn_strs.append(self.text_tokenizer.DecodeTokens(current_str)) + current_str = [] + rtn_strs.append(t) + else: + current_str.append(t) + if current_str != []: + rtn_strs.append(self.text_tokenizer.DecodeTokens(current_str)) + return ' '.join(rtn_strs) + +class TextTokenizer(object): + """ + Interface for text tokenizer + """ + def __init__(self): + if not hasattr(self, 'num_text_tokens'): + self.num_text_tokens = 0 + if not hasattr(self, 'num_tokens'): + self.num_tokens = self.num_text_tokens + + def __call__(self, text, process_fn=None): + return self.EncodeAsIds(text, process_fn) + + def __len__(self): + return self.num_text_tokens + + @property + def tokens(self): + """list (or iterable) of text tokens for text tokenizer""" + raise NotImplementedError('TextTokenizer tokens property not implemented') + + @property + def vocab(self): + """dictionary mapping tokens to ids""" + raise NotImplementedError('TextTokenizer vocab property not implemented') + + @staticmethod + def exists(model_path): + """check if the filepath for a text tokenizer exists""" + raise NotImplementedError('TextTokenizer exists method not implemented') + + def Train(self, corpus): + """train a tokenizer on a data corpus and save model for future use""" + raise NotImplementedError('TextTokenizer Train not implemented') + + def EncodeAsIds(self, text, process_fn=None): + """ + Preprocess text and encode as ids. Return a tokenization object with + original text, processed text, and id tokenization. + """ + raise NotImplementedError('TextTokenizer EncodeAsIds not implemented') + + def EncodeAsTokens(self, text, process_fn=None): + """ + Preprocess text and encode as tokens. Return a tokenization object with + original text, processed text, and token tokenization. + """ + raise NotImplementedError('TextTokenizer EncodeAsTokens not implemented') + + def IdToToken(self, Id): + """Convert an Id to Token. Reverse lookup of self.vocab""" + raise NotImplementedError('TextTokenizer IdToToken not implemented') + + def TokenToId(self, token): + """Convert a Token to Id. Lookup of self.vocab""" + raise NotImplementedError('TextTokenizer TokenToId not implemented') + + def DecodeIds(self, Ids): + """Convert a list or tokenization object of Ids to a text string""" + raise NotImplementedError('TextTokenizer DecodeIds not implemented') + + def DecodeTokens(self, Tokens): + """Convert a list or tokenization object of tokens to a text string""" + raise NotImplementedError('TextTokenizer DecodeTokens not implemented') + + +class CharacterLevelTokenizer(TextTokenizer): + """ + Text tokenizer for ASCII-256 Character Level Tokenization. + """ + def __init__(self, **kwargs): + self.num_text_tokens = 256 + super(CharacterLevelTokenizer, self).__init__() + self._tokens = [self.IdToToken(Id) for Id in range(self.num_text_tokens)] + self._vocab = {t: i for i,t in enumerate(self._tokens)} + + def __len__(self): + return 256 + + @staticmethod + def exists(model_path): + return True + + def Train(self, corpus): + pass + + @property + def tokens(self): + return self._tokens + + @property + def vocab(self): + return self._vocab + + def EncodeAsIds(self, text, process_fn=None): + """convert text to ascii 256 Ids""" + processed_text = text + if process_fn is not None: + processed_text = process_fn(processed_text) + processed_text = str(processed_text) + tokens = [self.TokenToId(c) for c in processed_text] + return Tokenization(tokens, processed_text, text) + + def EncodeAsTokens(self, text, process_fn=None): + """convert text to ascii 256 characters""" + processed_text = text + if process_fn is not None: + processed_text = process_fn(processed_text) + processed_text = str(processed_text) + tokens = [c for c in processed_text] + return Tokenization(tokens, processed_text, text, asIds=False) + + def IdToToken(self, Id): + """ascii index to character""" + return chr(Id) + + def TokenToId(self, token): + """ascii character to index""" + return ord(token) + + def DecodeIds(self, Ids): + """converts ascii ids to tokens before joining them into text""" + if isinstance(Ids, Tokenization): + Ids = Ids.tokenization + return ''.join([self.IdToToken(tok) for tok in Ids]) + + def DecodeTokens(self, Tokens): + """just concatenates ascii tokens into text""" + if isinstance(Tokens, Tokenization): + Tokens = Tokens.tokenization + return ''.join(Tokens) + + +MAX_SENTENCEPIECE_SENTENCES = 100000000 + +def get_corpus_freq(dataset, filepath, filetype='tsv'): + """ + Take corpus, split it into sentences, and extract word frequencies. + Write frequencies to `filepath` as a tsv. Only write the first + MAX_SENTENCEPIECE_SENTENCES most common words to the file. + """ + nltk.download('punkt', download_dir="./nltk") + if filetype == 'tsv': + delimiter = '\t' + else: + delimiter = ',' + + print("compute corpus frequency\n", flush=True) + + total_sentence_count = 0 + maxlen = 0 + freqs = {} + for entry in dataset: + if isinstance(entry, dict): + entry = entry['text'] + lines = entry.strip().split('\n') + for line in lines: + sentences = nltk_tokenize.sent_tokenize(line) + total_sentence_count += len(sentences) + for sentence in sentences: + maxlen = max(len(line), maxlen) + for word in sentence.split(): + if word not in freqs: + freqs[word] = 0 + freqs[word] += 1 + + print("length of freqs before truncating " + str(len(freqs)), flush=True) + print("file path for freq " + str(filepath), flush=True) + + freqs_sorted = {} + counter=0 + for word, count in sorted(freqs.items(), key=lambda x: x[1], reverse=True): + if counter >= MAX_SENTENCEPIECE_SENTENCES: + break + counter+=1 + freqs_sorted[word] = count + + + print("length of freqs after trancating " + str(len(freqs_sorted)), flush=True) + + with open(filepath, 'w') as f: + writer = csv.writer(f, delimiter=delimiter) + for k, v in freqs_sorted.items(): + writer.writerow([str(k), str(v)]) + + return total_sentence_count, maxlen + +class SentencePieceTokenizer(TextTokenizer): + """Trains and uses sentencepiece for text tokenization""" + def __init__(self, model_type='bpe', vocab_size=None, corpus=None, model_path=None, character_coverage=1.0, **kwargs): + self.character_coverage = character_coverage + self.model_type = model_type.lower() + self.spm_model = model_path + self.num_text_tokens = vocab_size + make_train = not SentencePieceTokenizer.exists(self.spm_model) + if make_train: + assert corpus is not None and self.num_text_tokens is not None + self.Train(corpus, self.num_text_tokens) + self._tokens = [] + self._vocab = {} + self.load_spm_model() + super(SentencePieceTokenizer, self).__init__() + + def __len__(self): + return self.num_text_tokens + + @property + def tokens(self): + return self._tokens + + @property + def vocab(self): + return self._vocab + + @staticmethod + def exists(model_path): + if model_path is None: + return False + # check if path exists + dne = not os.path.exists(model_path) + # check if path.model exists + if dne and not model_path.endswith('.model'): + dne = not os.path.exists(model_path+'.model') + return not dne + + def load_spm_model(self): + """load sentencepiece model and parse vocab""" + if not os.path.exists(self.spm_model) and not self.spm_model.endswith('.model'): + self.spm_model = self.spm_model+'.model' + self.sp = spm.SentencePieceProcessor() + self.sp.Load(self.spm_model) + self.vocab_size = self.num_text_tokens = len(self.sp) + self._tokens = [self.IdToToken(t) for t in range(self.vocab_size)] + self._vocab = {t: i for i,t in enumerate(self._tokens)} + + def Train(self, corpus, num_text_tokens): + """train sentencepiece model on corpus using word frequencies""" + self.num_text_tokens = num_text_tokens + use_model_path = self.spm_model + random_hash = str(random.randint(0, 2147483647)) + if use_model_path is None: + use_model_path = random_hash + if use_model_path.endswith('.model'): + use_model_path = use_model_path[:use_model_path.rfind('.model')] + input_path = use_model_path+'.tsv.'+random_hash + line_count, maxlenline = get_corpus_freq(corpus, input_path) + line_count = min(line_count, MAX_SENTENCEPIECE_SENTENCES) + print('line count used as input_sentence_size ', line_count, flush=True) + print('training sentencepiece model', flush=True) + train_string = '--input={file_path} --model_prefix={model_prefix} --vocab_size={vocab_size}' \ + + ' --model_type={model_type} --character_coverage={character_coverage} ' \ + + '--input_sentence_size={input_sentence_size} ' \ + + '--input_format=tsv' + train_string = train_string.format(file_path=input_path, model_prefix=use_model_path, vocab_size=num_text_tokens, + model_type=self.model_type, character_coverage=self.character_coverage, + input_sentence_size=int(line_count)) #, #)#, + print("calling spm.SentencePieceTrainer.Train(%s)"%(train_string), flush=True) + spm.SentencePieceTrainer.Train(train_string) + os.remove(input_path) + self.spm_model = use_model_path+'.model' + print('sentencepiece model written to '+self.spm_model, flush=True) + + def EncodeAsIds(self, text, process_fn=None): + """convert text to sentencepiece Ids""" + processed_text = text + if process_fn is not None: + processed_text = process_fn(processed_text) + tokens = self.sp.EncodeAsIds(processed_text) + return Tokenization(tokens, processed_text, text) + + def EncodeAsTokens(self, text, process_fn=None): + """convert text to sentencepiece tokens""" + processed_text = text + if process_fn is not None: + processed_text = process_fn(processed_text) + tokens = self.sp.EncodeAsTokens(processed_text) + return Tokenization(tokens, processed_text, text, asIds=False) + + def IdToToken(self, Id): + """convert Id to sentencpiece token""" + return self.sp.IdToPiece(Id) + + def TokenToId(self, token): + """convert sentencpiece token to Id""" + return self.sp.PieceToId(token) + + def DecodeIds(self, Ids): + """converts ids to a text string""" + if isinstance(Ids, Tokenization): + Ids = Ids.tokenization + return self.sp.DecodeIds(Ids) + + def DecodeTokens(self, Tokens): + """converts sentencepiece tokens to a text string""" + if isinstance(Tokens, Tokenization): + Tokens = Tokens.tokenization + return self.sp.DecodeTokens(Tokens) + +class BertWordPieceTokenizer(Tokenizer): + """ + Loads a pretrained WordPiece tokenizer from `cache_dir` for tokenization + in BERT training. Default to bert-large-uncased tokenizer. + """ + def __init__(self, tokenizer_model_type=None, cache_dir=None, **kwargs): + # default to bert-large-uncased tokenizer + if tokenizer_model_type not in PRETRAINED_VOCAB_ARCHIVE_MAP: + tokenizer_model_type = 'bert-large-uncased' + if torch.distributed.get_rank() == 0: + print('loading BertWordPieceTokenizer (', tokenizer_model_type, ') from cache_dir ', cache_dir) + do_lower_case = not ('-cased' in tokenizer_model_type or 'chinese' in tokenizer_model_type) + self.text_tokenizer = BertTokenizer.from_pretrained(tokenizer_model_type, do_lower_case=do_lower_case, cache_dir=cache_dir) + if torch.distributed.get_rank() == 0: + print('loaded', tokenizer_model_type) + # disable max len warnings by increasing max len + self.text_tokenizer.max_len = int(1e12) + + # set command tokens from wordpiece tokenizer values + self.num_command_tokens = 5 + self.num_tokens = len(self.text_tokenizer.vocab) + self.num_text_tokens = self.num_tokens-5 + self.num_type_tokens = 2 + + self._command_tokens = [ + CommandToken('pad', '[PAD]', self.text_tokenizer.vocab['[PAD]']), + CommandToken('ENC', '[CLS]', self.text_tokenizer.vocab['[CLS]']), + CommandToken('MASK', '[MASK]', self.text_tokenizer.vocab['[MASK]']), + CommandToken('unk', '[UNK]', self.text_tokenizer.vocab['[UNK]']), + CommandToken('sep', '[SEP]', self.text_tokenizer.vocab['[SEP]']), + ] + self.command_name_map = {tok.name: tok for tok in self._command_tokens} + self.command_token_map = {tok.token: tok for tok in self._command_tokens} + self.command_id_map = {tok.Id: tok for tok in self._command_tokens} + + # set type tokens + self.type_tokens = [ + TypeToken('str0', '', 0), + TypeToken('str1', '', 1), + ] + self.type_name_map = {tok.name: tok for tok in self.type_tokens} + self.type_token_map = {tok.token: tok for tok in self.type_tokens} + self.type_id_map = {tok.Id: tok for tok in self.type_tokens} + + # parse tokens and vocabs from tokenizer + + self._tokens = list(self.text_tokenizer.vocab.keys()) + self._vocab = {k:v for k,v in self.text_tokenizer.vocab.items()} + + self._text_tokens = list(self._tokens) + self._text_token_vocab = {k:v for k,v in self.text_tokenizer.vocab.items()} + + self._command_token_tokens = list(self.command_token_map.keys()) + self._command_token_vocab = {t:Id for Id,t in self.command_id_map.items()} + + self._token_types = list(self.type_token_map.keys()) + self._token_type_vocab = {t:Id for Id, t in self.type_id_map.items()} + + def EncodeAsIds(self, text, process_fn=None): + """convert text to wordpiece Ids""" + processed_text = text + if process_fn is not None: + processed_text = process_fn(processed_text) + tokens = self.text_tokenizer.tokenize(processed_text) + Ids = self.text_tokenizer.convert_tokens_to_ids(tokens) + return Tokenization(Ids, processed_text, text) + + def EncodeAsTokens(self, text, process_fn=None): + """convert wordpiece token to Id""" + processed_text = text + if process_fn is not None: + processed_text = process_fn(processed_text) + tokens = self.text_tokenizer.tokenize(processed_text) + return Tokenization(tokens, processed_text, text, asIds=False) + + def IdToToken(self, Id, type_token=False): + """convert Id to sentencpiece token""" + if isinstance(Id, (TypeToken, CommandToken)): + return Id.token + if type_token: + return self.type_id_map[Id].token + return self.text_tokenizer.ids_to_tokens[Id] + + def TokenToId(self, token, type_token=False): + """convert sentencpiece token to Id""" + if isinstance(token, (TypeToken, CommandToken)): + return token.Id + if type_token: + return self.type_token_map[token].Id + return self.text_tokenizer.vocab[token] + + def DecodeIds(self, Ids, type_token=False): + """converts ids to wordpiece tokens and joins them as a text string""" + if type_token: + return ' '.join(Id.token if isinstance(Id, TypeToken) else self.type_id_map[Id].token for Id in Ids) + if isinstance(Ids, Tokenization): + Ids = Ids.tokenization + Tokens = [] + for Id in Ids: + Tokens.append(self.text_tokenizer.ids_to_tokens[Id] if Id != -1 else '-1') + Tokens = self.text_tokenizer.convert_ids_to_tokens(Ids) + return ' '.join(Tokens) + + def DecodeTokens(self, Tokens, type_token=False): + """converts wordpiece tokens to a text string""" + if type_token: + return ' '.join(t.token if isinstance(t, TypeToken) else t for t in Tokens) + if isinstance(Tokens, Tokenization): + Tokens = Tokens.tokenization + return ' '.join(Tokens) + + +class GPT2BPETokenizer(Tokenizer): + def __init__(self, cache_dir=None, **kwargs): + self.text_tokenizer = GPT2Tokenizer.from_pretrained('gpt2', + cache_dir=cache_dir) + + #disable max len warnings by increasing max len + self.text_tokenizer.max_len = int(1e12) + self.num_command_tokens = 2 + self.num_tokens = len(self.text_tokenizer.encoder) + self.num_text_tokens = self.num_tokens-1 + self.num_type_tokens = 2 + + self._command_tokens = [ + CommandToken('pad', '<|endoftext|>', self.text_tokenizer.encoder['<|endoftext|>']), + CommandToken('eos', '<|endoftext|>', self.text_tokenizer.encoder['<|endoftext|>']), + ] + self.command_name_map = {tok.name: tok for tok in self._command_tokens} + self.command_token_map = {tok.token: tok for tok in self._command_tokens} + self.command_id_map = {tok.Id: tok for tok in self._command_tokens} + + self.type_tokens = [ + TypeToken('str0', '', 0), + TypeToken('str1', '', 1), + ] + self.type_name_map = {tok.name: tok for tok in self.type_tokens} + self.type_token_map = {tok.token: tok for tok in self.type_tokens} + self.type_id_map = {tok.Id: tok for tok in self.type_tokens} + + self._tokens = list(self.text_tokenizer.encoder.keys()) + self._vocab = {k:v for k,v in self.text_tokenizer.encoder.items()} + + self._text_tokens = list(self._tokens) + self._text_token_vocab = {k:v for k,v in self.text_tokenizer.encoder.items()} + + self._command_token_tokens = list(self.command_token_map.keys()) + self._command_token_vocab = {t:Id for Id,t in self.command_id_map.items()} + + self._token_types = list(self.type_token_map.keys()) + self._token_type_vocab = {t:Id for Id, t in self.type_id_map.items()} + + def EncodeAsIds(self, text, process_fn=None): + processed_text = text + if process_fn is not None: + processed_text = process_fn(processed_text) + Ids = self.text_tokenizer.encode(processed_text) + #return Tokenization(Ids, processed_text, text) + tokenization = Tokenization(Ids, processed_text, text) + tokenization.set_command_tokens(self._command_tokens) + return tokenization + + + def EncodeAsTokens(self, text, process_fn=None): + processed_text = text + if process_fn is not None: + processed_text = process_fn(processed_text) + tokens = [] + for token in re.findall(self.text_tokenizer.pat, processed_text): + token = ''.join(self.text_tokenizer.bye_encoder[b] for b in token.encode('utf-8')) + tokens.extend(bpe_token for bpe_token in self.text_tokenizer.bpe(token).split(' ')) + tokenization=Tokenization(tokens, processed_text, text, asIds=False) + tokenization.set_command_tokens(self._command_tokens) + return tokenization + #return Tokenization(tokens, processed_text, text, asIds=False) + + def IdToToken(self, Id, type_token=False): + if isinstance(Id, (TypeToken, CommandToken)): + return Id.token + if type_token: + return self.type_id_map[Id].token + return self.text_tokenizer.decoder[Id] + + def TokenToId(self, token, type_token=False): + if isinstance(token, (TypeToken, CommandToken)): + return token.Id + if type_token: + return self.type_token_map[token].Id + return self.text_tokenizer.encoder[token] + + def DecodeIds(self, Ids, type_token=False): + if type_token: + return ' '.join(Id.token if isinstance(Id, TypeToken) else self.type_id_map[Id].token for Id in Ids) + if isinstance(Ids, Tokenization): + Ids = Ids.tokenization + return self.text_tokenizer.decode(Ids) + + def DecodeTokens(self, Tokens, type_token=False): + if type_token: + return ' '.join(t.token if isinstance(t, TypeToken) else t for t in Tokens) + if isinstance(Tokens, Tokenization): + Tokens = Tokens.tokenization + return self.text_tokenizer.decode([self.TokenToId(tok) for tok in Tokens]) + diff --git a/examples/Megatron-LM/data_utils/tokenization_gpt2.py b/examples/Megatron-LM/data_utils/tokenization_gpt2.py new file mode 100644 index 0000000..8ffd7a6 --- /dev/null +++ b/examples/Megatron-LM/data_utils/tokenization_gpt2.py @@ -0,0 +1,304 @@ +# coding=utf-8 +# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for OpenAI GPT.""" +from __future__ import (absolute_import, division, print_function, + unicode_literals) + +import sys +import json +import logging +import os +import regex as re +from io import open + +try: + from functools import lru_cache +except ImportError: + # Just a dummy decorator to get the checks to run on python2 + # because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now. + def lru_cache(): + return lambda func: func + +from .file_utils import cached_path + +logger = logging.getLogger(__name__) + +PRETRAINED_VOCAB_ARCHIVE_MAP = { + 'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json", +} +PRETRAINED_MERGES_ARCHIVE_MAP = { + 'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt", +} +PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { + 'gpt2': 1024, +} +VOCAB_NAME = 'vocab.json' +MERGES_NAME = 'merges.txt' +SPECIAL_TOKENS_NAME = 'special_tokens.txt' + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + _chr = unichr if sys.version_info[0] == 2 else chr + bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8+n) + n += 1 + cs = [_chr(n) for n in cs] + return dict(zip(bs, cs)) + +def get_pairs(word): + """Return set of symbol pairs in a word. + + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + +class GPT2Tokenizer(object): + """ + GPT-2 BPE tokenizer. Peculiarities: + - Byte-level BPE + """ + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): + """ + Instantiate a PreTrainedBertModel from a pre-trained model file. + Download and cache the pre-trained model file if needed. + """ + if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: + vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path] + merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path] + special_tokens_file = None + else: + vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME) + merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME) + special_tokens_file = os.path.join(pretrained_model_name_or_path, SPECIAL_TOKENS_NAME) + if not os.path.exists(special_tokens_file): + special_tokens_file = None + else: + logger.info("loading special tokens file {}".format(special_tokens_file)) + # redirect to the cache, if necessary + try: + resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) + resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir) + except EnvironmentError: + logger.error( + "Model name '{}' was not found in model name list ({}). " + "We assumed '{}' was a path or url but couldn't find files {} and {} " + "at this path or url.".format( + pretrained_model_name_or_path, + ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), + pretrained_model_name_or_path, + vocab_file, merges_file)) + return None + if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file: + logger.info("loading vocabulary file {}".format(vocab_file)) + logger.info("loading merges file {}".format(merges_file)) + else: + logger.info("loading vocabulary file {} from cache at {}".format( + vocab_file, resolved_vocab_file)) + logger.info("loading merges file {} from cache at {}".format( + merges_file, resolved_merges_file)) + if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: + # if we're using a pretrained model, ensure the tokenizer wont index sequences longer + # than the number of positional embeddings + max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path] + kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) + # Instantiate tokenizer. + if special_tokens_file and 'special_tokens' not in kwargs: + special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1] + else: + special_tokens = kwargs.pop('special_tokens', []) + tokenizer = cls(resolved_vocab_file, resolved_merges_file, special_tokens=special_tokens, *inputs, **kwargs) + return tokenizer + + def __init__(self, vocab_file, merges_file, errors='replace', special_tokens=None, max_len=None): + self.max_len = max_len if max_len is not None else int(1e12) + self.encoder = json.load(open(vocab_file)) + self.decoder = {v:k for k,v in self.encoder.items()} + self.errors = errors # how to handle errors in decoding + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v:k for k, v in self.byte_encoder.items()} + bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1] + bpe_merges = [tuple(merge.split()) for merge in bpe_data] + self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) + self.cache = {} + + # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions + self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") + + self.special_tokens = {} + self.special_tokens_decoder = {} + self.set_special_tokens(special_tokens) + + def __len__(self): + return len(self.encoder) + len(self.special_tokens) + + def set_special_tokens(self, special_tokens): + """ Add a list of additional tokens to the encoder. + The additional tokens are indexed starting from the last index of the + current vocabulary in the order of the `special_tokens` list. + """ + if not special_tokens: + self.special_tokens = {} + self.special_tokens_decoder = {} + return + self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens)) + self.special_tokens_decoder = {v:k for k, v in self.special_tokens.items()} + logger.info("Special tokens {}".format(self.special_tokens)) + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token) + pairs = get_pairs(word) + + if not pairs: + return token + + while True: + bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word)-1 and word[i+1] == second: + new_word.append(first+second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = ' '.join(word) + self.cache[token] = word + return word + + def tokenize(self, text): + """ Tokenize a string. """ + bpe_tokens = [] + for token in re.findall(self.pat, text): + if sys.version_info[0] == 2: + token = ''.join(self.byte_encoder[ord(b)] for b in token) + else: + token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) + bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' ')) + return bpe_tokens + + def convert_tokens_to_ids(self, tokens): + """ Converts a sequence of tokens into ids using the vocab. """ + ids = [] + if isinstance(tokens, str) or (sys.version_info[0] == 2 and isinstance(tokens, unicode)): + if tokens in self.special_tokens: + return self.special_tokens[tokens] + else: + return self.encoder.get(tokens, 0) + for token in tokens: + if token in self.special_tokens: + ids.append(self.special_tokens[token]) + else: + ids.append(self.encoder.get(token, 0)) + if len(ids) > self.max_len: + logger.warning( + "Token indices sequence length is longer than the specified maximum " + " sequence length for this OpenAI GPT model ({} > {}). Running this" + " sequence through the model will result in indexing errors".format(len(ids), self.max_len) + ) + return ids + + def convert_ids_to_tokens(self, ids, skip_special_tokens=False): + """Converts a sequence of ids in BPE tokens using the vocab.""" + tokens = [] + for i in ids: + if i in self.special_tokens_decoder: + if not skip_special_tokens: + tokens.append(self.special_tokens_decoder[i]) + else: + tokens.append(self.decoder[i]) + return tokens + + def encode(self, text): + return self.convert_tokens_to_ids(self.tokenize(text)) + + def decode(self, tokens): + text = ''.join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors) + return text + + def save_vocabulary(self, vocab_path): + """Save the tokenizer vocabulary and merge files to a directory.""" + if not os.path.isdir(vocab_path): + logger.error("Vocabulary path ({}) should be a directory".format(vocab_path)) + return + vocab_file = os.path.join(vocab_path, VOCAB_NAME) + merge_file = os.path.join(vocab_path, MERGES_NAME) + special_tokens_file = os.path.join(vocab_path, SPECIAL_TOKENS_NAME) + + with open(vocab_file, 'w', encoding='utf-8') as f: + f.write(json.dumps(self.encoder, ensure_ascii=False)) + + index = 0 + with open(merge_file, "w", encoding="utf-8") as writer: + writer.write(u'#version: 0.2\n') + for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): + if index != token_index: + logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive." + " Please check that the tokenizer is not corrupted!".format(merge_file)) + index = token_index + writer.write(' '.join(bpe_tokens) + u'\n') + index += 1 + + index = len(self.encoder) + with open(special_tokens_file, 'w', encoding='utf-8') as writer: + for token, token_index in sorted(self.special_tokens.items(), key=lambda kv: kv[1]): + if index != token_index: + logger.warning("Saving special tokens vocabulary to {}: BPE indices are not consecutive." + " Please check that the tokenizer is not corrupted!".format(special_tokens_file)) + index = token_index + writer.write(token + u'\n') + index += 1 + + return vocab_file, merge_file, special_tokens_file diff --git a/examples/Megatron-LM/data_utils/wordpiece.py b/examples/Megatron-LM/data_utils/wordpiece.py new file mode 100755 index 0000000..81121e4 --- /dev/null +++ b/examples/Megatron-LM/data_utils/wordpiece.py @@ -0,0 +1,390 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes. Provided as is from https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/pytorch_pretrained_bert/tokenization.py""" + +from __future__ import absolute_import, division, print_function, unicode_literals + +import collections +import logging +import os +import unicodedata +from io import open + +from .file_utils import cached_path + +logger = logging.getLogger(__name__) + +PRETRAINED_VOCAB_ARCHIVE_MAP = { + 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", + 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt", + 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt", + 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt", + 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt", + 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt", + 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt", +} +PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { + 'bert-base-uncased': 512, + 'bert-large-uncased': 512, + 'bert-base-cased': 512, + 'bert-large-cased': 512, + 'bert-base-multilingual-uncased': 512, + 'bert-base-multilingual-cased': 512, + 'bert-base-chinese': 512, +} +VOCAB_NAME = 'vocab.txt' + + +def load_vocab(vocab_file): + """Loads a vocabulary file into a dictionary.""" + vocab = collections.OrderedDict() + index = 0 + with open(vocab_file, "r", encoding="utf-8") as reader: + while True: + token = reader.readline() + if not token: + break + token = token.strip() + vocab[token] = index + index += 1 + return vocab + + +def whitespace_tokenize(text): + """Runs basic whitespace cleaning and splitting on a piece of text.""" + text = text.strip() + if not text: + return [] + tokens = text.split() + return tokens + + +class BertTokenizer(object): + """Runs end-to-end tokenization: punctuation splitting + wordpiece""" + + def __init__(self, vocab_file, do_lower_case=True, max_len=None, do_basic_tokenize=True, + never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): + """Constructs a BertTokenizer. + + Args: + vocab_file: Path to a one-wordpiece-per-line vocabulary file + do_lower_case: Whether to lower case the input + Only has an effect when do_wordpiece_only=False + do_basic_tokenize: Whether to do basic tokenization before wordpiece. + max_len: An artificial maximum length to truncate tokenized sequences to; + Effective maximum length is always the minimum of this + value (if specified) and the underlying BERT model's + sequence length. + never_split: List of tokens which will never be split during tokenization. + Only has an effect when do_wordpiece_only=False + """ + if not os.path.isfile(vocab_file): + raise ValueError( + "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " + "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file)) + self.vocab = load_vocab(vocab_file) + self.ids_to_tokens = collections.OrderedDict( + [(ids, tok) for tok, ids in self.vocab.items()]) + self.do_basic_tokenize = do_basic_tokenize + if do_basic_tokenize: + self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case, + never_split=never_split) + self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) + self.max_len = max_len if max_len is not None else int(1e12) + + def tokenize(self, text): + if self.do_basic_tokenize: + split_tokens = [] + for token in self.basic_tokenizer.tokenize(text): + for sub_token in self.wordpiece_tokenizer.tokenize(token): + split_tokens.append(sub_token) + else: + split_tokens = self.wordpiece_tokenizer.tokenize(text) + return split_tokens + + def convert_tokens_to_ids(self, tokens): + """Converts a sequence of tokens into ids using the vocab.""" + ids = [] + for token in tokens: + ids.append(self.vocab[token]) + if len(ids) > self.max_len: + logger.warning( + "Token indices sequence length is longer than the specified maximum " + " sequence length for this BERT model ({} > {}). Running this" + " sequence through BERT will result in indexing errors".format(len(ids), self.max_len) + ) + return ids + + def convert_ids_to_tokens(self, ids): + """Converts a sequence of ids in wordpiece tokens using the vocab.""" + tokens = [] + for i in ids: + tokens.append(self.ids_to_tokens[i]) + return tokens + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): + """ + Instantiate a PreTrainedBertModel from a pre-trained model file. + Download and cache the pre-trained model file if needed. + """ + if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: + vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path] + else: + vocab_file = pretrained_model_name_or_path + if os.path.isdir(vocab_file): + vocab_file = os.path.join(vocab_file, VOCAB_NAME) + # redirect to the cache, if necessary + try: + resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) + except EnvironmentError: + logger.error( + "Model name '{}' was not found in model name list ({}). " + "We assumed '{}' was a path or url but couldn't find any file " + "associated to this path or url.".format( + pretrained_model_name_or_path, + ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), + vocab_file)) + return None + if resolved_vocab_file == vocab_file: + logger.info("loading vocabulary file {}".format(vocab_file)) + else: + logger.info("loading vocabulary file {} from cache at {}".format( + vocab_file, resolved_vocab_file)) + if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: + # if we're using a pretrained model, ensure the tokenizer wont index sequences longer + # than the number of positional embeddings + max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path] + kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) + # Instantiate tokenizer. + tokenizer = cls(resolved_vocab_file, *inputs, **kwargs) + return tokenizer + + +class BasicTokenizer(object): + """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" + + def __init__(self, + do_lower_case=True, + never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): + """Constructs a BasicTokenizer. + + Args: + do_lower_case: Whether to lower case the input. + """ + self.do_lower_case = do_lower_case + self.never_split = never_split + + def tokenize(self, text): + """Tokenizes a piece of text.""" + text = self._clean_text(text) + # This was added on November 1st, 2018 for the multilingual and Chinese + # models. This is also applied to the English models now, but it doesn't + # matter since the English models were not trained on any Chinese data + # and generally don't have any Chinese data in them (there are Chinese + # characters in the vocabulary because Wikipedia does have some Chinese + # words in the English Wikipedia.). + text = self._tokenize_chinese_chars(text) + orig_tokens = whitespace_tokenize(text) + split_tokens = [] + for token in orig_tokens: + if self.do_lower_case and token not in self.never_split: + token = token.lower() + token = self._run_strip_accents(token) + split_tokens.extend(self._run_split_on_punc(token)) + + output_tokens = whitespace_tokenize(" ".join(split_tokens)) + return output_tokens + + def _run_strip_accents(self, text): + """Strips accents from a piece of text.""" + text = unicodedata.normalize("NFD", text) + output = [] + for char in text: + cat = unicodedata.category(char) + if cat == "Mn": + continue + output.append(char) + return "".join(output) + + def _run_split_on_punc(self, text): + """Splits punctuation on a piece of text.""" + if text in self.never_split: + return [text] + chars = list(text) + i = 0 + start_new_word = True + output = [] + while i < len(chars): + char = chars[i] + if _is_punctuation(char): + output.append([char]) + start_new_word = True + else: + if start_new_word: + output.append([]) + start_new_word = False + output[-1].append(char) + i += 1 + + return ["".join(x) for x in output] + + def _tokenize_chinese_chars(self, text): + """Adds whitespace around any CJK character.""" + output = [] + for char in text: + cp = ord(char) + if self._is_chinese_char(cp): + output.append(" ") + output.append(char) + output.append(" ") + else: + output.append(char) + return "".join(output) + + def _is_chinese_char(self, cp): + """Checks whether CP is the codepoint of a CJK character.""" + # This defines a "chinese character" as anything in the CJK Unicode block: + # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) + # + # Note that the CJK Unicode block is NOT all Japanese and Korean characters, + # despite its name. The modern Korean Hangul alphabet is a different block, + # as is Japanese Hiragana and Katakana. Those alphabets are used to write + # space-separated words, so they are not treated specially and handled + # like the all of the other languages. + if ((cp >= 0x4E00 and cp <= 0x9FFF) or # + (cp >= 0x3400 and cp <= 0x4DBF) or # + (cp >= 0x20000 and cp <= 0x2A6DF) or # + (cp >= 0x2A700 and cp <= 0x2B73F) or # + (cp >= 0x2B740 and cp <= 0x2B81F) or # + (cp >= 0x2B820 and cp <= 0x2CEAF) or + (cp >= 0xF900 and cp <= 0xFAFF) or # + (cp >= 0x2F800 and cp <= 0x2FA1F)): # + return True + + return False + + def _clean_text(self, text): + """Performs invalid character removal and whitespace cleanup on text.""" + output = [] + for char in text: + cp = ord(char) + if cp == 0 or cp == 0xfffd or _is_control(char): + continue + if _is_whitespace(char): + output.append(" ") + else: + output.append(char) + return "".join(output) + + +class WordpieceTokenizer(object): + """Runs WordPiece tokenization.""" + + def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): + self.vocab = vocab + self.unk_token = unk_token + self.max_input_chars_per_word = max_input_chars_per_word + + def tokenize(self, text): + """Tokenizes a piece of text into its word pieces. + + This uses a greedy longest-match-first algorithm to perform tokenization + using the given vocabulary. + + For example: + input = "unaffable" + output = ["un", "##aff", "##able"] + + Args: + text: A single token or whitespace separated tokens. This should have + already been passed through `BasicTokenizer`. + + Returns: + A list of wordpiece tokens. + """ + + output_tokens = [] + for token in whitespace_tokenize(text): + chars = list(token) + if len(chars) > self.max_input_chars_per_word: + output_tokens.append(self.unk_token) + continue + + is_bad = False + start = 0 + sub_tokens = [] + while start < len(chars): + end = len(chars) + cur_substr = None + while start < end: + substr = "".join(chars[start:end]) + if start > 0: + substr = "##" + substr + if substr in self.vocab: + cur_substr = substr + break + end -= 1 + if cur_substr is None: + is_bad = True + break + sub_tokens.append(cur_substr) + start = end + + if is_bad: + output_tokens.append(self.unk_token) + else: + output_tokens.extend(sub_tokens) + return output_tokens + + +def _is_whitespace(char): + """Checks whether `chars` is a whitespace character.""" + # \t, \n, and \r are technically contorl characters but we treat them + # as whitespace since they are generally considered as such. + if char == " " or char == "\t" or char == "\n" or char == "\r": + return True + cat = unicodedata.category(char) + if cat == "Zs": + return True + return False + + +def _is_control(char): + """Checks whether `chars` is a control character.""" + # These are technically control characters but we count them as whitespace + # characters. + if char == "\t" or char == "\n" or char == "\r": + return False + cat = unicodedata.category(char) + if cat.startswith("C"): + return True + return False + + +def _is_punctuation(char): + """Checks whether `chars` is a punctuation character.""" + cp = ord(char) + # We treat all non-letter/number ASCII as punctuation. + # Characters such as "^", "$", and "`" are not in the Unicode + # Punctuation class but we treat them as punctuation anyways, for + # consistency. + if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or + (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): + return True + cat = unicodedata.category(char) + if cat.startswith("P"): + return True + return False diff --git a/examples/Megatron-LM/detokenizer.py b/examples/Megatron-LM/detokenizer.py new file mode 100755 index 0000000..999251e --- /dev/null +++ b/examples/Megatron-LM/detokenizer.py @@ -0,0 +1,60 @@ +import re + +def ptb_detokenizer(string): + string = string.replace(" '", "'") + string = string.replace(" \n", "\n") + string = string.replace("\n ", "\n") + string = string.replace(" n't", "n't") + string = string.replace(" N ","1 ") + string = string.replace("$ 1", "$1") + string = string.replace("# 1", "#1") + return string + + +def wikitext_detokenizer(string): + #contractions + string = string.replace("s '", "s'") + string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string) + # number separators + string = string.replace(" @-@ ", "-") + string = string.replace(" @,@ ", ",") + string = string.replace(" @.@ ", ".") + #punctuation + string = string.replace(" : ", ": ") + string = string.replace(" ; ", "; ") + string = string.replace(" . ", ". ") + string = string.replace(" ! ", "! ") + string = string.replace(" ? ", "? ") + string = string.replace(" , ", ", ") + # double brackets + string = re.sub(r"\(\s*([^\)]*?)\s*\)", r"(\1)", string) + string = re.sub(r"\[\s*([^\]]*?)\s*\]", r"[\1]", string) + string = re.sub(r"{\s*([^}]*?)\s*}", r"{\1}", string) + string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', string) + string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", string) + # miscellaneous + string = string.replace("= = = =", "====") + string = string.replace("= = =", "===") + string = string.replace("= =", "==") + string = string.replace(" "+chr(176)+" ", chr(176)) + string = string.replace(" \n", "\n") + string = string.replace("\n ", "\n") + string = string.replace(" N ", " 1 ") + string = string.replace(" 's", "'s") + + return string + +def lambada_detokenizer(string): + return string + +def get_detokenizer(path): + for key in DETOKENIZERS.keys(): + if key in path: + print(key) + return DETOKENIZERS[key] + +DETOKENIZERS = { + 'ptb': ptb_detokenizer, + 'wikitext': wikitext_detokenizer, + 'lambada': lambada_detokenizer, +} diff --git a/examples/Megatron-LM/docker/Dockerfile b/examples/Megatron-LM/docker/Dockerfile new file mode 100644 index 0000000..e0aad23 --- /dev/null +++ b/examples/Megatron-LM/docker/Dockerfile @@ -0,0 +1,38 @@ +# =========== +# base images +# =========== +FROM nvcr.io/nvidia/pytorch:19.05-py3 + + +# =============== +# system packages +# =============== +RUN apt-get update && apt-get install -y \ + bash-completion \ + emacs \ + git \ + graphviz \ + htop \ + libopenexr-dev \ + rsync \ + wget \ +&& rm -rf /var/lib/apt/lists/* + + +# ============ +# pip packages +# ============ +RUN pip install --upgrade pip && \ + pip install --upgrade setuptools +COPY requirements.txt /tmp/ +RUN pip install --upgrade --ignore-installed -r /tmp/requirements.txt + + +# =========== +# latest apex +# =========== +RUN pip uninstall -y apex && \ +git clone https://github.com/NVIDIA/apex.git ~/apex && \ +cd ~/apex && \ +pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" . + diff --git a/examples/Megatron-LM/docker/README.md b/examples/Megatron-LM/docker/README.md new file mode 100644 index 0000000..0896143 --- /dev/null +++ b/examples/Megatron-LM/docker/README.md @@ -0,0 +1 @@ +README for Megatron docker diff --git a/examples/Megatron-LM/docker/requirements.txt b/examples/Megatron-LM/docker/requirements.txt new file mode 100644 index 0000000..17b780c --- /dev/null +++ b/examples/Megatron-LM/docker/requirements.txt @@ -0,0 +1,10 @@ +boto3 +google-cloud-language +inflect +nltk +numpy +pandas +requests +sentencepiece +tensorflow +tqdm diff --git a/examples/Megatron-LM/evaluate_gpt2.py b/examples/Megatron-LM/evaluate_gpt2.py new file mode 100755 index 0000000..1fa01a7 --- /dev/null +++ b/examples/Megatron-LM/evaluate_gpt2.py @@ -0,0 +1,554 @@ +# coding=utf-8 +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Pretrain BERT""" + +import os +import json +import math +import random +import numpy as np +import torch +from deepspeed.accelerator.real_accelerator import get_accelerator + + +from arguments import get_args +from configure_data import configure_data +from fp16 import FP16_Module +from fp16 import FP16_Optimizer +from learning_rates import AnnealingLR +from model import GPT2Model +from model import gpt2_get_params_for_weight_decay_optimization +from model import DistributedDataParallel as DDP +import mpu +from apex.optimizers import FusedAdam as Adam +from utils import Timers +from utils import load_checkpoint +from utils import report_memory +from utils import print_params_min_max_norm +from utils import print_rank_0 + +from data_utils import make_tokenizer + +from detokenizer import * + +def get_model(args): + """Build the model.""" + + print_rank_0('building GPT2 model ...') + model = GPT2Model(num_layers=args.num_layers, + vocab_size=args.vocab_size, + hidden_size=args.hidden_size, + num_attention_heads=args.num_attention_heads, + embedding_dropout_prob=args.hidden_dropout, + attention_dropout_prob=args.attention_dropout, + output_dropout_prob=args.hidden_dropout, + max_sequence_length=args.max_position_embeddings, + checkpoint_activations=args.checkpoint_activations, + checkpoint_num_layers=args.checkpoint_num_layers, + parallel_output=not args.cloze_eval) + + print_rank_0(' > number of parameters: {}'.format( + sum([p.nelement() for p in model.parameters()]))) + + # GPU allocation. + model.to(torch.device(get_accelerator().current_device_name())) + + # Fp16 conversion. + if args.fp16: + model = FP16_Module(model) + + # Wrap model for distributed training. + model = DDP(model) + + return model + + +def setup_model(args): + """Setup model and optimizer.""" + + model = get_model(args) + + if args.load is not None: + _ = load_checkpoint( + model, None, None, args) + + return model + +def get_masks_and_position_ids(data, + eod_token, + reset_position_ids, + reset_attention_mask): + + # Extract batch size and sequence length. + batch_size, seq_length = data.size() + + # Attention mask (lower triangular). + if reset_attention_mask: + att_mask_batch = batch_size + else: + att_mask_batch = 1 + attention_mask = torch.tril(torch.ones( + (att_mask_batch, seq_length, seq_length), device=data.device)).view( + att_mask_batch, 1, seq_length, seq_length) + + # Loss mask. + loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device) + loss_mask[data == eod_token] = 0.0 + + # Position ids. + position_ids = torch.arange(seq_length, dtype=torch.long, + device=data.device) + position_ids = position_ids.unsqueeze(0).expand_as(data) + # We need to clone as the ids will be modifed based on batch index. + if reset_position_ids: + position_ids = position_ids.clone() + + if reset_position_ids or reset_attention_mask: + # Loop through the batches: + for b in range(batch_size): + + # Find indecies where EOD token is. + eod_index = position_ids[b, data[b] == eod_token] + # Detach indecies from positions if going to modify positions. + if reset_position_ids: + eod_index = eod_index.clone() + + # Loop through EOD indecies: + prev_index = 0 + for j in range(eod_index.size()[0]): + i = eod_index[j] + # Mask attention loss. + if reset_attention_mask: + attention_mask[b, 0, (i+1):, :(i+1)] = 0 + # Reset positions. + if reset_position_ids: + position_ids[b, (i+1):] -= (i + 1 - prev_index) + prev_index = i + 1 + + return attention_mask, loss_mask, position_ids + +def get_batch(data_iterator, args, timers): + ''' get_batch subdivides the source data into chunks of + length args.seq_length. If source is equal to the example + output of the data loading example, with a seq_length limit + of 2, we'd get the following two Variables for i = 0: + ┌ a g m s ┐ ┌ b h n t ┐ + └ b h n t ┘ └ c i o u ┘ + Note that despite the name of the function, the subdivison of data is not + done along the batch dimension (i.e. dimension 1), since that was handled + by the data loader. The chunks are along dimension 0, corresponding + to the seq_len dimension in the LSTM. A Variable representing an appropriate + shard reset mask of the same dimensions is also returned. + ''' + # Items and their type. + keys = ['text', 'pad_mask'] + datatype = torch.int64 + + # Broadcast data. + timers('data loader').start() + if data_iterator is not None: + data = next(data_iterator) + else: + data = None + timers('data loader').stop() + data_b = mpu.broadcast_data(keys, data, datatype) + + # Unpack. + tokens_ = data_b['text'].long() + lm_labels = tokens_[:, 1:].contiguous() + tokens = tokens_[:, :-1].contiguous() + padding_mask = data_b['pad_mask'].byte() + + # Get the masks and postition ids. + attention_mask, loss_mask, position_ids = get_masks_and_position_ids( + tokens, + args.eod_token, + args.reset_position_ids, + args.reset_attention_mask) + + # Convert + if args.fp16: + attention_mask = attention_mask.half() + + return tokens, lm_labels, attention_mask, position_ids, padding_mask + + +def forward_step(data_iterator, model, args, timers): + """Forward step.""" + + # Get the batch. + timers('batch generator').start() + batch = get_batch(data_iterator, args, timers) + if batch is None: + return None + tokens, lm_labels, attention_mask, position_ids, loss_mask = batch + timers('batch generator').stop() + # Forward model. + if args.eval_hf: + output, _ = model(tokens) + else: + output = model(tokens, position_ids, attention_mask) + + if not args.cloze_eval: + #losses = torch.nn.CrossEntropyLoss(reduce=False)( + losses = mpu.vocab_parallel_cross_entropy( + output.contiguous().float(), lm_labels.contiguous()) + loss_mask = loss_mask.contiguous() + loss_mask = loss_mask.view(-1) + lm_loss = torch.sum( + losses.view(-1) * loss_mask.float()) + else: + outputs = torch.argmax(output, -1).contiguous().view(-1) + acc = (outputs == lm_labels.contiguous().view(-1)).float() + loss_mask = loss_mask.contiguous().view(-1).float() + lm_loss = torch.sum(acc * loss_mask) + + return lm_loss + + +def evaluate(data_loader, model, args, timers, + num_iterations=None): + """Evaluation.""" + + # Turn on evaluation mode which disables dropout. + model.eval() + + total_lm_loss = 0 + if num_iterations is not None: + max_iters = num_iterations + else: + if mpu.get_model_parallel_rank() == 0: + max_iters_gpu = get_accelerator().LongTensor([len(data_loader)]) + else: + max_iters_gpu = get_accelerator().LongTensor([0]) + torch.distributed.broadcast(max_iters_gpu, + mpu.get_model_parallel_src_rank(), + group=mpu.get_model_parallel_group()) + max_iters = max_iters_gpu[0].item() + print_rank_0('global rank: {} | max iters: {}'.format( + torch.distributed.get_rank(), max_iters)) + + if data_loader is not None: + data_iterator = iter(data_loader) + else: + data_iterator = None + + with torch.no_grad(): + iteration = 0 + while iteration < max_iters: + if iteration % args.log_interval == 0: + print_rank_0('global rank: {} | iteration: {}'.format( + torch.distributed.get_rank(), iteration)) + # Forward evaluation. + lm_loss = forward_step(data_iterator, model, args, timers) + if lm_loss is None: + break + # Reduce across processes. + if isinstance(model, DDP): + torch.distributed.all_reduce(lm_loss.data) + if args.cloze_eval: + lm_loss.data = lm_loss.data / args.world_size + else: + lm_loss.data = lm_loss.data / args.model_parallel_size + + if not args.cloze_eval: + total_lm_loss += lm_loss.data.detach().float().item()/(args.num_tokenized_tokens-1) + else: + total_lm_loss += lm_loss.data.detach().float().item() + + iteration += 1 + + # Move model back to the train mode. + model.train() + + return total_lm_loss + + +def evaluate_and_print_results(prefix, data_iterator, model, + args, timers, num_iterations=None): + """Helper function to evaluate and dump results on screen.""" + if not args.cloze_eval: + lm_loss = evaluate(data_iterator, model, args, timers, num_iterations) + val_loss = lm_loss + ppl = math.exp(min(20, val_loss)) + token_ratio = (args.num_tokenized_tokens-1)/(args.num_original_tokens-1) + adjusted_ppl = math.exp(min(20, val_loss*token_ratio)) + print_rank_0('-' * 100) + string = ' validation results on {} | '.format(prefix) + string += 'avg loss: {:.4E} | '.format(val_loss) + string += 'ppl: {:.4E} | '.format(ppl) + string += 'adjusted ppl: {:.4E} | '.format(adjusted_ppl) + string += 'token ratio: {} |'.format(token_ratio) + length = len(string) + 1 + print_rank_0('-' * length) + print_rank_0(string) + print_rank_0('-' * length) + + return val_loss + else: + num_correct = evaluate(data_iterator, model, args, timers, num_iterations) + acc = num_correct / args.num_examples + print_rank_0('-' * 100) + string = ' validation results on {} | '.format(prefix) + string += 'number correct: {:.4E} | '.format(num_correct) + string += 'total examples: {:.4E} | '.format(args.num_examples) + string += 'avg accuracy: {:.4E}'.format(acc) + length = len(string) + 1 + print_rank_0('-' * length) + print_rank_0(string) + print_rank_0('-' * length) + return acc + + +def initialize_distributed(args): + """Initialize torch.distributed.""" + + # Manually set the device ids. + device = args.rank % get_accelerator().device_count() + if args.local_rank is not None: + device = args.local_rank + get_accelerator().set_device(device) + # Call the init process + init_method = 'tcp://' + master_ip = os.getenv('MASTER_ADDR', 'localhost') + master_port = os.getenv('MASTER_PORT', '6000') + init_method += master_ip + ':' + master_port + torch.distributed.init_process_group( + backend=args.distributed_backend, + world_size=args.world_size, rank=args.rank, + init_method=init_method) + + # Set the model-parallel / data-parallel communicators. + mpu.initialize_model_parallel(args.model_parallel_size) + + +def set_random_seed(seed): + """Set random seed for reproducability.""" + + if seed is not None and seed > 0: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + mpu.model_parallel_cuda_manual_seed(seed) + + +class LM_Eval_Dataset(torch.utils.data.Dataset): + def __init__(self, tokens, seq_len, pad_idx, overalapping_eval=None): + self.tokens = tokens + self.seq_len = seq_len + self.pad_idx = pad_idx + self.overalapping_eval = overalapping_eval + if self.overalapping_eval is None: + self.overalapping_eval = self.seq_len + self.overalapping_eval = max(1, self.overalapping_eval) + + self.total_targets = len(self.tokens) - 1 + # remove first sequence tokens + targets = max(self.total_targets - self.overalapping_eval, 0) + self.total_sequences = max(math.ceil(targets / self.overalapping_eval)+1, 1) + + def __len__(self): + return self.total_sequences + + def __getitem__(self, idx): + start_idx = idx * self.overalapping_eval + end_idx = start_idx + self.seq_len + tokens = self.tokens[start_idx:end_idx+1] + num_tokens = len(tokens) + pad_mask = [1]*num_tokens + if num_tokens < self.seq_len+1: + num_pad = (self.seq_len+1-num_tokens) + pad_mask += [0]*(num_pad) + tokens += [self.pad_idx] * num_pad + pad_mask = np.array(pad_mask[1:]) + if self.overalapping_eval != self.seq_len and idx!=0: + pad_mask[:-self.overalapping_eval] *= 0 + + return {'text': np.array(tokens), 'pad_mask': pad_mask} + +class Lambada_Eval_Dataset(torch.utils.data.Dataset): + def __init__(self, path, tokenizer, seq_len): + self.seq_len = seq_len + self.pad_idx = tokenizer.get_command('pad').Id + + self.tokens = [] + with open(path, 'r') as f: + for line in f.readlines(): + text = json.loads(line)['text'] + self.tokens.append(tokenizer.EncodeAsIds(text).tokenization) + + def __len__(self): + return len(self.tokens) + + def __getitem__(self, idx): + + tokens = self.tokens[idx] + num_tokens = len(tokens) + pad_mask = [0]*num_tokens + pad_mask[-1] = 1 + if num_tokens < self.seq_len+1: + num_pad = (self.seq_len+1-num_tokens) + pad_mask += [0]*(num_pad) + tokens += [self.pad_idx] * num_pad + pad_mask = np.array(pad_mask[1:]) + + return {'text': np.array(tokens), 'pad_mask': pad_mask} + +def get_tokenizer(args): + tokenizer_args = { + 'tokenizer_type': args.tokenizer_type, + 'corpus': None, + 'model_path': args.tokenizer_path, + 'vocab_size': args.vocab_size, + 'model_type': args.tokenizer_model_type, + 'cache_dir': args.cache_dir} + return make_tokenizer(**tokenizer_args) + +def get_eval_data(args): + val_dataloader = None + if mpu.get_model_parallel_rank() == 0: + eval_batch_size = args.eval_batch_size + eval_batch_size = args.batch_size if eval_batch_size is None else eval_batch_size + seq_len = args.seq_length + valid_data = args.valid_data + valid_data = valid_data[0] if isinstance(valid_data, list) else valid_data + + tokenizer = get_tokenizer(args) + + if not args.cloze_eval: + + with open(valid_data, "rb") as reader: + entire_data = reader.read().decode('utf-8') + num_original_tokens = len(entire_data.strip().split(" ")) + entire_data = get_detokenizer(valid_data)(entire_data) + tokenized_data = tokenizer.EncodeAsIds(entire_data).tokenization + num_tokenized_tokens = len(tokenized_data) + string = 'Original Tokens: %d, Detokenized tokens: %d' % (num_tokenized_tokens, num_original_tokens) + print_rank_0(string) + + eod_token = tokenizer.get_command('pad').Id + val_dataset = LM_Eval_Dataset(tokenized_data, seq_len, eod_token, + args.overlapping_eval) + else: + val_dataset = Lambada_Eval_Dataset(valid_data, tokenizer, seq_len) + num_tokenized_tokens = 0 + num_original_tokens = 0 + val_dataloader = torch.utils.data.DataLoader( + val_dataset, batch_size=eval_batch_size, drop_last=False) + + before = tokenizer.num_tokens + after = before + while after % mpu.get_model_parallel_world_size() != 0: + after += 1 + print_rank_0('> padded vocab (size: {}) with {} dummy tokens (new size: {})'. + format(before, after - before, after)) + eod_token = tokenizer.get_command('pad').Id + num_examples = len(val_dataset) + token_counts = get_accelerator().LongTensor([after, eod_token, num_examples, + num_original_tokens, + num_tokenized_tokens]) + else: + token_counts = get_accelerator().LongTensor([0, 0, 0, 0, 0]) + torch.distributed.broadcast(token_counts, + mpu.get_model_parallel_src_rank(), + group=mpu.get_model_parallel_group()) + args.vocab_size = token_counts[0].item() + args.eod_token = token_counts[1].item() + args.num_examples = token_counts[2].item() + args.num_original_tokens = token_counts[3].item() + args.num_tokenized_tokens = token_counts[4].item() + + print('global rank: {} | vocab size: {} | eod token: {} | ' + 'num_examples: {} | num_original_tokens: {} | ' + 'num_tokenized_tokens: {}'.format( + torch.distributed.get_rank(), args.vocab_size, + args.eod_token, args.num_examples, args.num_original_tokens, + args.num_tokenized_tokens )) + return val_dataloader + +def main(): + """Main training program.""" + + print('Evaluate GPT2 model') + + # Disable CuDNN. + torch.backends.cudnn.enabled = False + + # Timer. + timers = Timers() + + # Arguments. + args = get_args() + + # Pytorch distributed. + initialize_distributed(args) + + # Random seeds for reproducability. + set_random_seed(args.seed) + + # Data stuff. + eval_data = get_eval_data(args) + + # Model, optimizer, and learning rate. + if args.eval_hf: + from pytorch_pretrained_bert import GPT2LMHeadModel + from pytorch_pretrained_bert import GPT2Model as HFGPT2Model + if args.num_layers == 24: + model_path = args.load + #model_path = '/home/universal-lm-data.cosmos549/repos/gpt2_mp/models/345M' + hfmodel = HFGPT2Model.from_pretrained(model_path, cache_dir='gpt2_weights', from_tf=True).cuda() + model = GPT2LMHeadModel(hfmodel.config) + model.transformer.load_state_dict(hfmodel.state_dict()) + model.cuda() + else: + model = GPT2LMHeadModel.from_pretrained('gpt2', cache_dir='gpt2_weights').cuda() + else: + if args.load_openai: + from utils import move_weights + model_path = args.load + args.load = None + model = setup_model(args) + from pytorch_pretrained_bert import GPT2LMHeadModel + from pytorch_pretrained_bert import GPT2Model as HFGPT2Model + + model_path = 'gpt2' + from_tf = False + print('loading openai weights') + model.cpu() + if args.num_layers == 24: + #model_path = '/home/universal-lm-data.cosmos549/repos/gpt2_mp/models/345M' + hfmodel = HFGPT2Model.from_pretrained(model_path, cache_dir='gpt2_weights', from_tf=True) + gpt2model = GPT2LMHeadModel(hfmodel.config) + gpt2model.transformer.load_state_dict(hfmodel.state_dict()) + gpt2model + else: + gpt2model = GPT2LMHeadModel.from_pretrained('gpt2', cache_dir='gpt2_weights') + model2fill = model + while isinstance(model2fill, (DDP, FP16_Module)): + model2fill = model2fill.module + move_weights(model2fill, gpt2model) + model.cuda() + else: + model = setup_model(args) + + # Run on test data. + prefix = "wiki" #os.path.basename(args.valid_data) + evaluate_and_print_results(prefix, eval_data, + model, args, timers) + + +if __name__ == "__main__": + main() diff --git a/examples/Megatron-LM/fp16/__init__.py b/examples/Megatron-LM/fp16/__init__.py new file mode 100644 index 0000000..a2c68a1 --- /dev/null +++ b/examples/Megatron-LM/fp16/__init__.py @@ -0,0 +1,30 @@ +# coding=utf-8 +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .fp16util import ( + BN_convert_float, + network_to_half, + prep_param_lists, + model_grads_to_master_grads, + master_params_to_model_params, + tofp16, + to_python_float, + clip_grad_norm, + convert_module, + convert_network, + FP16Model, +) + +from .fp16 import * +from .loss_scaler import * diff --git a/examples/Megatron-LM/fp16/fp16.py b/examples/Megatron-LM/fp16/fp16.py new file mode 100755 index 0000000..e96d08f --- /dev/null +++ b/examples/Megatron-LM/fp16/fp16.py @@ -0,0 +1,634 @@ +# coding=utf-8 +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Stable version of apex FP16 Optimizer""" +import torch +from torch import nn +from torch.autograd import Variable +from torch.nn.parameter import Parameter +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors +from deepspeed.accelerator.real_accelerator import get_accelerator + +from .loss_scaler import DynamicLossScaler, LossScaler +from .fp16util import model_grads_to_master_grads, master_params_to_model_params, clip_grad_norm + +FLOAT_TYPES = (torch.FloatTensor, get_accelerator().FloatTensor) +HALF_TYPES = (torch.HalfTensor, get_accelerator().HalfTensor) + +def conversion_helper(val, conversion): + """Apply conversion to val. Recursively apply conversion if `val` is a nested tuple/list structure.""" + if not isinstance(val, (tuple, list)): + return conversion(val) + rtn = [conversion_helper(v, conversion) for v in val] + if isinstance(val, tuple): + rtn = tuple(rtn) + return rtn + +def fp32_to_fp16(val): + """Convert fp32 `val` to fp16""" + def half_conversion(val): + val_typecheck = val + if isinstance(val_typecheck, (Parameter, Variable)): + val_typecheck = val.data + if isinstance(val_typecheck, FLOAT_TYPES): + val = val.half() + return val + return conversion_helper(val, half_conversion) + +def fp16_to_fp32(val): + """Convert fp16 `val` to fp32""" + def float_conversion(val): + val_typecheck = val + if isinstance(val_typecheck, (Parameter, Variable)): + val_typecheck = val.data + if isinstance(val_typecheck, HALF_TYPES): + val = val.float() + return val + return conversion_helper(val, float_conversion) + +class FP16_Module(nn.Module): + def __init__(self, module): + super(FP16_Module, self).__init__() + self.add_module('module', module.half()) + + def forward(self, *inputs, **kwargs): + return fp16_to_fp32(self.module(*(fp32_to_fp16(inputs)), **kwargs)) + + def state_dict(self, destination=None, prefix='', keep_vars=False): + return self.module.state_dict(destination, prefix, keep_vars) + + def load_state_dict(self, state_dict, strict=True): + self.module.load_state_dict(state_dict, strict=strict) + +# TODO: Update overflow check + downscale to use Carl's fused kernel. +class FP16_Optimizer(object): + """ + :class:`FP16_Optimizer` is designed to wrap an existing PyTorch optimizer, + and manage static or dynamic loss scaling and master weights in a manner transparent to the user. + For standard use, only two lines must be changed: creating the :class:`FP16_Optimizer` instance, + and changing the call to ``backward``. + + Example:: + + model = torch.nn.Linear(D_in, D_out).cuda().half() + optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) + # Name the FP16_Optimizer instance to replace the existing optimizer + # (recommended but not required): + optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0) + ... + # loss.backward() becomes: + optimizer.backward(loss) + ... + + Example with dynamic loss scaling:: + + ... + optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True) + # optional arg to control dynamic loss scaling behavior + # dynamic_loss_args={'scale_window' : 500}) + # Usually, dynamic_loss_args is not necessary. + + Args: + init_optimizer (torch.optim.optimizer): Existing optimizer created with the parameters to optimize. Internally, :class:`FP16_Optimizer` replaces the passed optimizer's fp16 parameters, if any, with fp32 master parameters copied from the original ones. :class:`FP16_Optimizer` also stores references to the original fp16 parameters, and updates these fp16 parameters from the master fp32 copy at the end of each :attr:`step`. + static_loss_scale (float, optional, default=1.0): Loss scale used internally to scale gradients computed by the model. Any fp16 gradients will be copied to fp32, then downscaled before being applied to the fp32 master params, so ``static_loss_scale`` should not affect learning rate. + dynamic_loss_scale (bool, optional, default=False): Use dynamic loss scaling. If True, this will override any ``static_loss_scale`` option. + dynamic_loss_args (dict, optional, default=None): Dict of kwargs that will be forwarded to the internal :class:`DynamicLossScaler` instance's constructor. Keys of this dict must match kwargs accepted by :class:`DynamicLossScaler`'s constructor. If ``dynamic_loss_args`` is unspecified, :class:`DynamicLossScaler`'s defaults will be used. + verbose (bool, optional, default=True): By default, FP16_Optimizer's constructor prints out the parameters and parameter groups it is ingesting, as a sanity check. If this becomes annoying (e.g. for large models), it can be disabled by passing ``verbose=False``. ``verbose=False`` will not disable printing when the loss scale is readjusted during dynamic loss scaling. + + ``init_optimizer`` is expected to have been constructed in the ordinary way. + It is recommended (although not required) that the newly constructed :class:`FP16_Optimizer` instance be + named to replace ``init_optimizer``, for two reasons: + First, it means that references to the same name + later in the file will not have to change. + Second, :class:`FP16_Optimizer` reserves the right (as an implementation detail) to + modify ``init_optimizer``. If you do choose a unique name for the new + :class:`FP16_Optimizer` instance, you should only work with this new instance, + because the preexisting optimizer might no longer behave as expected. + + ``init_optimizer`` may be any Pytorch optimizer. + It may contain a mixture of fp16 and fp32 parameters organized into any number of + ``param_groups`` with different hyperparameters. The :class:`FP16_Optimizer` constructor will + ingest these ``param_groups`` and remember them. + + Calls to :: + + loss.backward() + + must be replaced with :: + + optimizer.backward(loss) + + because :class:`FP16_Optimizer` requires ownership of the backward pass to implement + loss scaling and copies to master gradients. + + .. note:: + Loss scaling, either static or dynamic, is orthogonal to learning rate, because gradients + are downscaled before being applied. This means that adjusting the loss scale, or using + dynamic loss scaling, should not require retuning the learning rate or any other + hyperparameters. + + + **Advanced options** + + **Closures**: :class:`FP16_Optimizer` can wrap a Pytorch optimizer that receives a closure. + See docstring for :attr:`step`. + + **Gradient clipping**: Use :attr:`clip_master_grads`. + + **Multiple losses**: If your model accumulates gradients from multiple losses, + this can be made more efficient by supplying ``update_master_grads=False`` + to :attr:`backward`. See docstring for :attr:`backward`. + + **Manually adjusting loss scale**: The current loss scale can be retrieved or set via :: + + print(optimizer.loss_scale) + optimizer.loss_scale = new_loss_scale + + For static loss scaling, manually adjusting the loss scale over time is a reasonable + thing to do. During later epochs, gradients may become smaller, and a + higher loss scale may be required, analogous to scheduling the learning rate. Dynamic loss + scaling is more subtle (see :class:`DynamicLossScaler`) and in this case, manually adjusting + the loss scale is not recommended. + + **Multi_GPU training**: If the wrapped ``init_optimizer`` was created from a model wrapped in + Pytorch DistributedDataParallel or Apex DistributedDataParallel, :class:`FP16_Optimizer` + should still work as intended. + """ + + def __init__(self, + init_optimizer, + static_loss_scale=1.0, + dynamic_loss_scale=False, + dynamic_loss_args=None, + verbose=False): + if not get_accelerator().is_available(): + raise SystemError("Cannot use fp16 without CUDA.") + + self.verbose = verbose + + self.optimizer = init_optimizer + # init_state_dict sets up an alternative way to cast per-param state tensors. + # Stashing here in case https://github.com/pytorch/pytorch/issues/7733 makes it necessary. + # init_state_dict = init_optimizer.state_dict() + + self.fp16_groups = [] + self.fp32_from_fp16_groups = [] + self.fp32_from_fp32_groups = [] + for i, param_group in enumerate(self.optimizer.param_groups): + self.maybe_print("FP16_Optimizer processing param group {}:".format(i)) + fp16_params_this_group = [] + fp32_params_this_group = [] + fp32_from_fp16_params_this_group = [] + for i, param in enumerate(param_group['params']): + if param.requires_grad: + half_tensor_type_string = 'torch.{}.HalfTensor'.format(get_accelerator().device_name()) + float_tensor_type_string = 'torch.{}.FloatTensor'.format(get_accelerator().device_name()) + if param.type() == half_tensor_type_string: + self.maybe_print("FP16_Optimizer received {} with {}" + .format(half_tensor_type_string, param.size())) + fp16_params_this_group.append(param) + master_param = param.detach().clone().float() + master_param.requires_grad = True + # Copythe model parallel flag. + master_param.model_parallel = param.model_parallel + param_group['params'][i] = master_param + fp32_from_fp16_params_this_group.append(master_param) + # Reset existing state dict key to the new master param. + # We still need to recast per-param state tensors, if any, to FP32. + if param in self.optimizer.state: + self.optimizer.state[master_param] = self.optimizer.state.pop(param) + elif param.type() == float_tensor_type_string: + self.maybe_print("FP16_Optimizer received {} with {}" + .format(float_tensor_type_string, param.size())) + fp32_params_this_group.append(param) + param_group['params'][i] = param + else: + raise TypeError("Wrapped parameters must be either " + "{} or {}. " + "Received {}".format(float_tensor_type_string, + half_tensor_type_string, + param.type())) + + self.fp16_groups.append(fp16_params_this_group) + self.fp32_from_fp16_groups.append(fp32_from_fp16_params_this_group) + self.fp32_from_fp32_groups.append(fp32_params_this_group) + + # Leverage state_dict() and load_state_dict() to recast preexisting per-param state tensors + self.optimizer.load_state_dict(self.optimizer.state_dict()) + # alternative way to cast per-param state tensors: + # self.optimizer.load_state_dict(init_state_dict) + + if dynamic_loss_scale: + self.dynamic_loss_scale = True + if dynamic_loss_args is not None: + self.loss_scaler = DynamicLossScaler(**dynamic_loss_args) + else: + self.loss_scaler = DynamicLossScaler() + else: + self.dynamic_loss_scale = False + self.loss_scaler = LossScaler(static_loss_scale) + + self.overflow = False + self.first_closure_call_this_step = True + + self.clip_grad_norm = clip_grad_norm + + def maybe_print(self, msg): + if self.verbose: + print(msg) + + def __getstate__(self): + raise RuntimeError("FP16_Optimizer should be serialized using state_dict().") + + def __setstate__(self, state): + raise RuntimeError("FP16_Optimizer should be deserialized using load_state_dict().") + + def zero_grad(self, set_grads_to_None=False): + """ + Zero fp32 and fp16 parameter grads. + """ + # In principle, only the .grad attributes of the model params need to be zeroed, + # because gradients are copied into the FP32 master params. However, we zero + # all gradients owned by the optimizer, just to be safe: + for group in self.optimizer.param_groups: + for p in group['params']: + if set_grads_to_None: + p.grad = None + else: + if p.grad is not None: + p.grad.detach_() + p.grad.zero_() + + # Zero fp16 gradients owned by the model: + for fp16_group in self.fp16_groups: + for param in fp16_group: + if set_grads_to_None: + param.grad = None + else: + if param.grad is not None: + param.grad.detach_() # as in torch.optim.optimizer.zero_grad() + param.grad.zero_() + + def _check_overflow(self): + params = [] + for group in self.fp16_groups: + for param in group: + params.append(param) + for group in self.fp32_from_fp32_groups: + for param in group: + params.append(param) + self.overflow = self.loss_scaler.has_overflow(params) + + def _update_scale(self, has_overflow=False): + self.loss_scaler.update_scale(has_overflow) + + def _master_params_to_model_params(self): + for fp16_group, fp32_from_fp16_group in zip(self.fp16_groups, self.fp32_from_fp16_groups): + master_params_to_model_params(fp16_group, fp32_from_fp16_group) + + def _model_params_to_master_params(self): + for fp16_group, fp32_from_fp16_group in zip(self.fp16_groups, self.fp32_from_fp16_groups): + master_params_to_model_params(fp32_from_fp16_group, fp16_group) + + # To consider: Integrate distributed with this wrapper by registering a hook on each variable + # that does the overflow check, gradient copy + downscale, and fp32 allreduce in a different stream. + def _model_grads_to_master_grads(self): + for fp16_group, fp32_from_fp16_group in zip(self.fp16_groups, self.fp32_from_fp16_groups): + model_grads_to_master_grads(fp16_group, fp32_from_fp16_group) + + def _downscale_master(self): + if self.loss_scale != 1.0: + for group in self.optimizer.param_groups: + for param in group['params']: + if param.grad is not None: + param.grad.data.mul_(1./self.loss_scale) + + def clip_master_grads(self, max_norm, norm_type=2): + """ + Clips fp32 master gradients via ``torch.nn.utils.clip_grad_norm``. + + Args: + max_norm (float or int): max norm of the gradients + norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for + infinity norm. + + Returns: + Total norm of the current fp32 gradients (viewed as a single vector). + + .. warning:: + Returns -1 if the most recently computed fp16 gradients overflowed (that is, if ``self.overflow`` is ``True``). + """ + if not self.overflow: + fp32_params = [] + for param_group in self.optimizer.param_groups: + for param in param_group['params']: + fp32_params.append(param) + return self.clip_grad_norm(fp32_params, max_norm, norm_type) + else: + return -1 + + def state_dict(self): + """ + Returns a dict containing the current state of this :class:`FP16_Optimizer` instance. + This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict + of the contained Pytorch optimizer. + Example:: + + checkpoint = {} + checkpoint['model'] = model.state_dict() + checkpoint['optimizer'] = optimizer.state_dict() + torch.save(checkpoint, "saved.pth") + """ + state_dict = {} + state_dict['loss_scaler'] = self.loss_scaler + state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale + state_dict['overflow'] = self.overflow + state_dict['first_closure_call_this_step'] = self.first_closure_call_this_step + state_dict['optimizer_state_dict'] = self.optimizer.state_dict() + state_dict['fp32_from_fp16'] = self.fp32_from_fp16_groups + return state_dict + + def load_state_dict(self, state_dict): + """ + Loads a state_dict created by an earlier call to state_dict(). + If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``, + whose parameters in turn came from ``model``, it is expected that the user + will call ``model.load_state_dict()`` before + ``fp16_optimizer_instance.load_state_dict()`` is called. + + Example:: + + model = torch.nn.Linear(D_in, D_out).cuda().half() + optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) + optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0) + ... + checkpoint = torch.load("saved.pth") + model.load_state_dict(checkpoint['model']) + optimizer.load_state_dict(checkpoint['optimizer']) + """ + # I think it should actually be ok to reload the optimizer before the model. + self.loss_scaler = state_dict['loss_scaler'] + self.dynamic_loss_scale = state_dict['dynamic_loss_scale'] + self.overflow = state_dict['overflow'] + self.first_closure_call_this_step = state_dict['first_closure_call_this_step'] + self.optimizer.load_state_dict(state_dict['optimizer_state_dict']) + # At this point, the optimizer's references to the model's fp32 parameters are up to date. + # The optimizer's hyperparameters and internal buffers are also up to date. + # However, the fp32 master copies of the model's fp16 params stored by the optimizer are still + # out of date. There are two options. + # 1: Refresh the master params from the model's fp16 params. + # This requires less storage but incurs precision loss. + # 2: Save and restore the fp32 master copies separately. + # We choose option 2. + # + # Pytorch Optimizer.load_state_dict casts saved buffers (e.g. momentum) to the type and device + # of their associated parameters, because it's possible those buffers might not exist yet in + # the current optimizer instance. In our case, as long as the current FP16_Optimizer has been + # constructed in the same way as the one whose state_dict we are loading, the same master params + # are guaranteed to exist, so we can just copy_() from the saved master params. + for current_group, saved_group in zip(self.fp32_from_fp16_groups, state_dict['fp32_from_fp16']): + for current, saved in zip(current_group, saved_group): + current.data.copy_(saved.data) + + def step(self, closure=None): # could add clip option. + """ + If no closure is supplied, :attr:`step` should be called after + ``fp16_optimizer_obj.backward(loss)``. + :attr:`step` updates the fp32 master copy of parameters using the optimizer supplied to + :class:`FP16_Optimizer`'s constructor, then copies the updated fp32 params into the fp16 params + originally referenced by :class:`FP16_Optimizer`'s constructor, so the user may immediately run + another forward pass using their model. + + If a closure is supplied, :attr:`step` may be called without a prior call to + :attr:`backward(loss)`. + This control flow is identical to `ordinary Pytorch optimizer use`_ with closures. + However, the user should take care that any ``loss.backward()`` call within the closure + has been replaced by ``fp16_optimizer_obj.backward(loss)``. + + Args: + closure (optional): Closure that will be supplied to the underlying optimizer originally passed to :class:`FP16_Optimizer`'s constructor. closure should call :attr:`zero_grad()` on the :class:`FP16_Optimizer` object, compute the loss, call :attr:`backward(loss)`, and return the loss. + + Example with closure:: + + # optimizer is assumed to be an FP16_Optimizer object, previously constructed from an + # existing pytorch optimizer. + for input, target in dataset: + def closure(): + optimizer.zero_grad() + output = model(input) + loss = loss_fn(output, target) + # loss.backward() becomes: + optimizer.backward(loss) + return loss + optimizer.step(closure) + + .. warning:: + Currently, calling :attr:`step` with a closure is not compatible with dynamic loss scaling. + + .. _`ordinary Pytorch optimizer use`: + http://pytorch.org/docs/master/optim.html#optimizer-step-closure + """ + + scale = self.loss_scaler.loss_scale + self._update_scale(self.overflow) + + if self.overflow: + self.maybe_print("OVERFLOW! Skipping step. Attempted loss scale: {}, reducing to {}" + .format(scale, self.loss_scale)) + return + + if closure is not None: + retval = self._step_with_closure(closure) + else: + retval = self.optimizer.step() + + self._master_params_to_model_params() + + return retval + + def _step_with_closure(self, closure): + def wrapped_closure(): + # helpful for debugging + # print("Calling wrapped_closure, first_closure_call_this_step = {}" + # .format(self.first_closure_call_this_step)) + if self.first_closure_call_this_step: + # We expect that the fp16 params are initially fresh on entering self.step(), + # so _master_params_to_model_params() is unnecessary the first time wrapped_closure() + # is called within self.optimizer.step(). + self.first_closure_call_this_step = False + else: + # If self.optimizer.step() internally calls wrapped_closure more than once, + # it may update the fp32 params after each call. However, self.optimizer + # doesn't know about the fp16 params at all. If the fp32 params get updated, + # we can't rely on self.optimizer to refresh the fp16 params. We need + # to handle that manually: + self._master_params_to_model_params() + # Our API expects the user to give us ownership of the backward() call by + # replacing all calls to loss.backward() with optimizer.backward(loss). + # This requirement holds whether or not the call to backward() is made within a closure. + # If the user is properly calling optimizer.backward(loss) within "closure," + # calling closure() here will give the fp32 master params fresh gradients + # for the optimizer to play with, so all wrapped_closure needs to do is call + # closure() and return the loss. + temp_loss = closure() + while(self.overflow): + scale = self.loss_scaler.loss_scale + self._update_scale(self.overflow) + self.maybe_print("OVERFLOW within closure! Skipping step. Attempted loss scale: {}, " + "reducing to {}".format(scale, self.loss_scale)) + temp_loss = closure() + return temp_loss + + retval = self.optimizer.step(wrapped_closure) + + self.first_closure_call_this_step = True + + return retval + + def backward(self, loss, update_master_grads=True, retain_graph=False): + """ + :attr:`backward` performs the following conceptual steps: + + 1. fp32_loss = loss.float() (see first Note below) + 2. scaled_loss = fp32_loss*loss_scale + 3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's leaves (which may be fp16, fp32, or a mixture, depending how your model was defined). + 4. fp16 grads are then copied to the master params' ``.grad`` attributes (see second Note), which are guaranteed to be fp32. + 5. Finally, master grads are divided by loss_scale. + + In this way, after :attr:`backward`, the master params have fresh gradients, + and :attr:`step` may be called. + + .. note:: + :attr:`backward` internally converts the loss to fp32 before applying the loss scale. + This provides some additional safety against overflow if the user has supplied an + fp16 loss value. + However, for maximum overflow safety, the user should + compute the loss criterion (MSE, cross entropy, etc) in fp32 before supplying it to + :attr:`backward`. + + .. warning:: + The gradients found in a model's leaves after the call to + :attr:`backward` should not be regarded as valid in general, + because it's possible + they have been scaled (and in the case of dynamic loss scaling, + the scale factor may change over time). + If the user wants to inspect gradients after a call to :attr:`backward`, + only the master gradients should be regarded as valid. These can be retrieved via + :attr:`inspect_master_grad_data()`. + + Args: + loss: The loss output by the user's model. loss may be either float or half (but see first Note above). + update_master_grads (bool, optional, default=True): Option to copy fp16 grads to fp32 grads on this call. By setting this to False, the user can delay the copy, which is useful to eliminate redundant fp16->fp32 grad copies if :attr:`backward` is being called on multiple losses in one iteration. If set to False, the user becomes responsible for calling :attr:`update_master_grads` before calling :attr:`step`. + retain_graph (bool, optional, default=False): Forwards the usual ``retain_graph=True`` option to the internal call to ``loss.backward``. If ``retain_graph`` is being used to accumulate gradient values from multiple backward passes before calling ``optimizer.step``, passing ``update_master_grads=False`` is also recommended (see Example below). + + Example:: + + # Ordinary operation: + optimizer.backward(loss) + + # Naive operation with multiple losses (technically valid, but less efficient): + # fp32 grads will be correct after the second call, but + # the first call incurs an unnecessary fp16->fp32 grad copy. + optimizer.backward(loss1) + optimizer.backward(loss2) + + # More efficient way to handle multiple losses: + # The fp16->fp32 grad copy is delayed until fp16 grads from all + # losses have been accumulated. + optimizer.backward(loss1, update_master_grads=False) + optimizer.backward(loss2, update_master_grads=False) + optimizer.update_master_grads() + """ + # To consider: try multiple backward passes using retain_grad=True to find + # a loss scale that works. After you find a loss scale that works, do a final dummy + # backward pass with retain_graph=False to tear down the graph. Doing this would avoid + # discarding the iteration, but probably wouldn't improve overall efficiency. + self.loss_scaler.backward(loss.float(), retain_graph=retain_graph) + if update_master_grads: + self.update_master_grads() + + def update_master_grads(self): + """ + Copy the ``.grad`` attribute from stored references to fp16 parameters to + the ``.grad`` attribute of the fp32 master parameters that are directly + updated by the optimizer. :attr:`update_master_grads` only needs to be called if + ``fp16_optimizer_obj.backward`` was called with ``update_master_grads=False``. + """ + if self.dynamic_loss_scale: + self._check_overflow() + if self.overflow: return + self._model_grads_to_master_grads() + self._downscale_master() + + def inspect_master_grad_data(self): + """ + When running with :class:`FP16_Optimizer`, + ``.grad`` attributes of a model's fp16 leaves should not be + regarded as truthful, because they might be scaled. + After a call to :attr:`fp16_optimizer_obj.backward(loss)`, if no overflow was encountered, + the fp32 master params' ``.grad`` + attributes will contain valid gradients properly divided by the loss scale. However, + because :class:`FP16_Optimizer` flattens some parameters, accessing them may be + nonintuitive. :attr:`inspect_master_grad_data` + allows those gradients to be viewed with shapes corresponding to their associated model leaves. + + Returns: + List of lists (one list for each parameter group). The list for each parameter group + is a list of the ``.grad.data`` attributes of the fp32 master params belonging to that group. + """ + if self.overflow: + print("Warning: calling FP16_Optimizer.inspect_master_grad_data while in an overflow state. " + "Gradients are currently invalid (may be inf, nan, or stale). Returning None.") + return None + else: + # The optimizer owns only references to master params. + master_grads_data = [] + for param_group in self.optimizer.param_groups: + master_grads_this_group = [] + for param in param_group['params']: + if param.grad is not None: + master_grads_this_group.append(param.grad.data) + else: + master_grads_this_group.append(None) + master_grads_data.append(master_grads_this_group) + return master_grads_data + + + # Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale" + def _get_loss_scale(self): + return self.loss_scaler.loss_scale + + def _set_loss_scale(self, value): + self.loss_scaler.cur_scale = value + + loss_scale = property(_get_loss_scale, _set_loss_scale) + + # Promote state so it can be retrieved or set via "fp16_optimizer_instance.state" + def _get_state(self): + return self.optimizer.state + + def _set_state(self, value): + self.optimizer.state = value + + state = property(_get_state, _set_state) + + # Promote param_groups so it can be retrieved or set via "fp16_optimizer_instance.param_groups" + # (for example, to adjust the learning rate) + def _get_param_groups(self): + return self.optimizer.param_groups + + def _set_param_groups(self, value): + self.optimizer.param_groups = value + + param_groups = property(_get_param_groups, _set_param_groups) diff --git a/examples/Megatron-LM/fp16/fp16util.py b/examples/Megatron-LM/fp16/fp16util.py new file mode 100644 index 0000000..8572bb1 --- /dev/null +++ b/examples/Megatron-LM/fp16/fp16util.py @@ -0,0 +1,204 @@ +# coding=utf-8 +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +from torch.autograd import Variable +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors + +import mpu + + +class tofp16(nn.Module): + """ + Utility module that implements:: + + def forward(self, input): + return input.half() + """ + + def __init__(self): + super(tofp16, self).__init__() + + def forward(self, input): + return input.half() + + +def BN_convert_float(module): + """ + Utility function for network_to_half(). + + Retained for legacy purposes. + """ + if isinstance(module, torch.nn.modules.batchnorm._BatchNorm) and module.affine is True: + module.float() + for child in module.children(): + BN_convert_float(child) + return module + + +def network_to_half(network): + """ + Convert model to half precision in a batchnorm-safe way. + + Retained for legacy purposes. It is recommended to use FP16Model. + """ + return nn.Sequential(tofp16(), BN_convert_float(network.half())) + + +def convert_module(module, dtype): + """ + Converts a module's immediate parameters and buffers to dtype. + """ + for param in module.parameters(recurse=False): + if param is not None: + if param.data.dtype.is_floating_point: + param.data = param.data.to(dtype=dtype) + if param._grad is not None and param._grad.data.dtype.is_floating_point: + param._grad.data = param._grad.data.to(dtype=dtype) + + for buf in module.buffers(recurse=False): + if buf is not None and buf.data.dtype.is_floating_point: + buf.data = buf.data.to(dtype=dtype) + + +def convert_network(network, dtype): + """ + Converts a network's parameters and buffers to dtype. + """ + for module in network.modules(): + if isinstance(module, torch.nn.modules.batchnorm._BatchNorm) and module.affine is True: + continue + convert_module(module, dtype) + return network + + +class FP16Model(nn.Module): + """ + Convert model to half precision in a batchnorm-safe way. + """ + + def __init__(self, network): + super(FP16Model, self).__init__() + self.network = convert_network(network, dtype=torch.half) + + def forward(self, *inputs): + inputs = tuple(t.half() for t in inputs) + return self.network(*inputs) + + +def backwards_debug_hook(grad): + raise RuntimeError("master_params recieved a gradient in the backward pass!") + +def prep_param_lists(model, flat_master=False): + """ + Creates a list of FP32 master parameters for a given model, as in + `Training Neural Networks with Mixed Precision: Real Examples`_. + + Args: + model (torch.nn.Module): Existing Pytorch model + flat_master (bool, optional, default=False): Flatten the master parameters into a single tensor, as a performance optimization. + Returns: + A tuple (``model_params``, ``master_params``). ``model_params`` is a list of the model's parameters for later use with :func:`model_grads_to_master_grads` and :func:`master_params_to_model_params`. ``master_params`` is a list of FP32 master gradients. If ``flat_master=True``, ``master_params`` will be a list with one element. + + Example:: + + model_params, master_params = prep_param_lists(model) + + .. warning:: + Currently, if ``flat_master=True``, all the model's parameters must be the same type. If the model has parameters of different types, use ``flat_master=False``, or use :class:`FP16_Optimizer`. + + .. _`Training Neural Networks with Mixed Precision: Real Examples`: + http://on-demand.gputechconf.com/gtc/2018/video/S81012/ + """ + model_params = [param for param in model.parameters() if param.requires_grad] + + if flat_master: + # Give the user some more useful error messages + try: + # flatten_dense_tensors returns a contiguous flat array. + # http://pytorch.org/docs/master/_modules/torch/_utils.html + master_params = _flatten_dense_tensors([param.data for param in model_params]).float() + except: + print("Error in prep_param_lists: model may contain a mixture of parameters " + "of different types. Use flat_master=False, or use F16_Optimizer.") + raise + master_params = torch.nn.Parameter(master_params) + master_params.requires_grad = True + # master_params.register_hook(backwards_debug_hook) + if master_params.grad is None: + master_params.grad = master_params.new(*master_params.size()) + return model_params, [master_params] + else: + master_params = [param.clone().float().detach() for param in model_params] + for param in master_params: + param.requires_grad = True + return model_params, master_params + + +def model_grads_to_master_grads(model_params, master_params, flat_master=False): + """ + Copy model gradients to master gradients. + + Args: + model_params: List of model parameters created by :func:`prep_param_lists`. + master_params: List of FP32 master parameters created by :func:`prep_param_lists`. If ``master_params`` was created with ``flat_master=True``, ``flat_master=True`` should also be supplied to :func:`model_grads_to_master_grads`. + """ + if flat_master: + # The flattening may incur one more deep copy than is necessary. + master_params[0].grad.data.copy_( + _flatten_dense_tensors([p.grad.data for p in model_params])) + else: + for model, master in zip(model_params, master_params): + if model.grad is not None: + if master.grad is None: + master.grad = Variable(master.data.new(*master.data.size())) + master.grad.data.copy_(model.grad.data) + else: + master.grad = None + + +def master_params_to_model_params(model_params, master_params, flat_master=False): + """ + Copy master parameters to model parameters. + + Args: + model_params: List of model parameters created by :func:`prep_param_lists`. + master_params: List of FP32 master parameters created by :func:`prep_param_lists`. If ``master_params`` was created with ``flat_master=True``, ``flat_master=True`` should also be supplied to :func:`master_params_to_model_params`. + """ + if flat_master: + for model, master in zip(model_params, + _unflatten_dense_tensors(master_params[0].data, model_params)): + model.data.copy_(master) + else: + for model, master in zip(model_params, master_params): + model.data.copy_(master.data) + +# Backward compatibility fixes + +def to_python_float(t): + if hasattr(t, 'item'): + return t.item() + else: + return t[0] + +TORCH_MAJOR = int(torch.__version__.split('.')[0]) +TORCH_MINOR = int(torch.__version__.split('.')[1]) + +clip_grad_norm = mpu.clip_grad_norm +#elif TORCH_MAJOR == 0 and TORCH_MINOR <= 4: +# clip_grad_norm = torch.nn.utils.clip_grad_norm +#else: +# clip_grad_norm = torch.nn.utils.clip_grad_norm_ diff --git a/examples/Megatron-LM/fp16/loss_scaler.py b/examples/Megatron-LM/fp16/loss_scaler.py new file mode 100755 index 0000000..2c5136e --- /dev/null +++ b/examples/Megatron-LM/fp16/loss_scaler.py @@ -0,0 +1,238 @@ +# coding=utf-8 +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import mpu +from deepspeed.accelerator.real_accelerator import get_accelerator + +# item() is a recent addition, so this helps with backward compatibility. +def to_python_float(t): + if hasattr(t, 'item'): + return t.item() + else: + return t[0] + +class LossScaler: + """ + Class that manages a static loss scale. This class is intended to interact with + :class:`FP16_Optimizer`, and should not be directly manipulated by the user. + + Use of :class:`LossScaler` is enabled via the ``static_loss_scale`` argument to + :class:`FP16_Optimizer`'s constructor. + + Args: + scale (float, optional, default=1.0): The loss scale. + """ + + def __init__(self, scale=1): + self.cur_scale = scale + + # `params` is a list / generator of torch.Variable + def has_overflow(self, params): + return False + + # `x` is a torch.Tensor + def _has_inf_or_nan(x): + return False + + def update_scale(self, overflow): + pass + + @property + def loss_scale(self): + return self.cur_scale + + def scale_gradient(self, module, grad_in, grad_out): + return tuple(self.loss_scale * g for g in grad_in) + + def backward(self, loss, retain_graph=False): + scaled_loss = loss*self.loss_scale + scaled_loss.backward(retain_graph=retain_graph) + +class DynamicLossScaler: + """ + Class that manages dynamic loss scaling. It is recommended to use :class:`DynamicLossScaler` + indirectly, by supplying ``dynamic_loss_scale=True`` to the constructor of + :class:`FP16_Optimizer`. However, it's important to understand how :class:`DynamicLossScaler` + operates, because the default options can be changed using the + the ``dynamic_loss_args`` argument to :class:`FP16_Optimizer`'s constructor. + + Loss scaling is designed to combat the problem of underflowing gradients encountered at long + times when training fp16 networks. Dynamic loss scaling begins by attempting a very high loss + scale. Ironically, this may result in OVERflowing gradients. If overflowing gradients are + encountered, :class:`DynamicLossScaler` informs :class:`FP16_Optimizer` that an overflow has + occurred. + :class:`FP16_Optimizer` then skips the update step for this particular iteration/minibatch, + and :class:`DynamicLossScaler` adjusts the loss scale to a lower value. + If a certain number of iterations occur without overflowing gradients detected, + :class:`DynamicLossScaler` increases the loss scale once more. + In this way :class:`DynamicLossScaler` attempts to "ride the edge" of + always using the highest loss scale possible without incurring overflow. + + Args: + init_scale (float, optional, default=2**32): Initial loss scale attempted by :class:`DynamicLossScaler.` + scale_factor (float, optional, default=2.0): Factor used when adjusting the loss scale. If an overflow is encountered, the loss scale is readjusted to loss scale/``scale_factor``. If ``scale_window`` consecutive iterations take place without an overflow, the loss scale is readjusted to loss_scale*``scale_factor``. + scale_window (int, optional, default=1000): Number of consecutive iterations without an overflow to wait before increasing the loss scale. + """ + + def __init__(self, + init_scale=2**32, + scale_factor=2., + scale_window=1000, + min_scale=1, + delayed_shift=1, + consecutive_hysteresis=False): + self.cur_scale = init_scale + self.cur_iter = 0 + self.last_overflow_iter = -1 + self.scale_factor = scale_factor + self.scale_window = scale_window + self.min_scale = min_scale + self.delayed_shift = delayed_shift + self.cur_hysteresis = delayed_shift + self.consecutive_hysteresis = consecutive_hysteresis + + # `params` is a list / generator of torch.Variable + def has_overflow_serial(self, params): + for p in params: + if p.grad is not None and DynamicLossScaler._has_inf_or_nan(p.grad.data): + return True + + return False + + def has_overflow(self, params): + overflow = self.has_overflow_serial(params) + # Since each model parallel GPU carries only part of the model, + # make sure overflow flag is synced across all the model parallel GPUs + overflow_gpu = get_accelerator().ByteTensor([overflow]) + torch.distributed.all_reduce(overflow_gpu, + op=torch.distributed.ReduceOp.MAX, + group=mpu.get_model_parallel_group()) + overflow = overflow_gpu[0].item() + return bool(overflow) + + + # `x` is a torch.Tensor + def _has_inf_or_nan(x): + try: + # if x is half, the .float() incurs an additional deep copy, but it's necessary if + # Pytorch's .sum() creates a one-element tensor of the same type as x + # (which is true for some recent version of pytorch). + cpu_sum = float(x.float().sum()) + # More efficient version that can be used if .sum() returns a Python scalar + # cpu_sum = float(x.sum()) + except RuntimeError as instance: + # We want to check if inst is actually an overflow exception. + # RuntimeError could come from a different error. + # If so, we still want the exception to propagate. + if "value cannot be converted" not in instance.args[0]: + raise + return True + else: + if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum: + return True + return False + + # `overflow` is boolean indicating whether the gradient overflowed + def update_scale(self, overflow): + + if not hasattr(self, 'min_scale'): + self.min_scale = 1 + if not hasattr(self, 'delayed_shift'): + self.delayed_shift = 1 + if not hasattr(self, 'cur_hysteresis'): + self.cur_hysteresis = 1 + if not hasattr(self, 'consecutive_hysteresis'): + self.consecutive_hysteresis = True + if overflow: + # self.cur_scale /= self.scale_factor + if self.delayed_shift == 1 or self.cur_hysteresis == 1: + self.cur_scale = max(self.cur_scale/self.scale_factor, self.min_scale) + else: + self.cur_hysteresis -= 1 + self.last_overflow_iter = self.cur_iter + else: + if self.consecutive_hysteresis: + self.cur_hysteresis = self.delayed_shift + if (self.cur_iter - self.last_overflow_iter) % self.scale_window == 0: + if not self.consecutive_hysteresis: + self.cur_hysteresis = self.delayed_shift + self.cur_scale *= self.scale_factor + self.cur_iter += 1 + + @property + def loss_scale(self): + return self.cur_scale + + def scale_gradient(self, module, grad_in, grad_out): + return tuple(self.loss_scale * g for g in grad_in) + + def backward(self, loss, retain_graph=False): + scaled_loss = loss*self.loss_scale + scaled_loss.backward(retain_graph=retain_graph) + +############################################################## +# Example usage below here -- assuming it's in a separate file +############################################################## +""" +TO-DO separate out into an example. +if __name__ == "__main__": + import torch + from torch.autograd import Variable + from dynamic_loss_scaler import DynamicLossScaler + + # N is batch size; D_in is input dimension; + # H is hidden dimension; D_out is output dimension. + N, D_in, H, D_out = 64, 1000, 100, 10 + + # Create random Tensors to hold inputs and outputs, and wrap them in Variables. + x = Variable(torch.randn(N, D_in), requires_grad=False) + y = Variable(torch.randn(N, D_out), requires_grad=False) + + w1 = Variable(torch.randn(D_in, H), requires_grad=True) + w2 = Variable(torch.randn(H, D_out), requires_grad=True) + parameters = [w1, w2] + + learning_rate = 1e-6 + optimizer = torch.optim.SGD(parameters, lr=learning_rate) + loss_scaler = DynamicLossScaler() + + for t in range(500): + y_pred = x.mm(w1).clamp(min=0).mm(w2) + loss = (y_pred - y).pow(2).sum() * loss_scaler.loss_scale + print('Iter {} loss scale: {}'.format(t, loss_scaler.loss_scale)) + print('Iter {} scaled loss: {}'.format(t, loss.data[0])) + print('Iter {} unscaled loss: {}'.format(t, loss.data[0] / loss_scaler.loss_scale)) + + # Run backprop + optimizer.zero_grad() + loss.backward() + + # Check for overflow + has_overflow = DynamicLossScaler.has_overflow(parameters) + + # If no overflow, unscale grad and update as usual + if not has_overflow: + for param in parameters: + param.grad.data.mul_(1. / loss_scaler.loss_scale) + optimizer.step() + # Otherwise, don't do anything -- ie, skip iteration + else: + print('OVERFLOW!') + + # Update loss scale for next iteration + loss_scaler.update_scale(has_overflow) + +""" diff --git a/examples/Megatron-LM/generate_samples.py b/examples/Megatron-LM/generate_samples.py new file mode 100755 index 0000000..639f904 --- /dev/null +++ b/examples/Megatron-LM/generate_samples.py @@ -0,0 +1,281 @@ +# coding=utf-8 +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Sample Generate GPT2""" + +import os +import random +import numpy as np +import torch +import torch.nn.functional as F +import argparse +import time +from arguments import get_args +from utils import Timers +from pretrain_gpt2 import initialize_distributed +from pretrain_gpt2 import set_random_seed +from pretrain_gpt2 import get_train_val_test_data +from pretrain_gpt2 import get_masks_and_position_ids +from utils import load_checkpoint +from data_utils import make_tokenizer +from configure_data import configure_data +import mpu +from deepspeed.accelerator.real_accelerator import get_accelerator + +from fp16 import FP16_Module +from model import GPT2Model +from model import DistributedDataParallel as DDP +from utils import print_rank_0 + +def get_model(args): + """Build the model.""" + + print_rank_0('building GPT2 model ...') + model = GPT2Model(num_layers=args.num_layers, + vocab_size=args.vocab_size, + hidden_size=args.hidden_size, + num_attention_heads=args.num_attention_heads, + embedding_dropout_prob=args.hidden_dropout, + attention_dropout_prob=args.attention_dropout, + output_dropout_prob=args.hidden_dropout, + max_sequence_length=args.max_position_embeddings, + checkpoint_activations=args.checkpoint_activations, + checkpoint_num_layers=args.checkpoint_num_layers, + parallel_output=False) + + if mpu.get_data_parallel_rank() == 0: + print(' > number of parameters on model parallel rank {}: {}'.format( + mpu.get_model_parallel_rank(), + sum([p.nelement() for p in model.parameters()])), flush=True) + + # GPU allocation. + model.to(torch.device(get_accelerator().current_device_name())) + + # Fp16 conversion. + if args.fp16: + model = FP16_Module(model) + + # Wrap model for distributed training. + model = DDP(model) + + return model + +def setup_model(args): + """Setup model and optimizer.""" + + model = get_model(args) + + if args.load is not None: + _ = load_checkpoint( + model, None, None, args) + + return model + + +def get_batch(context_tokens, device, args): + tokens = context_tokens + tokens = tokens.view(args.batch_size, -1).contiguous() + tokens = tokens.to(device) + + # Get the masks and postition ids. + attention_mask, loss_mask, position_ids = get_masks_and_position_ids( + tokens, + args.eod_token, + args.reset_position_ids, + args.reset_attention_mask) + + return tokens, attention_mask, position_ids + +def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): + # This function has been mostly taken from huggingface conversational ai code at + # https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313 + + if top_k > 0: + # Remove all tokens with a probability less than the last token of the top-k + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits[indices_to_remove] = filter_value + + if top_p > 0.0: + #convert to 1D + logits=logits.view(logits.size()[1]).contiguous() + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + + # Remove tokens with cumulative probability above the threshold + sorted_indices_to_remove = cumulative_probs > top_p + # Shift the indices to the right to keep also the first token above the threshold + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + indices_to_remove = sorted_indices[sorted_indices_to_remove] + logits[indices_to_remove] = filter_value + #going back to 2D + logits=logits.view(1, -1).contiguous() + + return logits + + +def generate_samples(model, tokenizer, args, device): + + context_count=0 + model.eval() + with torch.no_grad(): + while True: + torch.distributed.barrier(group=mpu.get_model_parallel_group()) + terminate_runs=0 + + if mpu.get_model_parallel_rank() == 0: + raw_text = input("\nContext prompt (stop to exit) >>> ") + while not raw_text: + print('Prompt should not be empty!') + raw_text = input("\nContext prompt (stop to exit) >>> ") + + if "stop" in raw_text: + terminate_runs = 1 + else: + context_tokens = tokenizer.EncodeAsIds(raw_text).tokenization + context_length = len(context_tokens) + + if context_length >=args.seq_length//2: + print("\nContext length", context_length, \ + "\nPlease give smaller context (half of the sequence length)!") + continue + else: + context_tokens = tokenizer.EncodeAsIds("EMPTY TEXT").tokenization + context_length = len(context_tokens) + + terminate_runs_tensor = get_accelerator().LongTensor([terminate_runs]) + torch.distributed.broadcast(terminate_runs_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) + terminate_runs = terminate_runs_tensor[0].item() + + if terminate_runs == 1: + return + + pad_id = tokenizer.get_command('pad').Id + if context_length < args.seq_length: + context_tokens.extend([pad_id] * (args.seq_length - context_length)) + + context_tokens_tensor = get_accelerator().LongTensor(context_tokens) + context_length_tensor = get_accelerator().LongTensor([context_length]) + + torch.distributed.broadcast(context_length_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) + torch.distributed.broadcast(context_tokens_tensor, mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group()) + + context_length = context_length_tensor[0].item() + tokens, attention_mask, position_ids=get_batch(context_tokens_tensor, device, args) + + start_time = time.time() + + counter = 0 + org_context_length = context_length + + while counter < (org_context_length + args.out_seq_length): + logits = model(tokens, position_ids, attention_mask) + logits = logits[:, context_length - 1, :] / args.temperature + logits = top_k_logits(logits, top_k=args.top_k, top_p=args.top_p) + log_probs = F.softmax(logits, dim=-1) + prev = torch.multinomial(log_probs, num_samples=1) + tokens[0, context_length] = prev[0] + context_length += 1 + counter += 1 + + output_tokens_list = tokens.view(-1).contiguous() + decode_tokens = tokenizer.DecodeIds(output_tokens_list.tolist()) + token_end = decode_tokens.find("<|endoftext|>") + + + if mpu.get_model_parallel_rank() == 0 and (counter % 16 == 0 or token_end != -1): + os.system('clear') + print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True) + print("\nContext:", raw_text, flush=True) + trim_decode_tokens = decode_tokens[len(raw_text):decode_tokens.find("<|endoftext|>")] + print("\nGPT2:", trim_decode_tokens, flush=True) + if token_end != -1: + break + + if mpu.get_model_parallel_rank() == 0: + os.system('clear') + print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True) + print("\nContext:", raw_text, flush=True) + output_tokens_list = tokens.view(-1).contiguous() + decode_tokens = tokenizer.DecodeIds(output_tokens_list.tolist()) + trim_decode_tokens = decode_tokens[len(raw_text):decode_tokens.find("<|endoftext|>")] + print("\nGPT2:", trim_decode_tokens, flush=True) + raw_text = None + + torch.distributed.barrier(group=mpu.get_model_parallel_group()) + context_count += 1 + +def prepare_tokenizer(args): + + tokenizer_args = { + 'tokenizer_type': args.tokenizer_type, + 'corpus': None, + 'model_path': args.tokenizer_path, + 'vocab_size': args.vocab_size, + 'model_type': args.tokenizer_model_type, + 'cache_dir': args.cache_dir} + tokenizer = make_tokenizer(**tokenizer_args) + + args.tokenizer_num_tokens = tokenizer.num_tokens + args.tokenizer_num_type_tokens = tokenizer.num_type_tokens + args.eod_token = tokenizer.get_command('eos').Id + + after = tokenizer.num_tokens + while after % mpu.get_model_parallel_world_size() != 0: + after += 1 + + args.vocab_size = after + print("prepare tokenizer done", flush=True) + + return tokenizer + +def main(): + """Main training program.""" + + print('Generate Samples') + + # Disable CuDNN. + torch.backends.cudnn.enabled = False + + # Timer. + timers = Timers() + + # Arguments. + args = get_args() + + # Pytorch distributed. + initialize_distributed(args) + + # Random seeds for reproducability. + set_random_seed(args.seed) + + #get the tokenizer + tokenizer = prepare_tokenizer(args) + + # Model, optimizer, and learning rate. + model = setup_model(args) + + #setting default batch size to 1 + args.batch_size = 1 + + #generate samples + generate_samples(model, tokenizer, args, get_accelerator().current_device_name()) + + +if __name__ == "__main__": + main() + + + diff --git a/examples/Megatron-LM/gpt2_data_loader.py b/examples/Megatron-LM/gpt2_data_loader.py new file mode 100644 index 0000000..b02927d --- /dev/null +++ b/examples/Megatron-LM/gpt2_data_loader.py @@ -0,0 +1,211 @@ +# coding=utf-8 +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os + +import numpy as np +import torch +from torch.multiprocessing import Lock +from torch.utils.data import Dataset + +import mpu +from data_utils.samplers import DistributedBatchSampler +from data_utils.tokenization_gpt2 import GPT2Tokenizer + + +def make_gpt2_dataloaders(args): + + # Input parameters. + input_data_sizes_file = args.input_data_sizes_file + seq_length = args.seq_length + initial_seed = args.seed + + # Data parallel arguments. + world_size = mpu.get_data_parallel_world_size() + rank = mpu.get_data_parallel_rank() + global_batch_size = args.batch_size * world_size + num_workers = args.num_workers + + def make_data_loader_(data_path): + # Build the dataset. + dataset = GPT2Dataset(data_path, input_data_sizes_file, + seq_length, initial_seed) + # Use a simple sampler with distributed batch sampler. + sampler = torch.utils.data.SequentialSampler(dataset) + batch_sampler = DistributedBatchSampler(sampler=sampler, + batch_size=global_batch_size, + drop_last=True, + rank=rank, + world_size=world_size) + # Torch dataloader. + return torch.utils.data.DataLoader(dataset, + batch_sampler=batch_sampler, + num_workers=num_workers, + pin_memory=True) + + train = make_data_loader_(args.train_data_path) + valid = make_data_loader_(args.val_data_path) + test = make_data_loader_(args.test_data_path) + + args.do_train = False + args.do_valid = False + args.do_test = False + + if train is not None: + args.do_train = True + if valid is not None: + args.do_valid = True + if test is not None: + args.do_test = True + + # Tokenizer. + tokenizer = GPT2Tokenizer.from_pretrained('gpt2', cache_dir=args.cache_dir) + eod_token = tokenizer.encoder['<|endoftext|>'] + num_tokens = eod_token + 1 + + return (train, valid, test), num_tokens, eod_token + + +class GPT2Dataset(Dataset): + + def __init__(self, data_path, sizes_filename, seq_length, + initial_seed, max_epochs=100): + # Input parameters. + self.data_path = data_path + self.sizes_filename = sizes_filename + self.seq_length = seq_length + self.initial_seed = initial_seed + self.max_epochs = max_epochs + # Lock for building the dataset. + self.lock = Lock() + + # Shard stuff. + # Dictionary from shard nameto its size (number of element). + self.master_shard_size_dict = None + # Dictionary from shard name to modified size so it is + # divisible by self.seq_length. + self.shard_size_dict = None + # Long array (self.max_epochs * num-shards) populated + # randomly with shard names. + self.shards_name = None + # Start index of the data for a shard. + self.shards_start_index = None + self.build_shard_mappings_() + self.data_length = self.shards_start_index[-1] + + # Data. + self.shards_data = [None]*self.shards_name.size + self.shards_sample_index = [None]*self.shards_name.size + + def __len__(self): + return self.data_length + + def __getitem__(self, idx): + # Find which shard we need. + shard_index = np.searchsorted(self.shards_start_index, + idx, side='right') - 1 + # data index in the shard. + data_idx = idx - self.shards_start_index[shard_index] + # Load the shard if it is not in memory. + #self.lock.acquire() + if self.shards_data[shard_index] is None: + print('global rank {} is building data for shard index {} ...'. + format(torch.distributed.get_rank(), shard_index)) + self.build_dataset_(shard_index) + #assert self.shards_data[shard_index] is not None + #self.lock.release() + # Start index. + start_index = self.shards_sample_index[shard_index][data_idx] + # Add one for label shift. + end_index = start_index + self.seq_length + 1 + data = self.shards_data[shard_index][start_index:end_index] + return {'text': np.array(data, dtype=np.int64)} + + def build_dataset_(self, shard_index): + # Garbage collect so we don't use a lot of memory. + # Leave the last one in case other threads have not catche up yet. + #for i in range(shard_index - 1): + for i in range(shard_index): + self.shards_data[i] = None + self.shards_sample_index[i] = None + # Read the shard. + filename = os.path.join(self.data_path, self.shards_name[shard_index]) + print('loading {}'.format(filename)) + data = np.load(filename, allow_pickle=True) + # Shuffle the data + rng = np.random.RandomState(self.initial_seed + shard_index) + rng.shuffle(data) + # Flatten. + data = np.hstack(data) + size = (data.shape[0] - 1) // self.seq_length + last_index = size * self.seq_length + 1 + data = data[0:last_index] + self.shards_data[shard_index] = data + indices = np.arange(size) * self.seq_length + rng.shuffle(indices) + self.shards_sample_index[shard_index] = indices + + def build_shard_mappings_(self): + # Load the sizes file. + sizes_filename = os.path.join(self.data_path, self.sizes_filename) + if torch.distributed.get_rank() == 0: + print(' > loading sizes from {}'.format(sizes_filename)) + with open(sizes_filename, 'r') as f: + self.master_shard_size_dict = json.load(f) + if torch.distributed.get_rank() == 0: + print(' found {} shards'.format(len(self.master_shard_size_dict))) + # Adjust sizes to be a multiple of seq_length. + self.shard_size_dict = self.master_shard_size_dict.copy() + total_samples = 0 + for shard in self.shard_size_dict: + size = self.shard_size_dict[shard] + size = ((size - 1) // self.seq_length) * self.seq_length + total_samples += size // self.seq_length + self.shard_size_dict[shard] = size + if torch.distributed.get_rank() == 0: + print(' found {} samples in the dataset'.format(total_samples)) + # Build a list of shards. + shards_ = np.sort(np.array(list(self.shard_size_dict.keys()))) + rng = np.random.RandomState(self.initial_seed) + self.shards_name = np.copy(shards_) + rng.shuffle(self.shards_name) + for i in range(1, self.max_epochs): + shards_c = np.copy(shards_) + rng.shuffle(shards_c) + self.shards_name = np.append(self.shards_name, shards_c) + # Build the global indexing. + self.shards_start_index = np.zeros(self.shards_name.size, dtype=np.int) + self.shards_start_index[0] = 0 + for i in range(1, self.shards_name.size): + shard = str(self.shards_name[i-1]) + size = self.shard_size_dict[shard] + self.shards_start_index[i] = self.shards_start_index[i-1] + \ + size // self.seq_length + +''' +if __name__ == '__main__': + + print('gpt2 data loader ...') + path = '/raid/mshoeybi/data/gpt2/adlr/reddit_all_ftfy_lg200/npys' + + dataset = GPT2Dataset(path, 'sizes.txt', 1024, 1234, 100) + print('dataset contains {} samples'.format(dataset.data_length)) + + for i in range(len(dataset)): + if i % 512000 == 0: + print(i) + data = dataset[i] +''' diff --git a/examples/Megatron-LM/learning_rates.py b/examples/Megatron-LM/learning_rates.py new file mode 100644 index 0000000..81df8ea --- /dev/null +++ b/examples/Megatron-LM/learning_rates.py @@ -0,0 +1,76 @@ +# coding=utf-8 +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch DataLoader for TFRecords""" + +import torch +from torch.optim.lr_scheduler import _LRScheduler +import math + +class AnnealingLR(_LRScheduler): + """Anneals the learning rate from start to zero along a cosine curve.""" + + DECAY_STYLES = ['linear', 'cosine', 'exponential', 'constant', 'None'] + + def __init__(self, optimizer, start_lr, warmup_iter, num_iters, decay_style=None, last_iter=-1): + self.optimizer = optimizer + self.start_lr = start_lr + self.warmup_iter = warmup_iter + self.num_iters = last_iter + 1 + self.end_iter = num_iters + self.decay_style = decay_style.lower() if isinstance(decay_style, str) else None + self.step(self.num_iters) + if torch.distributed.get_rank() == 0: + print('learning rate decaying', decay_style) + + def get_lr(self): + # https://openreview.net/pdf?id=BJYwwY9ll pg. 4 + if self.warmup_iter > 0 and self.num_iters <= self.warmup_iter: + return float(self.start_lr) * self.num_iters / self.warmup_iter + else: + if self.decay_style == self.DECAY_STYLES[0]: + return self.start_lr*((self.end_iter-(self.num_iters-self.warmup_iter))/self.end_iter) + elif self.decay_style == self.DECAY_STYLES[1]: + return self.start_lr / 2.0 * (math.cos(math.pi * (self.num_iters - self.warmup_iter) / self.end_iter) + 1) + elif self.decay_style == self.DECAY_STYLES[2]: + #TODO: implement exponential decay + return self.start_lr + else: + return self.start_lr + + def step(self, step_num=None): + if step_num is None: + step_num = self.num_iters + 1 + self.num_iters = step_num + new_lr = self.get_lr() + for group in self.optimizer.param_groups: + group['lr'] = new_lr + + def state_dict(self): + sd = { + 'start_lr': self.start_lr, + 'warmup_iter': self.warmup_iter, + 'num_iters': self.num_iters, + 'decay_style': self.decay_style, + 'end_iter': self.end_iter + } + return sd + + def load_state_dict(self, sd): + self.start_lr = sd['start_lr'] + self.warmup_iter = sd['warmup_iter'] + self.num_iters = sd['num_iters'] + self.end_iter = sd['end_iter'] + self.decay_style = sd['decay_style'] + self.step(self.num_iters) diff --git a/examples/Megatron-LM/model/__init__.py b/examples/Megatron-LM/model/__init__.py new file mode 100755 index 0000000..ff2c69e --- /dev/null +++ b/examples/Megatron-LM/model/__init__.py @@ -0,0 +1,20 @@ +# coding=utf-8 +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .distributed import * +from .gpt2_modeling import gpt2_get_params_for_weight_decay_optimization +from .gpt2_modeling import GPT2Model +from .model import BertModel +from .model import get_params_for_weight_decay_optimization diff --git a/examples/Megatron-LM/model/distributed.py b/examples/Megatron-LM/model/distributed.py new file mode 100755 index 0000000..b86733d --- /dev/null +++ b/examples/Megatron-LM/model/distributed.py @@ -0,0 +1,111 @@ +# coding=utf-8 +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors +import torch.distributed as dist +from torch.nn.modules import Module +from torch.autograd import Variable +from deepspeed.accelerator.real_accelerator import get_accelerator + +import mpu + +class DistributedDataParallel(Module): + + def __init__(self, module): + super(DistributedDataParallel, self).__init__() + self.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False + + self.module = module + self.data_parallel_group = mpu.get_data_parallel_group() + src_rank = mpu.get_model_parallel_rank() + for p in self.module.parameters(): + if torch.is_tensor(p): + dist.broadcast(p, src_rank, group=self.data_parallel_group) + + def allreduce_params(reduce_after=True, no_scale=False, fp32_allreduce=False): + if(self.needs_reduction): + self.needs_reduction = False + buckets = {} + for name, param in self.module.named_parameters(): + if param.requires_grad and param.grad is not None: + tp = (param.data.type()) + if tp not in buckets: + buckets[tp] = [] + buckets[tp].append(param) + if self.warn_on_half: + if torch.cuda.HalfTensor in buckets: + print("WARNING: gloo dist backend for half parameters may be extremely slow." + + " It is recommended to use the NCCL backend in this case.") + self.warn_on_half = False + for tp in buckets: + bucket = buckets[tp] + grads = [param.grad.data for param in bucket] + coalesced = _flatten_dense_tensors(grads) + if fp32_allreduce: + coalesced = coalesced.float() + if not no_scale and not reduce_after: + coalesced /= dist.get_world_size(group=self.data_parallel_group) + dist.all_reduce(coalesced, group=self.data_parallel_group) + get_accelerator().synchronize() + if not no_scale and reduce_after: + coalesced /= dist.get_world_size(group=self.data_parallel_group) + for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): + buf.copy_(synced) + self.hook_handles = [] + self.hooks = [] + for param in list(self.module.parameters()): + def allreduce_hook(*unused): + Variable._execution_engine.queue_callback(allreduce_params) + # handle = param.register_hook(allreduce_hook) + #self.hooks.append(allreduce_hook) + #self.hook_handles.append(handle) + self.allreduce_params = allreduce_params + + def forward(self, *inputs, **kwargs): + self.needs_reduction = True + return self.module(*inputs, **kwargs) + + def state_dict(self, destination=None, prefix='', keep_vars=False): + #[h.remove() for h in self.hook_handles] + sd = self.module.state_dict(destination, prefix, keep_vars) + # for handle, hook in zip(self.hook_handles, self.hooks): + # d = handle.hooks_dict_ref() + # d[handle.id] = hook + + return sd + + def load_state_dict(self, state_dict, strict=True): + self.module.load_state_dict(state_dict, strict=strict) + + ''' + def _sync_buffers(self): + buffers = list(self.module._all_buffers()) + if len(buffers) > 0: + # cross-node buffer sync + flat_buffers = _flatten_dense_tensors(buffers) + dist.broadcast(flat_buffers, 0) + for buf, synced in zip(buffers, _unflatten_dense_tensors(flat_buffers, buffers)): + buf.copy_(synced) + def train(self, mode=True): + # Clear NCCL communicator and CUDA event cache of the default group ID, + # These cache will be recreated at the later call. This is currently a + # work-around for a potential NCCL deadlock. + if dist._backend == dist.dist_backend.NCCL: + dist._clear_group_cache() + super(DistributedDataParallel, self).train(mode) + self.module.train(mode) + ''' + diff --git a/examples/Megatron-LM/model/gpt2_modeling.py b/examples/Megatron-LM/model/gpt2_modeling.py new file mode 100644 index 0000000..b99fe6a --- /dev/null +++ b/examples/Megatron-LM/model/gpt2_modeling.py @@ -0,0 +1,125 @@ +# coding=utf-8 +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GPT-2 model.""" + +import torch +import torch.nn.functional as F + +import mpu + + +def init_method_normal(std=0.02): + """Init method based on normal distribution. + + This is only used for embeddings. The transformer has its + own initializer. + """ + def init_(tensor): + return torch.nn.init.normal_(tensor, mean=0.0, std=std) + return init_ + + +class GPT2Model(torch.nn.Module): + """GPT-2 Language model. + + The output of the forward method are the logits (parallel or + serial depending on the `parallel_output` flag. + """ + + def __init__(self, + num_layers, + vocab_size, + hidden_size, + num_attention_heads, + embedding_dropout_prob, + attention_dropout_prob, + output_dropout_prob, + max_sequence_length, + checkpoint_activations, + checkpoint_num_layers=1, + parallel_output=True): + + super(GPT2Model, self).__init__() + + self.parallel_output = parallel_output + + init_method = init_method_normal(std=0.02) + + # Word embeddings (parallel). + self.word_embeddings = mpu.VocabParallelEmbedding( + vocab_size, hidden_size, init_method=init_method) + + # Position embedding (serial). + self.position_embeddings = torch.nn.Embedding(max_sequence_length, + hidden_size) + # Initialize the position embeddings. + init_method(self.position_embeddings.weight) + + # Embeddings dropout + self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob) + + # Transformer + self.transformer = mpu.GPT2ParallelTransformer(num_layers, + hidden_size, + num_attention_heads, + attention_dropout_prob, + output_dropout_prob, + checkpoint_activations, + checkpoint_num_layers) + + def forward(self, input_ids, position_ids, attention_mask): + + # Embeddings. + words_embeddings = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) + embeddings = words_embeddings + position_embeddings + + # Dropout. + embeddings = self.embedding_dropout(embeddings) + + # Transformer. + transformer_output = self.transformer(embeddings, attention_mask) + + # Parallel logits. + transformer_output_parallel = mpu.copy_to_model_parallel_region( + transformer_output) + logits_parallel = F.linear(transformer_output_parallel, + self.word_embeddings.weight) + + if self.parallel_output: + return logits_parallel + + return mpu.gather_from_model_parallel_region(logits_parallel) + + +def gpt2_get_params_for_weight_decay_optimization(module): + + weight_decay_params = {'params': []} + no_weight_decay_params = {'params': [], 'weight_decay': 0.0} + for module_ in module.modules(): + if isinstance(module_, (mpu.LayerNorm, torch.nn.LayerNorm)): + no_weight_decay_params['params'].extend( + [p for p in list(module_._parameters.values()) + if p is not None]) + else: + weight_decay_params['params'].extend( + [p for n, p in list(module_._parameters.items()) + if p is not None and n != 'bias']) + no_weight_decay_params['params'].extend( + [p for n, p in list(module_._parameters.items()) + if p is not None and n == 'bias']) + + return weight_decay_params, no_weight_decay_params diff --git a/examples/Megatron-LM/model/model.py b/examples/Megatron-LM/model/model.py new file mode 100755 index 0000000..ea6f205 --- /dev/null +++ b/examples/Megatron-LM/model/model.py @@ -0,0 +1,90 @@ +# coding=utf-8 +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for wrapping BertModel.""" + +import torch + +from .modeling import BertConfig +from .modeling import BertForPreTraining, BertForMaskedLM +from .modeling import BertLayerNorm + + +def get_params_for_weight_decay_optimization(module): + + weight_decay_params = {'params': []} + no_weight_decay_params = {'params': [], 'weight_decay': 0.0} + for module_ in module.modules(): + if isinstance(module_, (BertLayerNorm, torch.nn.LayerNorm)): + no_weight_decay_params['params'].extend( + [p for p in list(module_._parameters.values()) + if p is not None]) + else: + weight_decay_params['params'].extend( + [p for n, p in list(module_._parameters.items()) + if p is not None and n != 'bias']) + no_weight_decay_params['params'].extend( + [p for n, p in list(module_._parameters.items()) + if p is not None and n == 'bias']) + + return weight_decay_params, no_weight_decay_params + + +class BertModel(torch.nn.Module): + + def __init__(self, args): + super(BertModel, self).__init__() + if args.pretrained_bert: + self.model = BertForPreTraining.from_pretrained( + args.tokenizer_model_type, + cache_dir=args.cache_dir, + fp32_layernorm=args.fp32_layernorm, + fp32_embedding=args.fp32_embedding, + layernorm_epsilon=args.layernorm_epsilon) + else: + if args.intermediate_size is None: + intermediate_size = 4 * args.hidden_size + else: + intermediate_size = args.intermediate_size + self.config = BertConfig( + args.tokenizer_num_tokens, + hidden_size=args.hidden_size, + num_hidden_layers=args.num_layers, + num_attention_heads=args.num_attention_heads, + intermediate_size=intermediate_size, + hidden_dropout_prob=args.hidden_dropout, + attention_probs_dropout_prob=args.attention_dropout, + max_position_embeddings=args.max_position_embeddings, + type_vocab_size=args.tokenizer_num_type_tokens, + fp32_layernorm=args.fp32_layernorm, + fp32_embedding=args.fp32_embedding, + fp32_tokentypes=args.fp32_tokentypes, + layernorm_epsilon=args.layernorm_epsilon, + deep_init=args.deep_init) + self.model = BertForPreTraining(self.config) + + def forward(self, input_tokens, token_type_ids=None, + attention_mask=None, checkpoint_activations=False): + return self.model( + input_tokens, token_type_ids, attention_mask, + checkpoint_activations=checkpoint_activations) + + def state_dict(self, destination=None, prefix='', keep_vars=False): + return self.model.state_dict(destination=destination, prefix=prefix, + keep_vars=keep_vars) + + def load_state_dict(self, state_dict, strict=True): + return self.model.load_state_dict(state_dict, strict=strict) + diff --git a/examples/Megatron-LM/model/modeling.py b/examples/Megatron-LM/model/modeling.py new file mode 100644 index 0000000..d5f8f5a --- /dev/null +++ b/examples/Megatron-LM/model/modeling.py @@ -0,0 +1,1382 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch BERT model.""" + +from __future__ import absolute_import, division, print_function, unicode_literals + +import os +import copy +import json +import math +import logging +import tarfile +import tempfile +import shutil + +import torch +from torch import nn +import torch.nn.functional as F +from torch.nn import CrossEntropyLoss + +#from torch.utils.checkpoint import checkpoint + +from data_utils.file_utils import cached_path + +import mpu + + +def normal_init_method(mean, std): + def init_(tensor): + return torch.nn.init.normal_(tensor, mean=mean, std=std) + return init_ + +def scaled_init_method(mean, std, num_layers): + """Init method based on N(0, sigma/sqrt(2*num_layers).""" + std = std / math.sqrt(2.0 * num_layers) + def init_(tensor): + return torch.nn.init.normal_(tensor, mean=mean, std=std) + + return init_ + +logger = logging.getLogger(__name__) + +PRETRAINED_MODEL_ARCHIVE_MAP = { + 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz", + 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz", + 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz", + 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz", + 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz", + 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz", + 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz", +} +CONFIG_NAME = 'bert_config.json' +WEIGHTS_NAME = 'pytorch_model.bin' +TF_WEIGHTS_NAME = 'model.ckpt' + +def load_tf_weights_in_bert(model, tf_checkpoint_path): + """ Load tf checkpoints in a pytorch model + """ + try: + import re + import numpy as np + import tensorflow as tf + except ImportError: + print("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions.") + raise + tf_path = os.path.abspath(tf_checkpoint_path) + print("Converting TensorFlow checkpoint from {}".format(tf_path)) + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + print("Loading TF weight {} with shape {}".format(name, shape)) + array = tf.train.load_variable(tf_path, name) + names.append(name) + arrays.append(array) + + for name, array in zip(names, arrays): + name = name.split('/') + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any(n in ["adam_v", "adam_m"] for n in name): + print("Skipping {}".format("/".join(name))) + continue + pointer = model + for m_name in name: + if re.fullmatch(r'[A-Za-z]+_\d+', m_name): + l = re.split(r'_(\d+)', m_name) + else: + l = [m_name] + if l[0] == 'kernel' or l[0] == 'gamma': + pointer = getattr(pointer, 'weight') + elif l[0] == 'output_bias' or l[0] == 'beta': + pointer = getattr(pointer, 'bias') + elif l[0] == 'output_weights': + pointer = getattr(pointer, 'weight') + else: + pointer = getattr(pointer, l[0]) + if len(l) >= 2: + num = int(l[1]) + pointer = pointer[num] + if m_name[-11:] == '_embeddings': + pointer = getattr(pointer, 'weight') + elif m_name == 'kernel': + array = np.transpose(array) + try: + assert pointer.shape == array.shape + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + print("Initialize PyTorch weight {}".format(name)) + pointer.data = torch.from_numpy(array) + return model + + +def gelu(x): + """Implementation of the gelu activation function. + For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): + 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) + """ + return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) + + +def swish(x): + return x * torch.sigmoid(x) + + +ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} + +class BertConfig(object): + """Configuration class to store the configuration of a `BertModel`. + """ + def __init__(self, + vocab_size_or_config_json_file, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + deep_init=False, + fp32_layernorm=False, + fp32_embedding=False, + fp32_tokentypes=False, + layernorm_epsilon=1e-12): + """Constructs BertConfig. + + Args: + vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`. + hidden_size: Size of the encoder layers and the pooler layer. + num_hidden_layers: Number of hidden layers in the Transformer encoder. + num_attention_heads: Number of attention heads for each attention layer in + the Transformer encoder. + intermediate_size: The size of the "intermediate" (i.e., feed-forward) + layer in the Transformer encoder. + hidden_act: The non-linear activation function (function or string) in the + encoder and pooler. If string, "gelu", "relu" and "swish" are supported. + hidden_dropout_prob: The dropout probabilitiy for all fully connected + layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob: The dropout ratio for the attention + probabilities. + max_position_embeddings: The maximum sequence length that this model might + ever be used with. Typically set this to something large just in case + (e.g., 512 or 1024 or 2048). + type_vocab_size: The vocabulary size of the `token_type_ids` passed into + `BertModel`. + initializer_range: The sttdev of the truncated_normal_initializer for + initializing all weight matrices. + """ + if isinstance(vocab_size_or_config_json_file, str): + with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: + json_config = json.loads(reader.read()) + for key, value in json_config.items(): + self.__dict__[key] = value + elif isinstance(vocab_size_or_config_json_file, int): + self.vocab_size = vocab_size_or_config_json_file + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.deep_init = deep_init + self.fp32_layernorm = fp32_layernorm + self.fp32_embedding = fp32_embedding + self.layernorm_epsilon = layernorm_epsilon + self.fp32_tokentypes = fp32_tokentypes + else: + raise ValueError("First argument must be either a vocabulary size (int)" + "or the path to a pretrained model config file (str)") + + @classmethod + def from_dict(cls, json_object): + """Constructs a `BertConfig` from a Python dictionary of parameters.""" + config = BertConfig(vocab_size_or_config_json_file=-1) + for key, value in json_object.items(): + config.__dict__[key] = value + return config + + @classmethod + def from_json_file(cls, json_file): + """Constructs a `BertConfig` from a json file of parameters.""" + with open(json_file, "r", encoding='utf-8') as reader: + text = reader.read() + return cls.from_dict(json.loads(text)) + + def __repr__(self): + return str(self.to_json_string()) + + def to_dict(self): + """Serializes this instance to a Python dictionary.""" + output = copy.deepcopy(self.__dict__) + return output + + def to_json_string(self): + """Serializes this instance to a JSON string.""" + return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" + +try: + from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm +except ImportError: + print("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.") + class BertLayerNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-12): + """Construct a layernorm module in the TF style (epsilon inside the square root). + """ + super(BertLayerNorm, self).__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.bias = nn.Parameter(torch.zeros(hidden_size)) + self.variance_epsilon = eps + + def forward(self, x): + u = x.mean(-1, keepdim=True) + s = (x - u).pow(2).mean(-1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.variance_epsilon) + return self.weight * x + self.bias + +class BertEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings. + """ + def __init__(self, config): + super(BertEmbeddings, self).__init__() + #self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) + self.word_embeddings = mpu.VocabParallelEmbedding( + config.vocab_size, config.hidden_size, + init_method=normal_init_method(mean=0.0, + std=config.initializer_range)) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.fp32_layernorm = config.fp32_layernorm + self.fp32_embedding = config.fp32_embedding + self.fp32_tokentypes = config.fp32_tokentypes + self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layernorm_epsilon) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, input_ids, token_type_ids=None): + seq_length = input_ids.size(1) + position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_ids) + + words_embeddings = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + if not self.fp32_tokentypes: + + embeddings = words_embeddings + position_embeddings + token_type_embeddings + if self.fp32_embedding and not self.fp32_layernorm: + embeddings = embeddings.half() + previous_type = embeddings.type() + if self.fp32_layernorm: + embeddings = embeddings.float() + embeddings = self.LayerNorm(embeddings) + if self.fp32_layernorm: + if self.fp32_embedding: + embeddings = embeddings.half() + else: + embeddings = embeddings.type(previous_type) + else: + embeddings = words_embeddings.float() + position_embeddings.float() + token_type_embeddings.float() + if self.fp32_tokentypes and not self.fp32_layernorm: + embeddings = embeddings.half() + previous_type = embeddings.type() + if self.fp32_layernorm: + embeddings = embeddings.float() + embeddings = self.LayerNorm(embeddings) + if self.fp32_layernorm: + if self.fp32_tokentypes: + embeddings = embeddings.half() + else: + embeddings = embeddings.type(previous_type) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertSelfAttention(nn.Module): + def __init__(self, config): + super(BertSelfAttention, self).__init__() + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (config.hidden_size, config.num_attention_heads)) + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, hidden_states, attention_mask): + + mixed_query_layer = self.query(hidden_states) + mixed_key_layer = self.key(hidden_states) + mixed_value_layer = self.value(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.transpose_for_scores(mixed_key_layer) + value_layer = self.transpose_for_scores(mixed_value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + previous_type = attention_probs.type() + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + return context_layer + + +class BertSelfOutput(nn.Module): + def __init__(self, config): + super(BertSelfOutput, self).__init__() + if hasattr(config, 'deep_init') and config.deep_init: + init_method = scaled_init_method(mean=0.0, + std=config.initializer_range, + num_layers=config.num_hidden_layers) + else: + init_method = normal_init_method(mean=0.0, + std=config.initializer_range) + self.dense = mpu.RowParallelLinear( + input_size=config.hidden_size, + output_size=config.hidden_size, + bias=True, + input_is_parallel=True, + stride=1, + init_method=init_method) + self.fp32_layernorm = config.fp32_layernorm + self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layernorm_epsilon) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + ln_input = hidden_states + input_tensor + previous_type = ln_input.type() + if self.fp32_layernorm: + ln_input = ln_input.float() + hidden_states = self.LayerNorm(ln_input) + if self.fp32_layernorm: + hidden_states = hidden_states.type(previous_type) + return hidden_states + + +class BertAttention(nn.Module): + def __init__(self, config): + super(BertAttention, self).__init__() + self.self = mpu.BertParallelSelfAttention( + hidden_size=config.hidden_size, + num_attention_heads=config.num_attention_heads, + dropout_prob=config.attention_probs_dropout_prob, + output_parallel=True, + init_method=normal_init_method(mean=0.0, + std=config.initializer_range)) + self.output = BertSelfOutput(config) + + def forward(self, input_tensor, attention_mask): + self_output = self.self(input_tensor, attention_mask) + attention_output = self.output(self_output, input_tensor) + return attention_output + + +class BertIntermediate(nn.Module): + def __init__(self, config): + super(BertIntermediate, self).__init__() + self.dense = mpu.ColumnParallelLinear( + input_size=config.hidden_size, + output_size=config.intermediate_size, + bias=True, + gather_output=False, + stride=1, + init_method=normal_init_method(mean=0.0, + std=config.initializer_range)) + self.intermediate_act_fn = ACT2FN[config.hidden_act] \ + if isinstance(config.hidden_act, str) else config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + def __init__(self, config): + super(BertOutput, self).__init__() + if hasattr(config, 'deep_init') and config.deep_init: + init_method = scaled_init_method(mean=0.0, + std=config.initializer_range, + num_layers=config.num_hidden_layers) + else: + init_method = normal_init_method(mean=0.0, + std=config.initializer_range) + self.dense = mpu.RowParallelLinear( + input_size=config.intermediate_size, + output_size=config.hidden_size, + bias=True, + input_is_parallel=True, + stride=1, + init_method=init_method) + self.fp32_layernorm = config.fp32_layernorm + self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layernorm_epsilon) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + ln_input = hidden_states + input_tensor + previous_type = ln_input.type() + if self.fp32_layernorm: + ln_input = ln_input.float() + hidden_states = self.LayerNorm(ln_input) + if self.fp32_layernorm: + hidden_states = hidden_states.type(previous_type) + return hidden_states + + +class BertLayer(nn.Module): + def __init__(self, config): + super(BertLayer, self).__init__() + self.attention = BertAttention(config) + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + def forward(self, hidden_states, attention_mask): + attention_output = self.attention(hidden_states, attention_mask) + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class BertEncoder(nn.Module): + def __init__(self, config): + super(BertEncoder, self).__init__() + #layer = BertLayer(config) + #self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)]) + self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)]) + + # def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True): + # all_encoder_layers = [] + # for layer_module in self.layer: + # hidden_states = layer_module(hidden_states, attention_mask) + # if output_all_encoded_layers: + # all_encoder_layers.append(hidden_states) + # if not output_all_encoded_layers: + # all_encoder_layers.append(hidden_states) + # return all_encoder_layers + def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True, checkpoint_activations=False): + all_encoder_layers = [] + def custom(start, end): + def custom_forward(*inputs): + layers = self.layer[start:end] + x_ = inputs[0] + for layer in layers: + x_ = layer(x_, inputs[1]) + return x_ + return custom_forward + + if checkpoint_activations: + l = 0 + num_layers = len(self.layer) + chunk_length = 1 #math.ceil(math.sqrt(num_layers)) + while l < num_layers: + hidden_states = mpu.checkpoint(custom(l, l+chunk_length), hidden_states, attention_mask*1) + l += chunk_length + # decoder layers + else: + for i,layer_module in enumerate(self.layer): + hidden_states = layer_module(hidden_states, attention_mask) + + if output_all_encoded_layers: + all_encoder_layers.append(hidden_states) + + if not output_all_encoded_layers or checkpoint_activations: + all_encoder_layers.append(hidden_states) + return all_encoder_layers + + +class BertPooler(nn.Module): + def __init__(self, config): + super(BertPooler, self).__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertPredictionHeadTransform(nn.Module): + def __init__(self, config): + super(BertPredictionHeadTransform, self).__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.transform_act_fn = ACT2FN[config.hidden_act] \ + if isinstance(config.hidden_act, str) else config.hidden_act + self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layernorm_epsilon) + self.fp32_layernorm = config.fp32_layernorm + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + previous_type = hidden_states.type() + if self.fp32_layernorm: + hidden_states = hidden_states.float() + hidden_states = self.LayerNorm(hidden_states) + if self.fp32_layernorm: + hidden_states = hidden_states.type(previous_type) + return hidden_states + + +class BertLMPredictionHead(nn.Module): + def __init__(self, config, bert_model_embedding_weights): + super(BertLMPredictionHead, self).__init__() + self.transform = BertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + #self.decoder = nn.Linear(bert_model_embedding_weights.size(1), + # bert_model_embedding_weights.size(0), + # bias=False) + self.decoder_weight = bert_model_embedding_weights + self.bias = nn.Parameter(torch.zeros(bert_model_embedding_weights.size(0))) + self.bias.model_parallel = True + self.fp32_embedding = config.fp32_embedding + self.fp32_layernorm = config.fp32_layernorm + def convert_to_type(tensor): + if self.fp32_embedding: + return tensor.half() + else: + return tensor + self.type_converter = convert_to_type + self.converted = False + + def forward(self, hidden_states): + if not self.converted: + self.converted = True + if self.fp32_embedding: + self.transform.half() + if self.fp32_layernorm: + self.transform.LayerNorm.float() + hidden_states = self.transform(self.type_converter(hidden_states)) + # hidden_states = self.decoder(hidden_states) + self.bias + hidden_states = mpu.copy_to_model_parallel_region(hidden_states) + hidden_states = F.linear(self.type_converter(hidden_states), + self.type_converter(self.decoder_weight), + self.type_converter(self.bias)) + return hidden_states + + +class BertOnlyMLMHead(nn.Module): + def __init__(self, config, bert_model_embedding_weights): + super(BertOnlyMLMHead, self).__init__() + self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class BertOnlyNSPHead(nn.Module): + def __init__(self, config): + super(BertOnlyNSPHead, self).__init__() + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, pooled_output): + seq_relationship_score = self.seq_relationship(pooled_output) + return seq_relationship_score + + +class BertPreTrainingHeads(nn.Module): + def __init__(self, config, bert_model_embedding_weights): + super(BertPreTrainingHeads, self).__init__() + self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights) + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, sequence_output, pooled_output): + prediction_scores = self.predictions(sequence_output) + for p in self.seq_relationship.parameters(): + if p is None: + continue + pooled_output = pooled_output.type_as(p) + seq_relationship_score = self.seq_relationship(pooled_output) + return prediction_scores, seq_relationship_score + + +class PreTrainedBertModel(nn.Module): + """ An abstract class to handle weights initialization and + a simple interface for dowloading and loading pretrained models. + """ + def __init__(self, config, *inputs, **kwargs): + super(PreTrainedBertModel, self).__init__() + if not isinstance(config, BertConfig): + raise ValueError( + "Parameter config in `{}(config)` should be an instance of class `BertConfig`. " + "To create a model from a Google pretrained model use " + "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( + self.__class__.__name__, self.__class__.__name__ + )) + self.config = config + + def init_bert_weights(self, module): + """ Initialize the weights. + """ + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, BertLayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + @classmethod + def from_pretrained(cls, pretrained_model_name, state_dict=None, cache_dir=None, + fp32_layernorm=False, fp32_embedding=False, layernorm_epsilon=1e-12, + fp32_tokentypes=False, *inputs, **kwargs): + """ + Instantiate a PreTrainedBertModel from a pre-trained model file or a pytorch state dict. + Download and cache the pre-trained model file if needed. + + Params: + pretrained_model_name: either: + - a str with the name of a pre-trained model to load selected in the list of: + . `bert-base-uncased` + . `bert-large-uncased` + . `bert-base-cased` + . `bert-large-cased` + . `bert-base-multilingual-uncased` + . `bert-base-multilingual-cased` + . `bert-base-chinese` + - a path or url to a pretrained model archive containing: + . `bert_config.json` a configuration file for the model + . `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance + cache_dir: an optional path to a folder in which the pre-trained models will be cached. + state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models + *inputs, **kwargs: additional input for the specific Bert class + (ex: num_labels for BertForSequenceClassification) + """ + if pretrained_model_name in PRETRAINED_MODEL_ARCHIVE_MAP: + archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name] + else: + archive_file = pretrained_model_name + # redirect to the cache, if necessary + try: + resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir) + except FileNotFoundError: + logger.error( + "Model name '{}' was not found in model name list ({}). " + "We assumed '{}' was a path or url but couldn't find any file " + "associated to this path or url.".format( + pretrained_model_name, + ', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), + archive_file)) + return None + if resolved_archive_file == archive_file: + logger.info("loading archive file {}".format(archive_file)) + else: + logger.info("loading archive file {} from cache at {}".format( + archive_file, resolved_archive_file)) + tempdir = None + if os.path.isdir(resolved_archive_file): + serialization_dir = resolved_archive_file + else: + # Extract archive to temp dir + tempdir = tempfile.mkdtemp() + logger.info("extracting archive file {} to temp dir {}".format( + resolved_archive_file, tempdir)) + with tarfile.open(resolved_archive_file, 'r:gz') as archive: + archive.extractall(tempdir) + serialization_dir = tempdir + # Load config + config_file = os.path.join(serialization_dir, CONFIG_NAME) + config = BertConfig.from_json_file(config_file) + config.fp32_layernorm = fp32_layernorm + config.fp32_embedding = fp32_embedding + config.layernorm_epsilon = layernorm_epsilon + config.fp32_tokentypes = fp32_tokentypes + logger.info("Model config {}".format(config)) + # Instantiate model. + model = cls(config, *inputs, **kwargs) + if state_dict is None: + weights_path = os.path.join(serialization_dir, WEIGHTS_NAME) + state_dict = torch.load(weights_path) + + old_keys = [] + new_keys = [] + for key in state_dict.keys(): + new_key = None + if 'gamma' in key: + new_key = key.replace('gamma', 'weight') + if 'beta' in key: + new_key = key.replace('beta', 'bias') + if new_key: + old_keys.append(key) + new_keys.append(new_key) + for old_key, new_key in zip(old_keys, new_keys): + state_dict[new_key] = state_dict.pop(old_key) + + missing_keys = [] + unexpected_keys = [] + error_msgs = [] + # copy state_dict so _load_from_state_dict can modify it + metadata = getattr(state_dict, '_metadata', None) + state_dict = state_dict.copy() + if metadata is not None: + state_dict._metadata = metadata + + def load(module, prefix=''): + local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) + module._load_from_state_dict( + state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) + for name, child in module._modules.items(): + if child is not None: + load(child, prefix + name + '.') + load(model, prefix='' if hasattr(model, 'bert') else 'bert.') + if len(missing_keys) > 0: + logger.info("Weights of {} not initialized from pretrained model: {}".format( + model.__class__.__name__, missing_keys)) + if len(unexpected_keys) > 0: + logger.info("Weights from pretrained model not used in {}: {}".format( + model.__class__.__name__, unexpected_keys)) + if tempdir: + # Clean up temp dir + shutil.rmtree(tempdir) + return model + + +class BertModel(PreTrainedBertModel): + """BERT model ("Bidirectional Embedding Representations from a Transformer"). + + Params: + config: a BertConfig class instance with the configuration to build a new model + + Inputs: + `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] + with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts + `extract_features.py`, `run_classifier.py` and `run_squad.py`) + `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token + types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to + a `sentence B` token (see BERT paper for more details). + `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices + selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max + input sequence length in the current batch. It's the mask that we typically use for attention when + a batch has varying length sentences. + `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`. + + Outputs: Tuple of (encoded_layers, pooled_output) + `encoded_layers`: controled by `output_all_encoded_layers` argument: + - `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end + of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each + encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size], + - `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding + to the last attention block of shape [batch_size, sequence_length, hidden_size], + `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a + classifier pretrained on top of the hidden state associated to the first character of the + input (`CLF`) to train on the Next-Sentence task (see BERT's paper). + + Example usage: + ```python + # Already been converted into WordPiece token ids + input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) + input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) + token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) + + config = modeling.BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, + num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) + + model = modeling.BertModel(config=config) + all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) + ``` + """ + def __init__(self, config): + super(BertModel, self).__init__(config) + self.embeddings = BertEmbeddings(config) + self.encoder = BertEncoder(config) + self.pooler = BertPooler(config) + self.apply(self.init_bert_weights) + + def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True, checkpoint_activations=False): + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_ids) + + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to(dtype=next(self.encoder.parameters()).dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + + embedding_output = self.embeddings(input_ids, token_type_ids) + encoded_layers = self.encoder(embedding_output, + extended_attention_mask, + output_all_encoded_layers=output_all_encoded_layers, + checkpoint_activations=checkpoint_activations) + sequence_output = encoded_layers[-1] + for p in self.pooler.parameters(): + if p is None: + continue + sequence_output = sequence_output.type_as(p) + break + pooled_output = self.pooler(sequence_output) + if not output_all_encoded_layers or checkpoint_activations: + encoded_layers = encoded_layers[-1] + return encoded_layers, pooled_output + + +class BertForPreTraining(PreTrainedBertModel): + """BERT model with pre-training heads. + This module comprises the BERT model followed by the two pre-training heads: + - the masked language modeling head, and + - the next sentence classification head. + + Params: + config: a BertConfig class instance with the configuration to build a new model. + + Inputs: + `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] + with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts + `extract_features.py`, `run_classifier.py` and `run_squad.py`) + `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token + types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to + a `sentence B` token (see BERT paper for more details). + `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices + selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max + input sequence length in the current batch. It's the mask that we typically use for attention when + a batch has varying length sentences. + `masked_lm_labels`: masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length] + with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss + is only computed for the labels set in [0, ..., vocab_size] + `next_sentence_label`: next sentence classification loss: torch.LongTensor of shape [batch_size] + with indices selected in [0, 1]. + 0 => next sentence is the continuation, 1 => next sentence is a random sentence. + + Outputs: + if `masked_lm_labels` and `next_sentence_label` are not `None`: + Outputs the total_loss which is the sum of the masked language modeling loss and the next + sentence classification loss. + if `masked_lm_labels` or `next_sentence_label` is `None`: + Outputs a tuple comprising + - the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and + - the next sentence classification logits of shape [batch_size, 2]. + + Example usage: + ```python + # Already been converted into WordPiece token ids + input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) + input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) + token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) + + config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, + num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) + + model = BertForPreTraining(config) + masked_lm_logits_scores, seq_relationship_logits = model(input_ids, token_type_ids, input_mask) + ``` + """ + def __init__(self, config): + super(BertForPreTraining, self).__init__(config) + self.bert = BertModel(config) + self.cls = BertPreTrainingHeads(config, self.bert.embeddings.word_embeddings.weight) + self.apply(self.init_bert_weights) + + def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, next_sentence_label=None, checkpoint_activations=False): + sequence_output, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, + output_all_encoded_layers=False, checkpoint_activations=checkpoint_activations) + prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) + + + if masked_lm_labels is not None and next_sentence_label is not None: + loss_fct = CrossEntropyLoss(ignore_index=-1) + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size).float(), masked_lm_labels.view(-1)) + next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2).float(), next_sentence_label.view(-1)) + total_loss = masked_lm_loss + next_sentence_loss + return total_loss + else: + return prediction_scores, seq_relationship_score + + +class BertForMaskedLM(PreTrainedBertModel): + """BERT model with the masked language modeling head. + This module comprises the BERT model followed by the masked language modeling head. + + Params: + config: a BertConfig class instance with the configuration to build a new model. + + Inputs: + `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] + with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts + `extract_features.py`, `run_classifier.py` and `run_squad.py`) + `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token + types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to + a `sentence B` token (see BERT paper for more details). + `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices + selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max + input sequence length in the current batch. It's the mask that we typically use for attention when + a batch has varying length sentences. + `masked_lm_labels`: masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length] + with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss + is only computed for the labels set in [0, ..., vocab_size] + + Outputs: + if `masked_lm_labels` is not `None`: + Outputs the masked language modeling loss. + if `masked_lm_labels` is `None`: + Outputs the masked language modeling logits of shape [batch_size, sequence_length, vocab_size]. + + Example usage: + ```python + # Already been converted into WordPiece token ids + input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) + input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) + token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) + + config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, + num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) + + model = BertForMaskedLM(config) + masked_lm_logits_scores = model(input_ids, token_type_ids, input_mask) + ``` + """ + def __init__(self, config): + super(BertForMaskedLM, self).__init__(config) + self.bert = BertModel(config) + self.cls = BertOnlyMLMHead(config, self.bert.embeddings.word_embeddings.weight) + self.apply(self.init_bert_weights) + + def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, checkpoint_activations=False): + sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, + output_all_encoded_layers=False, checkpoint_activations=checkpoint_activations) + prediction_scores = self.cls(sequence_output) + + if masked_lm_labels is not None: + loss_fct = CrossEntropyLoss(ignore_index=-1) + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) + return masked_lm_loss + else: + return prediction_scores + + +class BertForNextSentencePrediction(PreTrainedBertModel): + """BERT model with next sentence prediction head. + This module comprises the BERT model followed by the next sentence classification head. + + Params: + config: a BertConfig class instance with the configuration to build a new model. + + Inputs: + `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] + with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts + `extract_features.py`, `run_classifier.py` and `run_squad.py`) + `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token + types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to + a `sentence B` token (see BERT paper for more details). + `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices + selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max + input sequence length in the current batch. It's the mask that we typically use for attention when + a batch has varying length sentences. + `next_sentence_label`: next sentence classification loss: torch.LongTensor of shape [batch_size] + with indices selected in [0, 1]. + 0 => next sentence is the continuation, 1 => next sentence is a random sentence. + + Outputs: + if `next_sentence_label` is not `None`: + Outputs the total_loss which is the sum of the masked language modeling loss and the next + sentence classification loss. + if `next_sentence_label` is `None`: + Outputs the next sentence classification logits of shape [batch_size, 2]. + + Example usage: + ```python + # Already been converted into WordPiece token ids + input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) + input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) + token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) + + config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, + num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) + + model = BertForNextSentencePrediction(config) + seq_relationship_logits = model(input_ids, token_type_ids, input_mask) + ``` + """ + def __init__(self, config): + super(BertForNextSentencePrediction, self).__init__(config) + self.bert = BertModel(config) + self.cls = BertOnlyNSPHead(config) + self.apply(self.init_bert_weights) + + def forward(self, input_ids, token_type_ids=None, attention_mask=None, next_sentence_label=None, checkpoint_activations=False): + _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, + output_all_encoded_layers=False, checkpoint_activations=checkpoint_activations) + seq_relationship_score = self.cls( pooled_output) + + if next_sentence_label is not None: + loss_fct = CrossEntropyLoss(ignore_index=-1) + next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) + return next_sentence_loss + else: + return seq_relationship_score + + +class BertForSequenceClassification(PreTrainedBertModel): + """BERT model for classification. + This module is composed of the BERT model with a linear layer on top of + the pooled output. + + Params: + `config`: a BertConfig class instance with the configuration to build a new model. + `num_labels`: the number of classes for the classifier. Default = 2. + + Inputs: + `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] + with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts + `extract_features.py`, `run_classifier.py` and `run_squad.py`) + `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token + types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to + a `sentence B` token (see BERT paper for more details). + `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices + selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max + input sequence length in the current batch. It's the mask that we typically use for attention when + a batch has varying length sentences. + `labels`: labels for the classification output: torch.LongTensor of shape [batch_size] + with indices selected in [0, ..., num_labels]. + + Outputs: + if `labels` is not `None`: + Outputs the CrossEntropy classification loss of the output with the labels. + if `labels` is `None`: + Outputs the classification logits of shape [batch_size, num_labels]. + + Example usage: + ```python + # Already been converted into WordPiece token ids + input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) + input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) + token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) + + config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, + num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) + + num_labels = 2 + + model = BertForSequenceClassification(config, num_labels) + logits = model(input_ids, token_type_ids, input_mask) + ``` + """ + def __init__(self, config, num_labels=2): + super(BertForSequenceClassification, self).__init__(config) + self.num_labels = num_labels + self.bert = BertModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, num_labels) + self.apply(self.init_bert_weights) + + def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, checkpoint_activations=False): + _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False, checkpoint_activations=checkpoint_activations) + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + return loss + else: + return logits + + +class BertForMultipleChoice(PreTrainedBertModel): + """BERT model for multiple choice tasks. + This module is composed of the BERT model with a linear layer on top of + the pooled output. + + Params: + `config`: a BertConfig class instance with the configuration to build a new model. + `num_choices`: the number of classes for the classifier. Default = 2. + + Inputs: + `input_ids`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length] + with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts + `extract_features.py`, `run_classifier.py` and `run_squad.py`) + `token_type_ids`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length] + with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` + and type 1 corresponds to a `sentence B` token (see BERT paper for more details). + `attention_mask`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length] with indices + selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max + input sequence length in the current batch. It's the mask that we typically use for attention when + a batch has varying length sentences. + `labels`: labels for the classification output: torch.LongTensor of shape [batch_size] + with indices selected in [0, ..., num_choices]. + + Outputs: + if `labels` is not `None`: + Outputs the CrossEntropy classification loss of the output with the labels. + if `labels` is `None`: + Outputs the classification logits of shape [batch_size, num_labels]. + + Example usage: + ```python + # Already been converted into WordPiece token ids + input_ids = torch.LongTensor([[[31, 51, 99], [15, 5, 0]], [[12, 16, 42], [14, 28, 57]]]) + input_mask = torch.LongTensor([[[1, 1, 1], [1, 1, 0]],[[1,1,0], [1, 0, 0]]]) + token_type_ids = torch.LongTensor([[[0, 0, 1], [0, 1, 0]],[[0, 1, 1], [0, 0, 1]]]) + config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, + num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) + + num_choices = 2 + + model = BertForMultipleChoice(config, num_choices) + logits = model(input_ids, token_type_ids, input_mask) + ``` + """ + def __init__(self, config, num_choices=2): + super(BertForMultipleChoice, self).__init__(config) + self.num_choices = num_choices + self.bert = BertModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, 1) + self.apply(self.init_bert_weights) + + def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, checkpoint_activations=False): + flat_input_ids = input_ids.view(-1, input_ids.size(-1)) + flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) + flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) + _, pooled_output = self.bert(flat_input_ids, flat_token_type_ids, flat_attention_mask, output_all_encoded_layers=False, checkpoint_activations=checkpoint_activations) + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + reshaped_logits = logits.view(-1, self.num_choices) + + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) + return loss + else: + return reshaped_logits + + +class BertForTokenClassification(PreTrainedBertModel): + """BERT model for token-level classification. + This module is composed of the BERT model with a linear layer on top of + the full hidden state of the last layer. + + Params: + `config`: a BertConfig class instance with the configuration to build a new model. + `num_labels`: the number of classes for the classifier. Default = 2. + + Inputs: + `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] + with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts + `extract_features.py`, `run_classifier.py` and `run_squad.py`) + `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token + types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to + a `sentence B` token (see BERT paper for more details). + `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices + selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max + input sequence length in the current batch. It's the mask that we typically use for attention when + a batch has varying length sentences. + `labels`: labels for the classification output: torch.LongTensor of shape [batch_size] + with indices selected in [0, ..., num_labels]. + + Outputs: + if `labels` is not `None`: + Outputs the CrossEntropy classification loss of the output with the labels. + if `labels` is `None`: + Outputs the classification logits of shape [batch_size, sequence_length, num_labels]. + + Example usage: + ```python + # Already been converted into WordPiece token ids + input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) + input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) + token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) + + config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, + num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) + + num_labels = 2 + + model = BertForTokenClassification(config, num_labels) + logits = model(input_ids, token_type_ids, input_mask) + ``` + """ + def __init__(self, config, num_labels=2): + super(BertForTokenClassification, self).__init__(config) + self.num_labels = num_labels + self.bert = BertModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + #self.classifier = nn.Linear(config.hidden_size, num_labels) + self.classifier = mpu.RowParallelLinear( + input_size=config.hidden_size, + output_size=num_labels, + bias=True, + input_is_parallel=True, + stride=1, + init_method=normal_init_method(mean=0.0, + std=config.initializer_range)) + self.apply(self.init_bert_weights) + + def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, checkpoint_activations=False): + sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False, checkpoint_activations=checkpoint_activations) + with mpu.get_cuda_rng_tracker().fork(): + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + return loss + else: + return logits + + +class BertForQuestionAnswering(PreTrainedBertModel): + """BERT model for Question Answering (span extraction). + This module is composed of the BERT model with a linear layer on top of + the sequence output that computes start_logits and end_logits + + Params: + `config`: a BertConfig class instance with the configuration to build a new model. + + Inputs: + `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] + with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts + `extract_features.py`, `run_classifier.py` and `run_squad.py`) + `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token + types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to + a `sentence B` token (see BERT paper for more details). + `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices + selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max + input sequence length in the current batch. It's the mask that we typically use for attention when + a batch has varying length sentences. + `start_positions`: position of the first token for the labeled span: torch.LongTensor of shape [batch_size]. + Positions are clamped to the length of the sequence and position outside of the sequence are not taken + into account for computing the loss. + `end_positions`: position of the last token for the labeled span: torch.LongTensor of shape [batch_size]. + Positions are clamped to the length of the sequence and position outside of the sequence are not taken + into account for computing the loss. + + Outputs: + if `start_positions` and `end_positions` are not `None`: + Outputs the total_loss which is the sum of the CrossEntropy loss for the start and end token positions. + if `start_positions` or `end_positions` is `None`: + Outputs a tuple of start_logits, end_logits which are the logits respectively for the start and end + position tokens of shape [batch_size, sequence_length]. + + Example usage: + ```python + # Already been converted into WordPiece token ids + input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) + input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) + token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) + + config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, + num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) + + model = BertForQuestionAnswering(config) + start_logits, end_logits = model(input_ids, token_type_ids, input_mask) + ``` + """ + def __init__(self, config): + super(BertForQuestionAnswering, self).__init__(config) + self.bert = BertModel(config) + # TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version + # self.dropout = nn.Dropout(config.hidden_dropout_prob) + #self.qa_outputs = nn.Linear(config.hidden_size, 2) + self.qa_outputs = mpu.RowParallelLinear( + input_size=config.hidden_size, + output_size=2, + bias=True, + input_is_parallel=True, + stride=1, + init_method=normal_init_method(mean=0.0, + std=config.initializer_range)) + self.apply(self.init_bert_weights) + + def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None, end_positions=None, checkpoint_activations=False): + sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False, checkpoint_activations=checkpoint_activations) + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1) + end_logits = end_logits.squeeze(-1) + + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions.clamp_(0, ignored_index) + end_positions.clamp_(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + return total_loss + else: + return start_logits, end_logits + + + + diff --git a/examples/Megatron-LM/mpu/__init__.py b/examples/Megatron-LM/mpu/__init__.py new file mode 100755 index 0000000..85cc048 --- /dev/null +++ b/examples/Megatron-LM/mpu/__init__.py @@ -0,0 +1,53 @@ +# coding=utf-8 +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Model parallel utility interface.""" + +from .cross_entropy import vocab_parallel_cross_entropy + +from .data import broadcast_data + +from .grads import clip_grad_norm + +from .initialize import destroy_model_parallel +from .initialize import get_data_parallel_group +from .initialize import get_data_parallel_rank +from .initialize import get_data_parallel_world_size +from .initialize import get_model_parallel_group +from .initialize import get_model_parallel_rank +from .initialize import get_model_parallel_src_rank +from .initialize import get_model_parallel_world_size +from .initialize import initialize_model_parallel +from .initialize import model_parallel_is_initialized + +from .layers import ColumnParallelLinear +from .layers import ParallelEmbedding +from .layers import RowParallelLinear +from .layers import VocabParallelEmbedding + +from .mappings import copy_to_model_parallel_region +from .mappings import gather_from_model_parallel_region +from .mappings import reduce_from_model_parallel_region +from .mappings import scatter_to_model_parallel_region + +from .random import checkpoint +from .random import partition_activations_in_checkpoint +from .random import get_cuda_rng_tracker +from .random import model_parallel_cuda_manual_seed + +from .transformer import BertParallelSelfAttention +from .transformer import BertParallelTransformerLayer +from .transformer import GPT2ParallelTransformer +from .transformer import LayerNorm diff --git a/examples/Megatron-LM/mpu/cross_entropy.py b/examples/Megatron-LM/mpu/cross_entropy.py new file mode 100644 index 0000000..845f044 --- /dev/null +++ b/examples/Megatron-LM/mpu/cross_entropy.py @@ -0,0 +1,109 @@ +# coding=utf-8 +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch + +from .initialize import get_model_parallel_group +from .initialize import get_model_parallel_rank +from .initialize import get_model_parallel_world_size +from .utils import VocabUtility + + +class _VocabParallelCrossEntropy(torch.autograd.Function): + + @staticmethod + def forward(ctx, vocab_parallel_logits, target): + + # Copy so the input remains unchanged. + logits = vocab_parallel_logits.clone() + # Maximum value along vocab dimension across all GPUs. + logits_max = torch.max(logits, dim=-1)[0] + torch.distributed.all_reduce(logits_max, + op=torch.distributed.ReduceOp.MAX, + group=get_model_parallel_group()) + # Subtract the maximum value. + logits.sub_(logits_max.unsqueeze(dim=-1)) + # Sum of exponential of logits along vocab dimension across all GPUs. + exp_logits = logits.exp() + sum_exp_logits = exp_logits.sum(dim=-1) + torch.distributed.all_reduce(sum_exp_logits, + op=torch.distributed.ReduceOp.SUM, + group=get_model_parallel_group()) + + # Get the partition's vocab indecies + get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size + partition_vocab_size = vocab_parallel_logits.size()[-1] + rank = get_model_parallel_rank() + world_size = get_model_parallel_world_size() + vocab_start_index, vocab_end_index = get_vocab_range( + partition_vocab_size, rank, world_size) + + # Create a mask of valid vocab ids (1 means it needs to be masked). + target_mask = (target < vocab_start_index) | (target >= vocab_end_index) + masked_target = target.clone() - vocab_start_index + masked_target[target_mask] = 0 + + # Get predicted-logits = logits[target]. + # For Simplicity, we convert logits to a 2-D tensor with size + # [*, partition-vocab-size] and target to a 1-D tensor of size [*]. + logits_2d = logits.view(-1, partition_vocab_size) + masked_target_1d = masked_target.view(-1) + arange_1d = torch.arange(start=0, end=logits_2d.size()[0], + device=logits_2d.device) + predicted_logits_1d = logits_2d[arange_1d, masked_target_1d] + predicted_logits = predicted_logits_1d.view_as(target) + predicted_logits[target_mask] = 0.0 + # All reduce is needed to get the chunks from other GPUs. + torch.distributed.all_reduce(predicted_logits, + op=torch.distributed.ReduceOp.SUM, + group=get_model_parallel_group()) + + # Loss = log(sum(exp(logits))) - predicted-logit. + loss = torch.log(sum_exp_logits) - predicted_logits + + # Store softmax, target-mask and masked-target for backward pass. + exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) + ctx.save_for_backward(exp_logits, target_mask, masked_target_1d) + + return loss + + @staticmethod + def backward(ctx, grad_output): + + # Retreive tensors from the forward path. + softmax, target_mask, masked_target_1d = ctx.saved_tensors + + # All the inputs have softmax as thier gradient. + grad_input = softmax + # For simplicity, work with the 2D gradient. + partition_vocab_size = softmax.size()[-1] + grad_2d = grad_input.view(-1, partition_vocab_size) + + # Add the gradient from matching classes. + arange_1d = torch.arange(start=0, end=grad_2d.size()[0], + device=grad_2d.device) + grad_2d[arange_1d, masked_target_1d] -= ( + 1.0 - target_mask.view(-1).float()) + + # Finally elementwise multiplication with the output gradients. + grad_input.mul_(grad_output.unsqueeze(dim=-1)) + + return grad_input, None + + +def vocab_parallel_cross_entropy(vocab_parallel_logits, target): + """Helper function for the cross entropy.""" + return _VocabParallelCrossEntropy.apply(vocab_parallel_logits, target) diff --git a/examples/Megatron-LM/mpu/data.py b/examples/Megatron-LM/mpu/data.py new file mode 100644 index 0000000..c7d9581 --- /dev/null +++ b/examples/Megatron-LM/mpu/data.py @@ -0,0 +1,117 @@ +# coding=utf-8 +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from deepspeed.accelerator.real_accelerator import get_accelerator + +from .initialize import get_model_parallel_group +from .initialize import get_model_parallel_rank +from .initialize import get_model_parallel_src_rank + + +_MAX_DATA_DIM = 4 + + +def _check_data_types(keys, data, target_dtype): + """Check that all the keys have the same target data type.""" + for key in keys: + assert data[key].dtype == target_dtype, '{} has data type {} which '\ + 'is different than {}'.format(key, data[key].dtype, target_dtype) + + +def _build_key_size_numel_dictionaries(keys, data): + """Build the size on rank 0 and broadcast.""" + max_dim = _MAX_DATA_DIM + sizes = [0 for _ in range(max_dim) for _ in keys] + + # Pack the sizes on rank zero. + if get_model_parallel_rank() == 0: + offset = 0 + for key in keys: + assert data[key].dim() < max_dim, 'you should increase MAX_DATA_DIM' + size = data[key].size() + for i, s in enumerate(size): + sizes[i + offset] = s + offset += max_dim + + # Move to GPU and broadcast. + sizes_cuda = get_accelerator().LongTensor(sizes) + torch.distributed.broadcast(sizes_cuda, get_model_parallel_src_rank(), + group=get_model_parallel_group()) + + # Move back to cpu and unpack. + sizes_cpu = sizes_cuda.cpu() + key_size = {} + key_numel = {} + total_numel = 0 + offset = 0 + for key in keys: + i = 0 + size = [] + numel = 1 + while sizes_cpu[offset + i] > 0: + this_size = sizes_cpu[offset + i] + size.append(this_size) + numel *= this_size + i += 1 + key_size[key] = size + key_numel[key] = numel + total_numel += numel + offset += max_dim + + return key_size, key_numel, total_numel + + +def broadcast_data(keys, data, datatype): + """Broadcast data from rank zero of each model parallel group to the + members of the same model parallel group. + + Arguments: + keys: list of keys in the data disctionary to be broadcasted + data: data dictionary of string keys and cpu tensor values. + datatype: torch data type of all tensors in data associated + with keys. + """ + # Build (key, size) and (key, number of elements) dictionaries along + # with the total number of elements on all ranks. + key_size, key_numel, total_numel = _build_key_size_numel_dictionaries(keys, + data) + + # Pack on rank zero. + if get_model_parallel_rank() == 0: + # Check that all keys have the same data type. + _check_data_types(keys, data, datatype) + # Flatten the data associated with the keys + flatten_data = torch.cat( + [data[key].contiguous().view(-1) for key in keys], dim=0).to(get_accelerator().device_name()) + else: + flatten_data = torch.empty(total_numel, + device=get_accelerator().current_device_name(), + dtype=datatype) + + # Boradcast + torch.distributed.broadcast(flatten_data, get_model_parallel_src_rank(), + group=get_model_parallel_group()) + + # Unpack + output = {} + offset = 0 + for key in keys: + size = key_size[key] + numel = key_numel[key] + output[key] = flatten_data.narrow(0, offset, numel).view(size) + offset += numel + + return output diff --git a/examples/Megatron-LM/mpu/grads.py b/examples/Megatron-LM/mpu/grads.py new file mode 100644 index 0000000..c5279e2 --- /dev/null +++ b/examples/Megatron-LM/mpu/grads.py @@ -0,0 +1,75 @@ +# coding=utf-8 +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# Parts of the code here are adapted from PyTorch +# repo: https://github.com/pytorch/pytorch + + +import torch +from torch._six import inf +from deepspeed.accelerator.real_accelerator import get_accelerator + +from .initialize import get_model_parallel_group +from .initialize import get_model_parallel_rank + + +def clip_grad_norm(parameters, max_norm, norm_type=2): + """Clips gradient norm of an iterable of parameters. + + This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and + added functionality to handle model parallel parameters. Note that + the gradients are modified in place. + + Arguments: + parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a + single Tensor that will have gradients normalized + max_norm (float or int): max norm of the gradients + norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for + infinity norm. + + Returns: + Total norm of the parameters (viewed as a single vector). + """ + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + parameters = list(filter(lambda p: p.grad is not None, parameters)) + max_norm = float(max_norm) + norm_type = float(norm_type) + if norm_type == inf: + total_norm = max(p.grad.data.abs().max() for p in parameters) + total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]) + # Take max across all GPUs. + torch.distributed.all_reduce(total_norm_cuda, + op=torch.distributed.ReduceOp.MAX, + group=get_model_parallel_group()) + total_norm = total_norm_cuda[0].item() + else: + total_norm = 0 + for p in parameters: + if p.model_parallel or (get_model_parallel_rank() == 0): + param_norm = p.grad.data.norm(norm_type) + total_norm += param_norm.item() ** norm_type + # Sum across all model parallel GPUs. + total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]) + torch.distributed.all_reduce(total_norm_cuda, + op=torch.distributed.ReduceOp.SUM, + group=get_model_parallel_group()) + total_norm = total_norm_cuda[0].item() ** (1. / norm_type) + clip_coef = max_norm / (total_norm + 1e-6) + if clip_coef < 1: + for p in parameters: + p.grad.data.mul_(clip_coef) + return total_norm diff --git a/examples/Megatron-LM/mpu/initialize.py b/examples/Megatron-LM/mpu/initialize.py new file mode 100644 index 0000000..409c939 --- /dev/null +++ b/examples/Megatron-LM/mpu/initialize.py @@ -0,0 +1,135 @@ +# coding=utf-8 +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Model and data parallel groups.""" + +import torch + +from .utils import ensure_divisibility + + +# Model parallel group that the current rank belongs to. +_MODEL_PARALLEL_GROUP = None +# Data parallel group that the current rank belongs to. +_DATA_PARALLEL_GROUP = None + + +def initialize_model_parallel(model_parallel_size_): + """ + Initialize model data parallel groups. + + Arguments: + model_parallel_size: number of GPUs used to parallelize model. + + Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we + use 2 GPUs to parallelize the model. The present function will + create 4 model parallel groups and 2 data parallel groups as: + 4 model parallel groups: + [g0, g1], [g2, g3], [g4, g5], [g6, g7] + 2 data parallel groups: + [g0, g2, g4, g6], [g1, g3, g5, g7] + Note that for efficiency, the caller should make sure adjacent ranks + are on the same DGX box. For example if we are using 2 DGX-1 boxes + with a total of 16 GPUs, rank 0 to 7 belong to the first box and + ranks 8 to 15 belong to the second box. + """ + if torch.distributed.get_rank() == 0: + print('> initializing model parallel with size {}'.format( + model_parallel_size_)) + # Get world size and rank. Ensure some consistencies. + assert torch.distributed.is_initialized() + world_size = torch.distributed.get_world_size() + model_parallel_size = min(model_parallel_size_, world_size) + ensure_divisibility(world_size, model_parallel_size) + rank = torch.distributed.get_rank() + + # Build the data parallel groups. + global _DATA_PARALLEL_GROUP + assert _DATA_PARALLEL_GROUP is None, \ + 'data parallel group is already initialized' + for i in range(model_parallel_size): + ranks = range(i, world_size, model_parallel_size) + group = torch.distributed.new_group(ranks) + if i == (rank % model_parallel_size): + _DATA_PARALLEL_GROUP = group + + # Build the model parallel groups. + global _MODEL_PARALLEL_GROUP + assert _MODEL_PARALLEL_GROUP is None, \ + 'model parallel group is already initialized' + for i in range(world_size // model_parallel_size): + ranks = range(i * model_parallel_size, + (i + 1) * model_parallel_size) + group = torch.distributed.new_group(ranks) + if i == (rank // model_parallel_size): + _MODEL_PARALLEL_GROUP = group + + +def model_parallel_is_initialized(): + """Check if model and data parallel groups are initialized.""" + if _MODEL_PARALLEL_GROUP is None or _DATA_PARALLEL_GROUP is None: + return False + return True + + +def get_model_parallel_group(): + """Get the model parallel group the caller rank belongs to.""" + assert _MODEL_PARALLEL_GROUP is not None, \ + 'model parallel group is not initialized' + return _MODEL_PARALLEL_GROUP + + +def get_data_parallel_group(): + """Get the data parallel group the caller rank belongs to.""" + assert _DATA_PARALLEL_GROUP is not None, \ + 'data parallel group is not initialized' + return _DATA_PARALLEL_GROUP + + +def get_model_parallel_world_size(): + """Return world size for the model parallel group.""" + return torch.distributed.get_world_size(group=get_model_parallel_group()) + + +def get_model_parallel_rank(): + """Return my rank for the model parallel group.""" + return torch.distributed.get_rank(group=get_model_parallel_group()) + + +def get_model_parallel_src_rank(): + """Calculate the global rank corresponding to a local rank zero + in the model parallel group.""" + global_rank = torch.distributed.get_rank() + local_world_size = get_model_parallel_world_size() + return (global_rank // local_world_size) * local_world_size + + +def get_data_parallel_world_size(): + """Return world size for the data parallel group.""" + return torch.distributed.get_world_size(group=get_data_parallel_group()) + + +def get_data_parallel_rank(): + """Return my rank for the data parallel group.""" + return torch.distributed.get_rank(group=get_data_parallel_group()) + + +def destroy_model_parallel(): + """Set the groups to none.""" + global _MODEL_PARALLEL_GROUP + _MODEL_PARALLEL_GROUP = None + global _DATA_PARALLEL_GROUP + _DATA_PARALLEL_GROUP = None diff --git a/examples/Megatron-LM/mpu/layers.py b/examples/Megatron-LM/mpu/layers.py new file mode 100644 index 0000000..85c1220 --- /dev/null +++ b/examples/Megatron-LM/mpu/layers.py @@ -0,0 +1,331 @@ +# coding=utf-8 +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# Parts of the code here are adapted from PyTorch +# repo: https://github.com/pytorch/pytorch + + +import math + +import torch +import torch.nn.functional as F +import torch.nn.init as init +from torch.nn.parameter import Parameter + +from deepspeed.accelerator.real_accelerator import get_accelerator +if get_accelerator().device_name() == 'cuda': + from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm +else: + from torch.nn import LayerNorm + +from .initialize import get_model_parallel_rank +from .initialize import get_model_parallel_world_size +from .mappings import copy_to_model_parallel_region +from .mappings import gather_from_model_parallel_region +from .mappings import reduce_from_model_parallel_region +from .mappings import scatter_to_model_parallel_region +from .random import get_cuda_rng_tracker +from .utils import divide +from .utils import split_tensor_along_last_dim +from .utils import VocabUtility + + +def _initialize_affine_weight(weight, output_size, input_size, + per_partition_size, partition_dim, init_method, + stride=1, return_master_weight=False): + """Initialize affine weight for model parallel. + + Build the master weight on all processes and scatter + the relevant chunk.""" + # If we only use 1 process for model parallelism, bypass scatter. + world_size = get_model_parallel_world_size() + if world_size == 1: + init_method(weight) + if return_master_weight: + return weight + return None + + # Initialize master weight + master_weight = torch.empty(output_size, input_size, + dtype=weight.dtype, + requires_grad=False) + init_method(master_weight) + + # Split and copy + per_partition_per_stride_size = divide(per_partition_size, stride) + weight_list = torch.split(master_weight, per_partition_per_stride_size, + dim=partition_dim) + rank = get_model_parallel_rank() + my_weight_list = weight_list[rank::world_size] + + with torch.no_grad(): + torch.cat(my_weight_list, dim=partition_dim, out=weight) + if return_master_weight: + return master_weight + return None + + +class VocabParallelEmbedding(torch.nn.Module): + """Embedding parallelized in the vocabulary dimension. + + This is mainly adapted from torch.nn.Embedding and all the default + values are kept. + Arguments: + num_embeddings: vocabulary size. + embedding_dim: size of hidden state. + init_method: method to initialize weights. + """ + def __init__(self, num_embeddings, embedding_dim, + init_method=init.xavier_normal_): + super(VocabParallelEmbedding, self).__init__() + # Keep the input dimensions. + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + # Set the detauls for compatibility. + self.padding_idx = None + self.max_norm = None + self.norm_type = 2. + self.scale_grad_by_freq = False + self.sparse = False + self._weight = None + # Divide the weight matrix along the vocaburaly dimension. + self.vocab_start_index, self.vocab_end_index = \ + VocabUtility.vocab_range_from_global_vocab_size( + self.num_embeddings, get_model_parallel_rank(), + get_model_parallel_world_size()) + self.num_embeddings_per_partition = self.vocab_end_index - \ + self.vocab_start_index + + # Allocate weights. + self.weight = Parameter(torch.Tensor(self.num_embeddings_per_partition, + self.embedding_dim)) + self.weight.model_parallel = True + # And initialize. + _initialize_affine_weight( + self.weight, self.num_embeddings, self.embedding_dim, + self.num_embeddings_per_partition, 0, init_method) + + def forward(self, input_): + # Build the mask. + input_mask = (input_ < self.vocab_start_index) | \ + (input_ >= self.vocab_end_index) + # Mask the input. + masked_input = input_.clone() - self.vocab_start_index + masked_input[input_mask] = 0 + # Get the embeddings. + output_parallel = F.embedding(masked_input, self.weight, + self.padding_idx, self.max_norm, + self.norm_type, self.scale_grad_by_freq, + self.sparse) + # Mask the output embedding. + output_parallel[input_mask, :] = 0.0 + # Reduce across all the model parallel GPUs. + output = reduce_from_model_parallel_region(output_parallel) + return output + + +class ParallelEmbedding(torch.nn.Module): + """Embedding parallelized in the embedding dimension. + + This is mainly adapted from torch.nn.Embedding and all the default + values are kept. + Arguments: + num_embeddings: vocabulary size. + embedding_dim: size of hidden state. + init_method: method to initialize weights. + """ + def __init__(self, num_embeddings, embedding_dim, + init_method=init.xavier_normal_, + keep_master_weight_for_test=False): + super(ParallelEmbedding, self).__init__() + # Keep the input dimensions. + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + # Set some detauls for compatibility. + self.padding_idx = None + self.max_norm = None + self.norm_type = 2. + self.scale_grad_by_freq = False + self.sparse = False + self._weight = None + # Divide the weight matrix along the embedding dimension. + world_size = get_model_parallel_world_size() + self.embedding_dim_per_partition = divide(self.embedding_dim, + world_size) + + # Allocate weights. + self.weight = Parameter(torch.Tensor(self.num_embeddings, + self.embedding_dim_per_partition)) + self.weight.model_parallel = True + # And initialize. + _initialize_affine_weight( + self.weight, self.num_embeddings, self.embedding_dim, + self.embedding_dim_per_partition, 1, init_method, + stride=1, return_master_weight=False) + + def forward(self, input_): + input_parallel = copy_to_model_parallel_region(input_) + output_parallel = F.embedding(input_parallel, self.weight, + self.padding_idx, self.max_norm, + self.norm_type, self.scale_grad_by_freq, + self.sparse) + output = gather_from_model_parallel_region(output_parallel) + return output + + +class ColumnParallelLinear(torch.nn.Module): + """Linear layer with column parallelism. + + The linear layer is defined as Y = XA + b. A is parallelized along + its second dimension as A = [A_1, ..., A_p]. + + Arguments: + input_size: first dimension of matrix A. + output_size: second dimension of matrix A. + bias: If true, add bias + gather_output: If true, call all-gether on output and make Y avaiable + to all GPUs, otherwise, every GPU will have its output + which is Y_i = XA_i + init_method: method to initialize weights. Note that bias is always set + to zero. + stride: For the strided linear layers. + keep_master_weight_for_test: This was added for testing and should be + set to False. It returns the master weights + used for initialization. + """ + def __init__(self, input_size, output_size, bias=True, gather_output=True, + init_method=init.xavier_normal_, stride=1, + keep_master_weight_for_test=False): + super(ColumnParallelLinear, self).__init__() + + # Keep input parameters + self.input_size = input_size + self.output_size = output_size + self.gather_output = gather_output + # Divide the weight matrix along the last dimension. + world_size = get_model_parallel_world_size() + self.output_size_per_partition = divide(output_size, world_size) + + # Parameters. + # Note: torch.nn.functional.linear performs XA^T + b and as a result + # we allocate the transpose. + self.weight = Parameter(torch.Tensor(self.output_size_per_partition, + self.input_size)) + self.weight.model_parallel = True + if bias: + self.bias = Parameter(torch.Tensor(self.output_size_per_partition)) + self.bias.model_parallel = True + # Always initialize bias to zero. + with torch.no_grad(): + self.bias.zero_() + else: + self.register_parameter('bias', None) + + # Initialize weight. + self.master_weight = _initialize_affine_weight( + self.weight, self.output_size, self.input_size, + self.output_size_per_partition, 0, init_method, + stride=stride, return_master_weight=keep_master_weight_for_test) + + def forward(self, input_): + # Set up backprop all-reduce. + input_parallel = copy_to_model_parallel_region(input_) + # Matrix multiply. + output_parallel = F.linear(input_parallel, self.weight, self.bias) + if self.gather_output: + # All-gather across the partitions. + output = gather_from_model_parallel_region(output_parallel) + else: + output = output_parallel + return output + + +class RowParallelLinear(torch.nn.Module): + """Linear layer with row parallelism. + + The linear layer is defined as Y = XA + b. A is parallelized along + its first dimension and X along its second dimension as: + - - + | A_1 | + | . | + A = | . | X = [X_1, ..., X_p] + | . | + | A_p | + - - + Arguments: + input_size: first dimension of matrix A. + output_size: second dimension of matrix A. + bias: If true, add bias. Note that bias is not parallelized. + input_is_parallel: If true, we assume that the input is already + split across the GPUs and we do not split + again. + init_method: method to initialize weights. Note that bias is always set + to zero. + stride: For the strided linear layers. + keep_master_weight_for_test: This was added for testing and should be + set to False. It returns the master weights + used for initialization. + """ + def __init__(self, input_size, output_size, bias=True, + input_is_parallel=False, + init_method=init.xavier_normal_, stride=1, + keep_master_weight_for_test=False): + super(RowParallelLinear, self).__init__() + + # Keep input parameters + self.input_size = input_size + self.output_size = output_size + self.input_is_parallel = input_is_parallel + # Divide the weight matrix along the last dimension. + world_size = get_model_parallel_world_size() + self.input_size_per_partition = divide(input_size, world_size) + + # Parameters. + # Note: torch.nn.functional.linear performs XA^T + b and as a result + # we allocate the transpose. + self.weight = Parameter(torch.Tensor(self.output_size, + self.input_size_per_partition)) + self.weight.model_parallel = True + if bias: + self.bias = Parameter(torch.Tensor(self.output_size)) + # Always initialize bias to zero. + with torch.no_grad(): + self.bias.zero_() + else: + self.register_parameter('bias', None) + + # Initialize weight. + self.master_weight = _initialize_affine_weight( + self.weight, self.output_size, self.input_size, + self.input_size_per_partition, 1, init_method, + stride=stride, return_master_weight=keep_master_weight_for_test) + + def forward(self, input_): + # Set up backprop all-reduce. + if self.input_is_parallel: + input_parallel = input_ + else: + input_parallel = scatter_to_model_parallel_region(input_) + # Matrix multiply. + output_parallel = F.linear(input_parallel, self.weight) + # All-reduce across all the partitions. + output_ = reduce_from_model_parallel_region(output_parallel) + if self.bias is not None: + output = output_ + self.bias + else: + output = output_ + return output + diff --git a/examples/Megatron-LM/mpu/mappings.py b/examples/Megatron-LM/mpu/mappings.py new file mode 100644 index 0000000..d91f48b --- /dev/null +++ b/examples/Megatron-LM/mpu/mappings.py @@ -0,0 +1,141 @@ +# coding=utf-8 +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from .initialize import get_model_parallel_group +from .utils import split_tensor_along_last_dim + + +def _reduce(input_): + """All-reduce the the input tensor across model parallel group.""" + group = get_model_parallel_group() + + # Bypass the function if we are using only 1 GPU. + if torch.distributed.get_world_size(group=group) == 1: + return input_ + + # All-reduce. + torch.distributed.all_reduce(input_, group=group) + + return input_ + + +def _split(input_): + """Split the tensor along its last dimension and keep the + corresponding slice.""" + group = get_model_parallel_group() + + # Bypass the function if we are using only 1 GPU. + if torch.distributed.get_world_size(group=group) == 1: + return input_ + + # Split along last dimension. + world_size = torch.distributed.get_world_size(group=group) + input_list = split_tensor_along_last_dim(input_, world_size) + + # Note: torch.split does not create contiguous tensors by default. + rank = torch.distributed.get_rank(group=group) + output = input_list[rank].contiguous() + + return output + + +def _gather(input_): + """Gather tensors and concatinate along the last dimension.""" + group = get_model_parallel_group() + + # Bypass the function if we are using only 1 GPU. + if torch.distributed.get_world_size(group=group) == 1: + return input_ + + # Size and dimension. + last_dim = input_.dim() - 1 + rank = torch.distributed.get_rank(group=group) + world_size = torch.distributed.get_world_size(group=group) + + tensor_list = [torch.empty_like(input_) for _ in range(world_size)] + tensor_list[rank] = input_ + torch.distributed.all_gather(tensor_list, input_, group=group) + + # Note: torch.cat already creates a contiguous tensor. + output = torch.cat(tensor_list, dim=last_dim).contiguous() + + return output + + +class _CopyToModelParallelRegion(torch.autograd.Function): + """Pass the input to the model parallel region.""" + + @staticmethod + def forward(ctx, input_): + return input_ + + @staticmethod + def backward(ctx, grad_output): + return _reduce(grad_output) + + +class _ReduceFromModelParallelRegion(torch.autograd.Function): + """All-redcue the input from the model parallel region.""" + + @staticmethod + def forward(ctx, input_): + return _reduce(input_) + + @staticmethod + def backward(ctx, grad_output): + return grad_output + + +class _ScatterToModelParallelRegion(torch.autograd.Function): + """Split the input and keep only the corresponding chuck to the rank.""" + + @staticmethod + def forward(ctx, input_): + return _split(input_) + + @staticmethod + def backward(ctx, grad_output): + return _gather(grad_output) + + +class _GatherFromModelParallelRegion(torch.autograd.Function): + """Gather the input from model parallel region and concatinate.""" + + @staticmethod + def forward(ctx, input_): + return _gather(input_) + + @staticmethod + def backward(ctx, grad_output): + return _split(grad_output) + + +# ----------------- +# Helper functions. +# ----------------- + +def copy_to_model_parallel_region(input_): + return _CopyToModelParallelRegion.apply(input_) + +def reduce_from_model_parallel_region(input_): + return _ReduceFromModelParallelRegion.apply(input_) + +def scatter_to_model_parallel_region(input_): + return _ScatterToModelParallelRegion.apply(input_) + +def gather_from_model_parallel_region(input_): + return _GatherFromModelParallelRegion.apply(input_) diff --git a/examples/Megatron-LM/mpu/random.py b/examples/Megatron-LM/mpu/random.py new file mode 100755 index 0000000..67be618 --- /dev/null +++ b/examples/Megatron-LM/mpu/random.py @@ -0,0 +1,387 @@ +# coding=utf-8 +#Modified by Samyam Rajbhandari +#Used to partition the activations stored for backward propagation +#Therefore reduces the memory consumption + +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# Parts of the code here are adapted from PyTorch +# repo: https://github.com/pytorch/pytorch +import contextlib +import torch.distributed as dist +import torch +from torch import _C +from deepspeed.accelerator.real_accelerator import get_accelerator +from torch.cuda import _lazy_call, device as device_ctx_manager +#from torch.utils.checkpoint import detach_variable + + +import torch.distributed as dist +PARTITION_ACTIVATIONS = False +PA_CORRECTNESS_TEST= False + +def see_memory_usage(message, force=False): + if not force: + return + dist.barrier() + if dist.get_rank() == 0: + print(message) + print("Memory Allocated ", get_accelerator().memory_allocated()/(1024*1024*1024), "GigaBytes") + print("Max Memory Allocated ", get_accelerator().max_memory_allocated()/(1024*1024*1024), "GigaBytes") + print("Cache Allocated ", get_accelerator().memory_cached()/(1024*1024*1024), "GigaBytes") + print("Max cache Allocated ", get_accelerator().max_memory_cached()/(1024*1024*1024), "GigaBytes") + print(" ") + #input("Press Any Key To Continue ..") + + +from .initialize import get_data_parallel_rank +from .initialize import get_model_parallel_rank +from .initialize import get_model_parallel_world_size +from .initialize import get_model_parallel_group + +mp_rank = None #get_model_parallel_rank() +mp_size = None #get_model_parallel_world_size() +mp_group = None #get_model_parallel_group() + +# Default name for the model parallel rng tracker. +_MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng' +transport_stream = None +cuda_device=None +def detach_variable(inputs, device=None): + if isinstance(inputs, tuple): + out = [] + for inp in inputs: + if not isinstance(inp, torch.Tensor): + out.append(inp) + continue + + requires_grad = inp.requires_grad + + if device is not None: + x = inp.to(device=device) + else: + x = inp + + x = x.detach() + x.requires_grad = requires_grad + out.append(x) + return tuple(out) + else: + raise RuntimeError( + "Only tuple of tensors is supported. Got Unsupported input type: ", type(inputs).__name__) + +def _set_cuda_rng_state(new_state, device=-1): + """Sets the random number generator state of the current GPU. + + Argumentss: + new_state (torch.ByteTensor): The desired state + This function is adapted from PyTorch repo (get_accelerator().set_rng_state) + with a single change: the input state is not cloned. Cloning caused + major performance issues for +4 GPU cases. + """ + if hasattr(_C, '_cuda_setRNGState') and callable(_C._cuda_setRNGState): + # older PyTorch + def cb(): + with device_ctx_manager(device): + _C._cuda_setRNGState(new_state) + else: + # newer PyTorch + if device == -1: + device = torch.device('cuda') + elif isinstance(device, str): + device = torch.device(device) + elif isinstance(device, int): + device = torch.device('cuda', device) + + def cb(): + idx = device.index + if idx is None: + idx = torch.cuda.current_device() + default_generator = get_accelerator().default_generator(idx) + default_generator.set_state(new_state) + + _lazy_call(cb) + + + +class CudaRNGStatesTracker: + """Tracker for the cuda RNG states. + + Using the `add` method, a cuda rng state is initialized based on + the input `seed` and is assigned to `name`. Later, by forking the + rng state, we can perform operations and return to our starting + cuda state. + """ + def __init__(self): + # Map from a string name to the cuda rng state. + self.states_ = {} + # Seeds are just for book keeping and ensure no seed is set twice. + self.seeds_ = set() + + def reset(self): + """Set to the initial state (no tracker).""" + self.states_ = {} + self.seeds_ = set() + + def get_states(self): + """Get rng states. Copy the dictionary so we have direct + pointers to the states, not just a pointer to the dictionary.""" + states = {} + for name in self.states_: + states[name] = self.states_[name] + return states + + def set_states(self, states): + """Set the rng states. For efficiency purposes, we do not check + the size of seed for compatibility.""" + self.states_ = states + + def add(self, name, seed): + """Track the rng state.""" + # Check seed is not already used. + if seed in self.seeds_: + raise Exception('seed {} already exists'.format(seed)) + self.seeds_.add(seed) + # Check that state is not already defined. + if name in self.states_: + raise Exception('cuda rng state {} already exists'.format(name)) + # Get the current rng state. + orig_rng_state = get_accelerator().get_rng_state() + # Set the new state and store it. + get_accelerator().manual_seed(seed) + self.states_[name] = get_accelerator().get_rng_state() + # Reset rng state to what it was. + _set_cuda_rng_state(orig_rng_state) + + @contextlib.contextmanager + def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME): + """Fork the cuda rng state, perform operations, and exit with + the original state.""" + # Check if we have added the state + if name not in self.states_: + raise Exception('cuda rng state {} is not added'.format(name)) + # Store current rng state. + orig_cuda_rng_state = get_accelerator().get_rng_state() + # Set rng state to the desired one + _set_cuda_rng_state(self.states_[name]) + # Do the stuff we wanted to do. + try: + yield + finally: + # Update the current rng state for later use. + self.states_[name] = get_accelerator().get_rng_state() + # And set the state to the original state we started with. + _set_cuda_rng_state(orig_cuda_rng_state) + + +# RNG tracker object. +_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() + + +def get_cuda_rng_tracker(): + """Get cuda rng tracker.""" + return _CUDA_RNG_STATE_TRACKER + + +def model_parallel_cuda_manual_seed(seed): + """Initialize model parallel cuda seed. + + This function should be called after the model parallel is + initialized. Also, no get_accelerator().manual_seed should be called + after this function. Basically, this is replacement for that + function. + Two set of RNG states are tracked: + default state: This is for data parallelism and is the same among a + set of model parallel GPUs but different across + different model paralle groups. This is used for + example for dropout in the non-model-parallel regions. + model-parallel state: This state is different among a set of model + parallel GPUs, but the same across data parallel + groups. This is used for example for dropout in + model parallel regions. + """ + # 2718 is just for fun and any POSITIVE value will work. + offset = seed + 2718 + model_parallel_seed = offset + get_model_parallel_rank() + # Data parallel gets the original sedd. + data_parallel_seed = seed + + if torch.distributed.get_rank() == 0: + print('> initializing model parallel cuda seeds on global rank {}, ' + 'model parallel rank {}, and data parallel rank {} with ' + 'model parallel seed: {} and data parallel seed: {}'.format( + torch.distributed.get_rank(), get_model_parallel_rank(), + get_data_parallel_rank(), model_parallel_seed, + data_parallel_seed), flush=True) + _CUDA_RNG_STATE_TRACKER.reset() + # Set the default state. + get_accelerator().manual_seed(data_parallel_seed) + # and model parallel state. + _CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME, + model_parallel_seed) + + +def get_partition_start(item): + global mp_rank, mp_size, mp_group + partition_size = get_partition_size(item) + start = partition_size * mp_rank + return int(start) + +def get_partition_size(item): + global mp_rank, mp_size, mp_group + size = item.numel() + partition_size = size/mp_size + return int(partition_size) + +def get_full_inputs(tensors): + inputs=[] + for i in range(int(len(tensors)/2)-1): + item = tensors[2 * i] + size = tensors[2* i + 1] + partition_size = item.numel() + tensor_size = partition_size * mp_size + flat_tensor = torch.zeros([tensor_size], dtype=item.dtype, device=item.device) + partitions=[] + for i in range(mp_size): + part_i = flat_tensor.narrow(0, partition_size * i , partition_size) + if i == mp_rank: + part_i.copy_(item) + partitions.append(part_i) + dist.all_gather(partitions,partitions[mp_rank], group=mp_group) + input_tensor = flat_tensor.view(list(size.numpy())) + item.data=input_tensor.data + + inputs.append(item) + inputs.append(tensors[-2]) + + return tuple(inputs) + + + +class CheckpointFunction(torch.autograd.Function): + """This function is adapted from torch.utils.checkpoint with + two main changes: + 1) get_accelerator().set_rng_state is replaced with `_set_cuda_rng_state` + 2) the states in the model parallel tracker are also properly + tracked/set/reset. + """ + @staticmethod + def forward(ctx, run_function, *args): + ctx.run_function = run_function + global mp_rank, mp_size, mp_group + if mp_rank is None: + mp_rank = get_model_parallel_rank() + mp_size = get_model_parallel_world_size() + mp_group = get_model_parallel_group() + + + global cuda_device, transport_stream, PARTITION_ACTIVATIONS + if cuda_device is None: + if dist.get_rank() == 0: + print(f"Partition Activations {PARTITION_ACTIVATIONS} and Correctness Check {PA_CORRECTNESS_TEST}") + + cuda_device = get_accelerator().current_device_name() + #The transport stream is used to overlap the allgather communication for the activations + #with the computation in the backward pass + transport_stream = get_accelerator().Stream(device=cuda_device) + + if PARTITION_ACTIVATIONS: + inputs = [item.detach().contiguous().view(-1).narrow(0, get_partition_start(item), get_partition_size(item)).clone() for item in args[:-1]] + inputs.append(args[-1]) + + #just in case something funky is happening such as reuse of inputs + inputs_cuda = [item.to(cuda_device) for item in args] + + # Copy the rng states. + ctx.fwd_cpu_rng_state = torch.get_rng_state() + ctx.fwd_cuda_rng_state = get_accelerator().get_rng_state() + ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() + + #ctx.save_for_backward(*args) + with torch.no_grad(): + outputs = run_function(*inputs_cuda) + + del inputs_cuda + + if PARTITION_ACTIVATIONS: + new_args = [] + for arg, inp in zip(args,inputs): + size= torch.tensor(arg.size()) + arg.data = inp.data + new_args.append(arg) + new_args.append(size) + ctx.save_for_backward(*new_args) + else: + ctx.save_for_backward(*args) + + return outputs + + @staticmethod + def backward(ctx, *args): + if not torch.autograd._is_checkpoint_valid(): + raise RuntimeError("Checkpointing is not compatible with .grad(), " + "please use .backward() if possible") + + global cuda_device, transport_stream, PARTITION_ACTIVATIONS + + if PARTITION_ACTIVATIONS: + with get_accelerator().stream(transport_stream): + inputs = get_full_inputs(ctx.saved_tensors) + detached_inputs = detach_variable(inputs) + else: + inputs = ctx.saved_tensors + detached_inputs = detach_variable(inputs) + + # Store the current states. + bwd_cpu_rng_state = torch.get_rng_state() + bwd_cuda_rng_state = get_accelerator().get_rng_state() + bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() + + # Set the states to what it used to be before the forward pass. + torch.set_rng_state(ctx.fwd_cpu_rng_state) + _set_cuda_rng_state(ctx.fwd_cuda_rng_state) + get_cuda_rng_tracker().set_states(ctx.fwd_cuda_rng_state_tracker) + + if PARTITION_ACTIVATIONS: + current_stream=get_accelerator().current_stream() + current_stream.wait_stream(transport_stream) + + with torch.enable_grad(): + outputs = ctx.run_function(*detached_inputs) + + # Set the states back to what it was at the start of this function. + torch.set_rng_state(bwd_cpu_rng_state) + _set_cuda_rng_state(bwd_cuda_rng_state) + get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker) + + if isinstance(outputs, torch.Tensor): + outputs = (outputs,) + torch.autograd.backward(outputs, args) + return (None,) + tuple(inp.grad for inp in detached_inputs) + + +def checkpoint(function, *args): + """Checkpoint a model or part of the model. + This has been directly copied from torch.utils.checkpoint.""" + return CheckpointFunction.apply(function, *args) + +def partition_activations_in_checkpoint(partition_activation): + global PARTITION_ACTIVATIONS + PARTITION_ACTIVATIONS=partition_activation + if dist.get_rank() == 0: + print(f"**************Partition Activations {PARTITION_ACTIVATIONS}************") + + diff --git a/examples/Megatron-LM/mpu/tests/__init__.py b/examples/Megatron-LM/mpu/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/examples/Megatron-LM/mpu/tests/commons.py b/examples/Megatron-LM/mpu/tests/commons.py new file mode 100644 index 0000000..b1093c5 --- /dev/null +++ b/examples/Megatron-LM/mpu/tests/commons.py @@ -0,0 +1,83 @@ +# coding=utf-8 +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import random +import numpy +import torch + +from deepspeed.accelerator.real_accelerator import get_accelerator +import mpu + + +class IdentityLayer(torch.nn.Module): + def __init__(self, size, scale=1.0): + super(IdentityLayer, self).__init__() + self.weight = torch.nn.Parameter(scale * torch.randn(size)) + def forward(self): + return self.weight + + +def set_random_seed(seed): + """Set random seed for reproducability.""" + random.seed(seed) + numpy.random.seed(seed) + torch.manual_seed(seed) + mpu.model_parallel_cuda_manual_seed(seed) + + +def initialize_distributed(backend='nccl'): + """Initialize torch.distributed.""" + # Get local rank in case it is provided. + parser = argparse.ArgumentParser() + parser.add_argument('--local_rank', type=int, default=None, + help='local rank passed from distributed launcher') + args = parser.parse_args() + local_rank = args.local_rank + + # Get rank and world size. + rank = int(os.getenv('RANK', '0')) + world_size = int(os.getenv("WORLD_SIZE", '1')) + + print('> initializing torch.distributed with local rank: {}, ' + 'rank: {}, world size: {}'.format(local_rank, rank, world_size)) + + # Set the device id. + device = rank % get_accelerator().device_count() + if local_rank is not None: + device = local_rank + get_accelerator().set_device(device) + + # Call the init process. + init_method = 'tcp://' + master_ip = os.getenv('MASTER_ADDR', 'localhost') + master_port = os.getenv('MASTER_PORT', '6000') + init_method += master_ip + ':' + master_port + torch.distributed.init_process_group( + backend=backend, + world_size=world_size, + rank=rank, + init_method=init_method) + + +def print_separator(message): + torch.distributed.barrier() + filler_len = (78 - len(message)) // 2 + filler = '-' * filler_len + string = '\n' + filler + ' {} '.format(message) + filler + if torch.distributed.get_rank() == 0: + print(string, flush=True) + torch.distributed.barrier() diff --git a/examples/Megatron-LM/mpu/tests/test_cross_entropy.py b/examples/Megatron-LM/mpu/tests/test_cross_entropy.py new file mode 100644 index 0000000..3987567 --- /dev/null +++ b/examples/Megatron-LM/mpu/tests/test_cross_entropy.py @@ -0,0 +1,111 @@ +# coding=utf-8 +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import sys +sys.path.append("../..") + +import torch +import torch.nn.functional as F +from deepspeed.accelerator.real_accelerator import get_accelerator +import mpu +from mpu.cross_entropy import vocab_parallel_cross_entropy + +from commons import initialize_distributed +from commons import print_separator +from commons import IdentityLayer +from commons import set_random_seed + + +def torch_cross_entropy(batch_size, seq_length, vocab_size, + logits_scale, seed): + set_random_seed(seed) + identity = IdentityLayer((batch_size, seq_length, vocab_size), + scale=logits_scale).cuda() + logits = identity() + target = get_accelerator().LongTensor( + size=(batch_size, seq_length)).random_(0, vocab_size) + loss = F.cross_entropy(logits.view(-1, logits.size()[-1]), + target.view(-1), + reduction='none').view_as(target).mean() + loss.backward() + return loss, identity.weight.grad + + +def mpu_cross_entropy(batch_size, seq_length, vocab_size, + logits_scale, seed): + set_random_seed(seed) + identity = IdentityLayer((batch_size, seq_length, vocab_size), + scale=logits_scale).cuda() + logits = identity() + logits_parallel = mpu.scatter_to_model_parallel_region(logits) + target = get_accelerator().LongTensor( + size=(batch_size, seq_length)).random_(0, vocab_size) + loss = vocab_parallel_cross_entropy(logits_parallel, target).mean() + loss.backward() + return loss, identity.weight.grad + + +def test_cross_entropy(model_parallel_size): + + if torch.distributed.get_rank() == 0: + print('> testing cross entropy with model parallel size {} ...'. + format(model_parallel_size)) + + mpu.initialize_model_parallel(model_parallel_size) + model_parallel_size = mpu.get_model_parallel_world_size() + + batch_size = 13 + seq_length = 17 + vocab_size_per_partition = 11 + logits_scale = 1000.0 + vocab_size = vocab_size_per_partition * model_parallel_size + seed = 1234 + + loss_torch, grad_torch = torch_cross_entropy(batch_size, seq_length, + vocab_size, logits_scale, + seed) + loss_mpu, grad_mpu = mpu_cross_entropy(batch_size, seq_length, + vocab_size, logits_scale, + seed) + + error = loss_torch.sub_(loss_mpu).abs().max() + print(' max error in loss on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 1.0e-6 + + error = grad_torch.sub_(grad_mpu).abs().max() + print(' max error in grad on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 1.0e-6 + + # Reset groups + mpu.destroy_model_parallel() + + torch.distributed.barrier() + if torch.distributed.get_rank() == 0: + print('>> passed the test :-)') + + +if __name__ == '__main__': + + initialize_distributed() + world_size = torch.distributed.get_world_size() + + model_parallel_size = 1 + while model_parallel_size <= world_size: + print_separator('test cross entropy') + test_cross_entropy(model_parallel_size) + model_parallel_size *= 2 diff --git a/examples/Megatron-LM/mpu/tests/test_data.py b/examples/Megatron-LM/mpu/tests/test_data.py new file mode 100644 index 0000000..6e8eca7 --- /dev/null +++ b/examples/Megatron-LM/mpu/tests/test_data.py @@ -0,0 +1,92 @@ +# coding=utf-8 +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import operator +import sys +sys.path.append("../..") + +import torch +import mpu +from mpu import data as data_utils + +from commons import initialize_distributed +from commons import print_separator + + +def test_boradcast_data(model_parallel_size): + + if torch.distributed.get_rank() == 0: + print('> testing boradcast_data with model parallel size {} ...'. + format(model_parallel_size)) + + mpu.initialize_model_parallel(model_parallel_size) + torch.manual_seed(1234 + mpu.get_data_parallel_rank()) + model_parallel_size = mpu.get_model_parallel_world_size() + + key_size_t = {'key1': [7, 11], + 'key2': [8, 2, 1], + 'key3': [13], + 'key4': [5, 1, 2], + 'key5': [5, 12]} + keys = list(key_size_t.keys()) + + data = {} + data_t = {} + for key in key_size_t: + data[key] = torch.LongTensor(size=key_size_t[key]).random_(0, 1000) + data_t[key] = data[key].clone() + data['keyX'] = torch.FloatTensor(size=(5, )).random_(0, 1000) + data_t['keyX'] = data['keyX'].clone() + if mpu.get_model_parallel_rank() != 0: + data = None + + data_utils._check_data_types(keys, data_t, torch.int64) + key_size, key_numel, \ + total_numel = data_utils._build_key_size_numel_dictionaries(keys, data) + for key in keys: + assert key_size[key] == key_size_t[key] + total_numel_t = 0 + for key in keys: + target_size = functools.reduce(operator.mul, key_size_t[key], 1) + assert key_numel[key] == target_size + total_numel_t += target_size + assert total_numel == total_numel_t + + data_b = data_utils.broadcast_data(keys, data, torch.int64) + for key in keys: + tensor = data_t[key].cuda() + assert data_b[key].sub(tensor).abs().max() == 0 + + # Reset groups + mpu.destroy_model_parallel() + + torch.distributed.barrier() + if torch.distributed.get_rank() == 0: + print('>> passed the test :-)') + + +if __name__ == '__main__': + + initialize_distributed() + world_size = torch.distributed.get_world_size() + + model_parallel_size = 1 + while model_parallel_size <= world_size: + print_separator('test test boradcast data') + test_boradcast_data(model_parallel_size) + model_parallel_size *= 2 + + diff --git a/examples/Megatron-LM/mpu/tests/test_initialize.py b/examples/Megatron-LM/mpu/tests/test_initialize.py new file mode 100644 index 0000000..c77e2e6 --- /dev/null +++ b/examples/Megatron-LM/mpu/tests/test_initialize.py @@ -0,0 +1,98 @@ +# coding=utf-8 +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +sys.path.append("../..") + +import torch +import mpu + +from commons import initialize_distributed +from commons import print_separator + + +def test_initialize_model_parallel(model_parallel_size): + + if torch.distributed.get_rank() == 0: + print('> testing initialize_model_parallel with size {} ...'.format( + model_parallel_size)) + model_parallel_size_ = min(model_parallel_size, + torch.distributed.get_world_size()) + assert not mpu.model_parallel_is_initialized() + mpu.initialize_model_parallel(model_parallel_size_) + assert mpu.model_parallel_is_initialized() + + # Checks. + def check(group, world_size, rank): + assert world_size == torch.distributed.get_world_size(group=group) + assert rank == torch.distributed.get_rank(group=group) + + # Model parallel. + world_size = model_parallel_size_ + rank = torch.distributed.get_rank() % model_parallel_size_ + assert world_size == mpu.get_model_parallel_world_size() + assert rank == mpu.get_model_parallel_rank() + check(mpu.get_model_parallel_group(), world_size, rank) + + + # Data parallel. + world_size = torch.distributed.get_world_size() // model_parallel_size_ + rank = torch.distributed.get_rank() // model_parallel_size + assert world_size == mpu.get_data_parallel_world_size() + assert rank == mpu.get_data_parallel_rank() + check(mpu.get_data_parallel_group(), world_size, rank) + + # Reset groups + mpu.destroy_model_parallel() + + torch.distributed.barrier() + if torch.distributed.get_rank() == 0: + print('>> passed the test :-)') + + +def test_get_model_parallel_src_rank(model_parallel_size_): + + if torch.distributed.get_rank() == 0: + print('> testing get_model_parallel_src_rank with size {} ...'.format( + model_parallel_size_)) + model_parallel_size = min(model_parallel_size_, + torch.distributed.get_world_size()) + assert not mpu.model_parallel_is_initialized() + mpu.initialize_model_parallel(model_parallel_size) + assert mpu.model_parallel_is_initialized() + + # Checks + src_rank = torch.distributed.get_rank() - mpu.get_model_parallel_rank() + assert mpu.get_model_parallel_src_rank() == src_rank + + # Reset groups + mpu.destroy_model_parallel() + + torch.distributed.barrier() + if torch.distributed.get_rank() == 0: + print('>> passed the test :-)') + + +if __name__ == '__main__': + + initialize_distributed() + world_size = torch.distributed.get_world_size() + model_parallel_size = 1 + while model_parallel_size <= world_size: + print_separator('test initialize model parallel') + test_initialize_model_parallel(model_parallel_size) + print_separator('test model parallel source rank') + test_get_model_parallel_src_rank(model_parallel_size) + model_parallel_size *= 2 diff --git a/examples/Megatron-LM/mpu/tests/test_layers.py b/examples/Megatron-LM/mpu/tests/test_layers.py new file mode 100644 index 0000000..c38bf72 --- /dev/null +++ b/examples/Megatron-LM/mpu/tests/test_layers.py @@ -0,0 +1,529 @@ +# coding=utf-8 +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import sys +sys.path.append("../..") + +import torch +import torch.nn.init as init +from torch.nn.parameter import Parameter +import mpu + +from commons import initialize_distributed +from commons import print_separator +from commons import set_random_seed +from mpu import layers + + +def test_parallel_embedding(model_parallel_size): + + if torch.distributed.get_rank() == 0: + print('> testing parallel embedding with model parallel size {} ...'. + format(model_parallel_size)) + + mpu.initialize_model_parallel(model_parallel_size) + model_parallel_size = mpu.get_model_parallel_world_size() + + batch_size = 17 + seq_length = 23 + vocab_size = 48 + hidden_size = 16 + seed = 1236 + + set_random_seed(123) + input_data = torch.LongTensor( + size=(batch_size,seq_length)).random_(0, vocab_size).cuda() + loss_weight = torch.randn([batch_size, seq_length, hidden_size]).cuda() + + set_random_seed(seed) + embedding_original = torch.nn.Embedding(vocab_size, hidden_size).cuda() + + output = embedding_original(input_data) + loss_original = torch.mul(output, loss_weight).sum() + loss_original.backward() + + set_random_seed(seed) + embedding_parallel = layers.ParallelEmbedding( + vocab_size, hidden_size, init_method=init.normal_).cuda() + output = embedding_parallel(input_data) + loss_parallel = torch.mul(output, loss_weight).sum() + loss_parallel.backward() + + set_random_seed(seed) + embedding_vocab_parallel = layers.VocabParallelEmbedding( + vocab_size, hidden_size, init_method=init.normal_).cuda() + output = embedding_vocab_parallel(input_data) + loss_vocab_parallel = torch.mul(output, loss_weight).sum() + loss_vocab_parallel.backward() + + torch.distributed.barrier() + error = loss_parallel.sub(loss_original).abs() + print(' error in loss (parallel) on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 1.0e-12, 'error: {}'.format(error) + + torch.distributed.barrier() + error = loss_vocab_parallel.sub(loss_original).abs() + print(' error in loss (vocab parallel) on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 1.0e-12, 'error: {}'.format(error) + + weight_grad_orig = torch.split(embedding_original.weight.grad, + hidden_size // model_parallel_size, + 1)[mpu.get_model_parallel_rank()] + error = embedding_parallel.weight.grad.sub(weight_grad_orig).abs().max() + print(' error in grad (parallel) on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 1.0e-12, 'error: {}'.format(error) + + weight_grad_orig = torch.split(embedding_original.weight.grad, + vocab_size // model_parallel_size, + 0)[mpu.get_model_parallel_rank()] + error = embedding_vocab_parallel.weight.grad.sub( + weight_grad_orig).abs().max() + print(' error in grad (vocab parallel) on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 1.0e-12, 'error: {}'.format(error) + + # Reset groups + mpu.destroy_model_parallel() + + torch.distributed.barrier() + if torch.distributed.get_rank() == 0: + print('>> passed the test :-)') + + +def test_initialize_affine_weight(model_parallel_size): + + mpu.initialize_model_parallel(model_parallel_size) + if torch.distributed.get_rank() == 0: + print('> testing initialize_affine_weight with model parallel ' + 'size: {}'.format(model_parallel_size)) + model_parallel_size = mpu.get_model_parallel_world_size() + + seed = 12345 + input_size_coeff = 13 + input_size = input_size_coeff * model_parallel_size + output_size_coeff = 17 + output_size = output_size_coeff * model_parallel_size + + # --------------- + # Column parallel + # --------------- + weight = torch.empty(output_size_coeff, input_size) + set_random_seed(seed) + layers._initialize_affine_weight(weight, output_size, input_size, + + output_size_coeff, 0, + torch.nn.init.normal_) + # Target. + set_random_seed(seed) + master_weight = torch.empty(output_size, input_size) + torch.nn.init.normal_(master_weight) + rank = mpu.get_model_parallel_rank() + my_weight = torch.split(master_weight, output_size_coeff, + dim=0)[rank].contiguous().clone() + + # Compare. + error = weight.sub(my_weight).abs().max() + torch.distributed.barrier() + print(' column parallel max error (should be zero) on global rank ' + '{}: {}'.format(torch.distributed.get_rank(), error)) + assert error < 1.0e-6 + + # ------------ + # Row parallel + # ------------ + weight = torch.empty(output_size, input_size_coeff) + set_random_seed(seed) + mpu.layers._initialize_affine_weight(weight, output_size, input_size, + input_size_coeff, 1, + torch.nn.init.normal_) + # Target. + set_random_seed(seed) + master_weight = torch.empty(output_size, input_size) + torch.nn.init.normal_(master_weight) + rank = mpu.get_model_parallel_rank() + my_weight = torch.split(master_weight, input_size_coeff, + dim=1)[rank].contiguous().clone() + + # Compare. + error = weight.sub(my_weight).abs().max() + torch.distributed.barrier() + print(' row parallel max error (should be zero) on global rank ' + '{}: {}'.format(torch.distributed.get_rank(), error)) + assert error < 1.0e-6 + + # Reset groups + mpu.destroy_model_parallel() + + torch.distributed.barrier() + if torch.distributed.get_rank() == 0: + print(' >> passed the test :-)') + + +class IdentityLayer2D(torch.nn.Module): + def __init__(self, m , n): + super(IdentityLayer2D, self).__init__() + self.weight = Parameter(torch.Tensor(m, n)) + torch.nn.init.xavier_normal_(self.weight) + def forward(self): + return self.weight + + +def test_column_parallel_linear(model_parallel_size): + + mpu.initialize_model_parallel(model_parallel_size) + if torch.distributed.get_rank() == 0: + print('> testing ColumnParallelLinear with model parallel ' + 'size: {}'.format(model_parallel_size)) + model_parallel_size = mpu.get_model_parallel_world_size() + + seed = 12345 + set_random_seed(seed) + input_size_coeff = 13 + input_size = input_size_coeff * model_parallel_size + output_size_coeff = 17 + output_size = output_size_coeff * model_parallel_size + batch_size = 7 + + # Network + identity_layer = IdentityLayer2D(batch_size, input_size).cuda() + linear_layer = mpu.ColumnParallelLinear( + input_size, output_size, keep_master_weight_for_test=True).cuda() + loss_weight = torch.randn([batch_size, output_size]).cuda() + # Forward + input_ = identity_layer() + output = linear_layer(input_) + loss = torch.mul(output, loss_weight).sum() + # Backward + loss.backward() + + # Values. + dLdY = loss_weight + X = identity_layer.weight + A = linear_layer.master_weight.cuda() + dLdA = torch.matmul(dLdY.t(), X) + dLdb = torch.matmul(torch.ones(batch_size, 1).cuda().t(), dLdY).view(-1) + dLdX = torch.matmul(dLdY, A) + + rank = mpu.get_model_parallel_rank() + my_dLdA = torch.split(dLdA, output_size_coeff, + dim=0)[rank].contiguous().clone() + error = my_dLdA.sub(linear_layer.weight.grad).abs().max() + torch.distributed.barrier() + print(' error in dLdA on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 1.0e-6 + + my_dLdb = torch.split(dLdb, output_size_coeff, + dim=0)[rank].contiguous().clone() + error = my_dLdb.sub(linear_layer.bias.grad).abs().max() + torch.distributed.barrier() + print(' error in dLdb on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 1.0e-6 + + error = dLdX.sub(identity_layer.weight.grad).abs().max() + torch.distributed.barrier() + print(' error in dLdX on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 1.0e-6 + + # Reset groups + mpu.destroy_model_parallel() + + torch.distributed.barrier() + if torch.distributed.get_rank() == 0: + print(' >> passed the test :-)') + + +def test_row_parallel_linear(model_parallel_size): + + mpu.initialize_model_parallel(model_parallel_size) + if torch.distributed.get_rank() == 0: + print('> testing RowParallelLinear with model parallel ' + 'size: {}'.format(model_parallel_size)) + model_parallel_size = mpu.get_model_parallel_world_size() + + seed = 12345 + set_random_seed(seed) + input_size_coeff = 13 + input_size = input_size_coeff * model_parallel_size + output_size_coeff = 17 + output_size = output_size_coeff * model_parallel_size + batch_size = 7 + + # Network + identity_layer = IdentityLayer2D(batch_size, input_size).cuda() + linear_layer = mpu.RowParallelLinear( + input_size, output_size, keep_master_weight_for_test=True).cuda() + loss_weight = torch.randn([batch_size, output_size]).cuda() + # Forward + input_ = identity_layer() + output = linear_layer(input_) + loss = torch.mul(output, loss_weight).sum() + # Backward + loss.backward() + + # Values. + dLdY = loss_weight + X = identity_layer.weight + A = linear_layer.master_weight.cuda() + dLdA = torch.matmul(dLdY.t(), X) + dLdb = torch.matmul(torch.ones(batch_size, 1).cuda().t(), dLdY).view(-1) + dLdX = torch.matmul(dLdY, A) + + rank = mpu.get_model_parallel_rank() + my_dLdA = torch.split(dLdA, input_size_coeff, + dim=1)[rank].contiguous().clone() + error = my_dLdA.sub(linear_layer.weight.grad).abs().max() + torch.distributed.barrier() + print(' error in dLdA on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 1.0e-6 + + error = dLdb.sub(linear_layer.bias.grad).abs().max() + torch.distributed.barrier() + print(' error in dLdb on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 1.0e-6 + + error = dLdX.sub(identity_layer.weight.grad).abs().max() + torch.distributed.barrier() + print(' error in dLdX on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 1.0e-6 + + # Reset groups + mpu.destroy_model_parallel() + + torch.distributed.barrier() + if torch.distributed.get_rank() == 0: + print(' >> passed the test :-)') + + +class IdentityLayer3D(torch.nn.Module): + def __init__(self, m , n, k): + super(IdentityLayer3D, self).__init__() + self.weight = Parameter(torch.Tensor(m, n, k)) + torch.nn.init.xavier_normal_(self.weight) + def forward(self): + return self.weight + + +def parallel_self_attention(model_parallel_size, num_att_heads_per_partition, + hidden_size_per_att_head, dropout_prob, batch_size, + sequence_length): + mpu.initialize_model_parallel(model_parallel_size) + model_parallel_size = mpu.get_model_parallel_world_size() + + seed = 12345 + set_random_seed(seed) + + num_att_heads = num_att_heads_per_partition * \ + torch.distributed.get_world_size() + hidden_size = hidden_size_per_att_head * num_att_heads + + # Network + identity_layer = IdentityLayer3D(batch_size, sequence_length, + hidden_size).cuda() + attention_layer = mpu.BertParallelSelfAttention(hidden_size, num_att_heads, + dropout_prob).cuda() + loss_weight = torch.randn([batch_size, sequence_length, hidden_size]).cuda() + attention_mask = torch.randn([batch_size, 1, 1, sequence_length]).cuda() + # Forward + input_ = identity_layer() + output = attention_layer(input_, attention_mask) + loss = torch.mul(output, loss_weight).sum() + # Backward + loss.backward() + + rank = mpu.get_model_parallel_rank() + mpu.destroy_model_parallel() + return rank, hidden_size, model_parallel_size, loss, \ + attention_layer, identity_layer + + +def test_parallel_self_attention(model_parallel_size): + + if torch.distributed.get_rank() == 0: + print('> testing ParallelSelfAttention with model parallel ' + 'size: {}'.format(model_parallel_size)) + + num_att_heads_per_partition = 3 + hidden_size_per_att_head = 7 + dropout_prob = 0.0 # has to be zero + batch_size = 5 + sequence_length = 13 + + rank_1, hideen_size_1, model_parallel_size_1, loss_1, \ + attention_layer_1, identity_layer_1 =parallel_self_attention( + 1, num_att_heads_per_partition, + hidden_size_per_att_head, dropout_prob, batch_size, sequence_length) + + rank, hidden_size, model_parallel_size, loss, \ + attention_layer, identity_layer =parallel_self_attention( + model_parallel_size, num_att_heads_per_partition, + hidden_size_per_att_head, dropout_prob, batch_size, sequence_length) + assert hideen_size_1 == hidden_size + + error = loss_1.sub(loss).abs().max() + torch.distributed.barrier() + print(' loss error on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 5.0e-6 + + my_lin_grad_list = torch.split( + attention_layer_1.query_key_value.weight.grad, + hidden_size // model_parallel_size, 0)[rank::model_parallel_size] + my_lin_grad = torch.cat(my_lin_grad_list, dim=0) + error = my_lin_grad.sub( + attention_layer.query_key_value.weight.grad).abs().max() + torch.distributed.barrier() + print(' weight gradient error on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 5.0e-6 + + error = identity_layer_1.weight.grad.sub( + identity_layer.weight.grad).abs().max() + torch.distributed.barrier() + print(' input gradient error on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 5.0e-6 + + torch.distributed.barrier() + if torch.distributed.get_rank() == 0: + print(' >> passed the test :-)') + +def parallel_transformer(model_parallel_size, num_att_heads_per_partition, + hidden_size_per_att_head, batch_size, sequence_length): + + mpu.initialize_model_parallel(model_parallel_size) + model_parallel_size = mpu.get_model_parallel_world_size() + + seed = 12345 + set_random_seed(seed) + + num_att_heads = num_att_heads_per_partition * \ + torch.distributed.get_world_size() + hidden_size = hidden_size_per_att_head * num_att_heads + intermediate_size = 4 * hidden_size + + # Network + identity_layer = IdentityLayer3D(batch_size, sequence_length, + hidden_size).cuda() + transformer_layer = mpu.BertParallelTransformerLayer( + hidden_size, intermediate_size, num_att_heads, 0.0, 0.0, + torch.nn.functional.relu, 1.0e-5).cuda() + + loss_weight = torch.randn([batch_size, sequence_length, hidden_size]).cuda() + attention_mask = torch.randn([batch_size, 1, 1, sequence_length]).cuda() + # Forward + input_ = identity_layer() + output = transformer_layer(input_, attention_mask) + loss = torch.mul(output, loss_weight).sum() + # Backward + loss.backward() + + rank = mpu.get_model_parallel_rank() + mpu.destroy_model_parallel() + return rank, hidden_size, model_parallel_size, loss, \ + transformer_layer, identity_layer + + +def test_parallel_transformer_layer(model_parallel_size): + + if torch.distributed.get_rank() == 0: + print('> testing ParallelTransformerLayer with model parallel ' + 'size: {}'.format(model_parallel_size)) + + num_att_heads_per_partition = 3 + hidden_size_per_att_head = 7 + batch_size = 5 + sequence_length = 13 + + rank_1, hidden_size_1, model_parallel_size_1, loss_1, \ + transformer_layer_1, identity_layer_1 = parallel_transformer( + 1, num_att_heads_per_partition, + hidden_size_per_att_head, batch_size, sequence_length) + + rank, hidden_size, model_parallel_size, loss, \ + transformer_layer, identity_layer = parallel_transformer( + model_parallel_size, num_att_heads_per_partition, + hidden_size_per_att_head, batch_size, sequence_length) + + error = loss_1.sub(loss).abs().max() + torch.distributed.barrier() + print(' loss error on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 5.0e-5, 'error: {}'.format(error) + + error = identity_layer_1.weight.grad.sub( + identity_layer.weight.grad).abs().max() + torch.distributed.barrier() + print(' input gradient error on global rank {}: {}'.format( + torch.distributed.get_rank(), error)) + assert error < 5.0e-5, 'error: {}'.format(error) + + torch.distributed.barrier() + if torch.distributed.get_rank() == 0: + print(' >> passed the test :-)') + + +if __name__ == '__main__': + + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + initialize_distributed() + world_size = torch.distributed.get_world_size() + + print_separator('test initialize affine weight') + model_parallel_size = 1 + while model_parallel_size <= world_size: + test_initialize_affine_weight(model_parallel_size) + model_parallel_size *= 2 + + model_parallel_size = 1 + while model_parallel_size <= world_size: + print_separator('test parallel embedding') + test_parallel_embedding(model_parallel_size) + model_parallel_size *= 2 + + print_separator('test column-parallel linear') + model_parallel_size = 1 + while model_parallel_size <= world_size: + test_column_parallel_linear(model_parallel_size) + model_parallel_size *= 2 + + print_separator('test row-parallel linear') + model_parallel_size = 1 + while model_parallel_size <= world_size: + test_row_parallel_linear(model_parallel_size) + model_parallel_size *= 2 + + print_separator('test parallel self-attention') + model_parallel_size = 1 + while model_parallel_size <= world_size: + test_parallel_self_attention(model_parallel_size) + model_parallel_size *= 2 + + print_separator('test parallel transformer') + model_parallel_size = 1 + while model_parallel_size <= world_size: + test_parallel_transformer_layer(model_parallel_size) + model_parallel_size *= 2 diff --git a/examples/Megatron-LM/mpu/tests/test_random.py b/examples/Megatron-LM/mpu/tests/test_random.py new file mode 100644 index 0000000..7e719a7 --- /dev/null +++ b/examples/Megatron-LM/mpu/tests/test_random.py @@ -0,0 +1,208 @@ +# coding=utf-8 +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +sys.path.append("../..") + +import torch +from deepspeed.accelerator.real_accelerator import get_accelerator +import mpu + +from commons import initialize_distributed +from commons import print_separator + + +def test_set_cuda_rng_state(model_parallel_size): + + if torch.distributed.get_rank() == 0: + print('> testing set_rng_state with size {} ...'. + format(model_parallel_size)) + + mpu.initialize_model_parallel(model_parallel_size) + model_parallel_size = mpu.get_model_parallel_world_size() + + size = 123 + seed = 1234 + get_accelerator().manual_seed(1234) + tensor = get_accelerator().FloatTensor(size) + + # Get the state + rng_state = get_accelerator().get_rng_state() + rng_state_copy = rng_state.clone() + + # Do some stuff. + for _ in range(5): + torch.randn(size, out=tensor) + result_1 = tensor.clone() + + assert rng_state.sub(rng_state_copy).max() == 0 + assert get_accelerator().get_rng_state().sub(rng_state_copy).max() > 0 + + # State should be different. + new_rng_state = get_accelerator().get_rng_state() + max_diff = new_rng_state.sub(rng_state).max() + print(' max diff in rng state (should be non-zero) on global rank {}: {}'. + format(torch.distributed.get_rank(), max_diff)) + assert max_diff > 0 + + # Reset the rng state and do the same stuff. + mpu.random._set_cuda_rng_state(rng_state) + for _ in range(5): + torch.randn(size, out=tensor) + mpu.random._set_cuda_rng_state(rng_state) + for _ in range(5): + torch.randn(size, out=tensor) + result_2 = tensor.clone() + + # Results should be the same + error = result_2.sub(result_1).abs().max() + print(' max error in generated tensors (should be zero) on ' + 'global rank {}: {}'.format(torch.distributed.get_rank(), error)) + assert error < 1.0e-6 + + # Input state should have remained intact. + error = rng_state.sub(rng_state_copy).max() + print(' max error in rng state (should be zero) on global rank {}: {}'. + format(torch.distributed.get_rank(), error)) + assert error == 0 + + # Reset groups + mpu.destroy_model_parallel() + + torch.distributed.barrier() + if torch.distributed.get_rank() == 0: + print('>> passed the test :-)') + + +def test_cuda_rng_tracker(model_parallel_size): + + if torch.distributed.get_rank() == 0: + print('> testing cuda rng tracker with size {} ...'. + format(model_parallel_size)) + + mpu.initialize_model_parallel(model_parallel_size) + model_parallel_size = mpu.get_model_parallel_world_size() + + seed_1 = 1234 + seed_2 = 4321 + size = [12, 21] + tensor = get_accelerator().FloatTensor(size) + + # Set to seed_1 and generate two tensors. + get_accelerator().manual_seed(seed_1) + torch.randn(size, out=tensor) + target_11 = tensor.clone() + torch.randn(size, out=tensor) + target_12 = tensor.clone() + + # Set to seed_2 and generate two tensors. + get_accelerator().manual_seed(seed_2) + torch.randn(size, out=tensor) + target_21 = tensor.clone() + torch.randn(size, out=tensor) + target_22 = tensor.clone() + + # Now if we interleave seed_1 and seed_2, + # we should still get the same tensors + get_accelerator().manual_seed(seed_1) + mpu.get_cuda_rng_tracker().add('test', seed_2) + + torch.randn(size, out=tensor) + result_11 = tensor.clone() + + with mpu.get_cuda_rng_tracker().fork('test'): + torch.randn(size, out=tensor) + result_21 = tensor.clone() + + torch.randn(size, out=tensor) + result_12 = tensor.clone() + + with mpu.get_cuda_rng_tracker().fork('test'): + torch.randn(size, out=tensor) + result_22 = tensor.clone() + + diff = result_11.sub(result_21).abs().max() + diff = min(diff, result_12.sub(result_22).abs().max()) + print(' max diff in generated tensors (should be non-zero) on ' + 'global rank {}: {}'.format(torch.distributed.get_rank(), diff)) + assert diff > 1.0e-6 + error = max(result_11.sub(target_11).abs().max(), + result_12.sub(target_12).abs().max()) + error = max(error, result_21.sub(target_21).abs().max()) + error = max(error, result_22.sub(target_22).abs().max()) + print(' max error in generated tensors (should be zero) on ' + 'global rank {}: {}'.format(torch.distributed.get_rank(), error)) + assert error < 1.0e-6 + + # Reset the tracker + mpu.get_cuda_rng_tracker().reset() + + # Reset groups + mpu.destroy_model_parallel() + + torch.distributed.barrier() + if torch.distributed.get_rank() == 0: + print('>> passed the test :-)') + + +def test_model_parallel_cuda_manual_seed(model_parallel_size): + + if torch.distributed.get_rank() == 0: + print('> testing model parallel cuda manual seed with size {} ...'. + format(model_parallel_size)) + + mpu.initialize_model_parallel(model_parallel_size) + model_parallel_size = mpu.get_model_parallel_world_size() + + mpu.model_parallel_cuda_manual_seed(12345) + assert get_accelerator().initial_seed() == 12345 + with mpu.get_cuda_rng_tracker().fork(): + assert get_accelerator().initial_seed() == (12345 + 2718 + + mpu.get_model_parallel_rank()) + + # Reset the tracker + mpu.get_cuda_rng_tracker().reset() + + # Reset groups + mpu.destroy_model_parallel() + + torch.distributed.barrier() + if torch.distributed.get_rank() == 0: + print('>> passed the test :-)') + + +if __name__ == '__main__': + + initialize_distributed() + world_size = torch.distributed.get_world_size() + + model_parallel_size = 1 + while model_parallel_size <= world_size: + print_separator('test set rng state') + test_set_cuda_rng_state(model_parallel_size) + model_parallel_size *= 2 + + model_parallel_size = 1 + while model_parallel_size <= world_size: + print_separator('test cuda rng tracker') + test_cuda_rng_tracker(model_parallel_size) + model_parallel_size *= 2 + + model_parallel_size = 1 + while model_parallel_size <= world_size: + print_separator('test model parallel cuda manual seed') + test_model_parallel_cuda_manual_seed(model_parallel_size) + model_parallel_size *= 2 + diff --git a/examples/Megatron-LM/mpu/transformer.py b/examples/Megatron-LM/mpu/transformer.py new file mode 100755 index 0000000..ce6917b --- /dev/null +++ b/examples/Megatron-LM/mpu/transformer.py @@ -0,0 +1,650 @@ +# coding=utf-8 +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Transformer.""" + +import math + +import torch +import torch.nn.init as init + +from .initialize import get_model_parallel_world_size +from .layers import ColumnParallelLinear +from .layers import RowParallelLinear +from .mappings import gather_from_model_parallel_region + +import deepspeed +from deepspeed.accelerator.real_accelerator import get_accelerator + +if get_accelerator().device_name() == 'cuda': + from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm +else: + from torch.nn import LayerNorm + +from .random import checkpoint +from .random import get_cuda_rng_tracker + +from .utils import divide +from .utils import split_tensor_along_last_dim + + +class GPT2ParallelSelfAttention(torch.nn.Module): + """Parallel self-attention layer for GPT2. + + Self-attention layer takes input with size [b, s, h] where b is + the batch size, s is the sequence lenght, and h is the hidden size + and creates output of the same size. + Arguments: + hidden_size: total hidden size of the layer (h). + num_attention_heads: number of attention heads (n). Note that we + require n to be divisible by number of GPUs + used to parallelize the model. Also, we + require hidden size to be divisible by n. + dropout_prob: dropout probability for the attention scores. + init_method: weight initialization. + output_layer_init_method: output layer initialization. If None, use + `init_method`. + We use the following notation: + h: hidden_size + n: num_attention_heads + p: number of partitions + np: n/p + hp: h/p + hn: h/n + b: batch size + s: sequence length + """ + def __init__(self, hidden_size, num_attention_heads, + attention_dropout_prob, output_dropout_prob, + init_method, output_layer_init_method=None): + super(GPT2ParallelSelfAttention, self).__init__() + # Set output layer initialization if not provided. + if output_layer_init_method is None: + output_layer_init_method = init_method + # Per attention head and per partition values. + world_size = get_model_parallel_world_size() + self.hidden_size_per_partition = divide(hidden_size, world_size) + self.hidden_size_per_attention_head = divide(hidden_size, + num_attention_heads) + self.num_attention_heads_per_partition = divide(num_attention_heads, + world_size) + # Strided linear layer. + self.query_key_value = ColumnParallelLinear(hidden_size, 3*hidden_size, + stride=3, + gather_output=False, + init_method=init_method) + # Dropout. Note that for a single iteration, this layer will generate + # different outputs on different number of parallel partitions but + # on average it should not be partition dependent. + self.attention_dropout = torch.nn.Dropout(attention_dropout_prob) + + # Output. + self.dense = RowParallelLinear(hidden_size, + hidden_size, + input_is_parallel=True, + init_method=output_layer_init_method) + self.output_dropout = torch.nn.Dropout(output_dropout_prob) + + if deepspeed.checkpointing.is_configured(): + global get_cuda_rng_tracker, checkpoint + get_cuda_rng_tracker = deepspeed.checkpointing.get_cuda_rng_tracker + checkpoint = deepspeed.checkpointing.checkpoint + + + def _transpose_for_scores(self, tensor): + """Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with + size [b, np, s, hn]. + """ + new_tensor_shape = tensor.size()[:-1] + \ + (self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head) + tensor = tensor.view(*new_tensor_shape) + return tensor.permute(0, 2, 1, 3) + + def forward(self, hidden_states, ltor_mask): + # hidden_states: [b, s, h] + # ltor_mask: [1, 1, s, s] + + # Attention heads. [b, s, hp] + mixed_x_layer = self.query_key_value(hidden_states) + (mixed_query_layer, + mixed_key_layer, + mixed_value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3) + + # Reshape and transpose [b, np, s, hn] + query_layer = self._transpose_for_scores(mixed_query_layer) + key_layer = self._transpose_for_scores(mixed_key_layer) + value_layer = self._transpose_for_scores(mixed_value_layer) + + # Raw attention scores. [b, np, s, s] + attention_scores = torch.matmul(query_layer, + key_layer.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt( + self.hidden_size_per_attention_head) + # Apply the left to right attention mask. + attention_scores = torch.mul(attention_scores, ltor_mask) - \ + 10000.0 * (1.0 - ltor_mask) + + # Attention probabilities. [b, np, s, s] + attention_probs = torch.nn.Softmax(dim=-1)(attention_scores) + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + with get_cuda_rng_tracker().fork(): + attention_probs = self.attention_dropout(attention_probs) + + # Context layer. + # [b, np, s, hn] + context_layer = torch.matmul(attention_probs, value_layer) + # [b, s, np, hn] + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + \ + (self.hidden_size_per_partition,) + # [b, s, hp] + context_layer = context_layer.view(*new_context_layer_shape) + + # Output. [b, s, h] + output = self.dense(context_layer) + output = self.output_dropout(output) + + return output + + +@torch.jit.script +def gelu_impl(x): + """OpenAI's gelu implementation.""" + return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * + (1.0 + 0.044715 * x * x))) + +def gelu(x): + # TODO: check if it impacts the convergence + # return gelu_impl(x) + return torch.nn.functional.gelu(x) + + +class GPT2ParallelMLP(torch.nn.Module): + """MLP for GPT2. + + MLP will take the input with h hidden state, project it to 4*h + hidden dimension, perform gelu transformation, and project the + state back into h hidden dimension. At the end, dropout is also + applied. + + Arguments: + hidden_size: The hidden size of the self attention. + output_dropout_prob: dropout probability for the outputs + after self attention and final output. + init_method: initialization method used for the weights. Note + that all biases are initialized to zero and + layernorm weight are initialized to one. + output_layer_init_method: output layer initialization. If None, + use `init_method`. + """ + + def __init__(self, hidden_size, output_dropout_prob, init_method, + output_layer_init_method=None): + super(GPT2ParallelMLP, self).__init__() + # Set output layer initialization if not provided. + if output_layer_init_method is None: + output_layer_init_method = init_method + # Project to 4h. + self.dense_h_to_4h = ColumnParallelLinear(hidden_size, 4*hidden_size, + gather_output=False, + init_method=init_method) + # Project back to h. + self.dense_4h_to_h = RowParallelLinear( + 4*hidden_size, + hidden_size, + input_is_parallel=True, + init_method=output_layer_init_method) + self.dropout = torch.nn.Dropout(output_dropout_prob) + + def forward(self, hidden_states): + # [b, s, 4hp] + intermediate_parallel = self.dense_h_to_4h(hidden_states) + intermediate_parallel = gelu(intermediate_parallel) + + # [b, s, h] + output = self.dense_4h_to_h(intermediate_parallel) + output = self.dropout(output) + return output + + +class GPT2ParallelTransformerLayer(torch.nn.Module): + """A single layer transformer for GPT2. + + We use the following notation: + h: hidden size + n: number of attention heads + b: batch size + s: sequence length + Transformore layer takes input with size [b, s, h] and returns an + output of the same size. + + Arguments: + hidden_size: The hidden size of the self attention. + num_attention_heads: number of attention head in the self + attention. + attention_dropout_prob: dropout probability of the attention + score in self attention. + output_dropout_prob: dropout probability for the outputs + after self attention and final output. + layernorm_epsilon: epsilon used in layernorm to avoid + division by zero. + init_method: initialization method used for the weights. Note + that all biases are initialized to zero and + layernorm weight are initialized to one. + output_layer_init_method: output layers (attention output and + mlp output) initialization. If None, + use `init_method`. + """ + def __init__(self, + hidden_size, + num_attention_heads, + attention_dropout_prob, + output_dropout_prob, + layernorm_epsilon, + init_method, + output_layer_init_method=None): + super(GPT2ParallelTransformerLayer, self).__init__() + # Set output layer initialization if not provided. + if output_layer_init_method is None: + output_layer_init_method = init_method + + # Layernorm on the input data. + self.input_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon) + + # Self attention. + self.attention = GPT2ParallelSelfAttention( + hidden_size, + num_attention_heads, + attention_dropout_prob, + output_dropout_prob, + init_method, + output_layer_init_method=output_layer_init_method) + + # Layernorm on the input data. + self.post_attention_layernorm = LayerNorm(hidden_size, + eps=layernorm_epsilon) + + # MLP + self.mlp = GPT2ParallelMLP( + hidden_size, + output_dropout_prob, + init_method, + output_layer_init_method=output_layer_init_method) + + def forward(self, hidden_states, ltor_mask): + # hidden_states: [b, s, h] + # ltor_mask: [1, 1, s, s] + + # Layer norm at the begining of the transformer layer. + layernorm_output = self.input_layernorm(hidden_states) + # Self attention. + attention_output = self.attention(layernorm_output, ltor_mask) + # Residual connection. + layernorm_input = hidden_states + attention_output + # Layer norm post the self attention. + layernorm_output = self.post_attention_layernorm(layernorm_input) + # MLP. + mlp_output = self.mlp(layernorm_output) + # Second residual connection. + output = layernorm_input + mlp_output + + return output + + +def unscaled_init_method(sigma): + """Init method based on N(0, sigma).""" + def init_(tensor): + return torch.nn.init.normal_(tensor, mean=0.0, std=sigma) + + return init_ + + +def scaled_init_method(sigma, num_layers): + """Init method based on N(0, sigma/sqrt(2*num_layers).""" + std = sigma / math.sqrt(2.0 * num_layers) + def init_(tensor): + return torch.nn.init.normal_(tensor, mean=0.0, std=std) + + return init_ + + +class GPT2ParallelTransformer(torch.nn.Module): + """GPT-2 transformer. + + This module takes input from embedding layer and it's output can + be used directly by a logit layer. It consists of L (num-layers) + blocks of: + layer norm + self attention + residual connection + layer norm + mlp + residual connection + followed by a final layer norm. + + Arguments: + num_layers: Number of transformer layers. + hidden_size: The hidden size of the self attention. + num_attention_heads: number of attention head in the self + attention. + attention_dropout_prob: dropout probability of the attention + score in self attention. + output_dropout_prob: dropout probability for the outputs + after self attention and final output. + checkpoint_activations: if True, checkpoint activations. + checkpoint_num_layers: number of layers to checkpoint. This + is basically the chunk size in checkpoitning. + layernorm_epsilon: epsilon used in layernorm to avoid + division by zero. + init_method_std: standard deviation of the init method which has + the form N(0, std). + use_scaled_init_for_output_weights: If Ture use 1/sqrt(2*num_layers) + scaling for the output weights ( + output of self attention and mlp). + """ + def __init__(self, + num_layers, + hidden_size, + num_attention_heads, + attention_dropout_prob, + output_dropout_prob, + checkpoint_activations, + checkpoint_num_layers=1, + layernorm_epsilon=1.0e-5, + init_method_std=0.02, + use_scaled_init_for_output_weights=True): + super(GPT2ParallelTransformer, self).__init__() + # Store activation checkpoiting flag. + self.checkpoint_activations = checkpoint_activations + self.checkpoint_num_layers = checkpoint_num_layers + + output_layer_init_method = None + if use_scaled_init_for_output_weights: + output_layer_init_method = scaled_init_method(init_method_std, + num_layers) + def get_layer(): + return GPT2ParallelTransformerLayer( + hidden_size, + num_attention_heads, + attention_dropout_prob, + output_dropout_prob, + layernorm_epsilon, + unscaled_init_method(init_method_std), + output_layer_init_method=output_layer_init_method) + + # Transformer layers. + self.layers = torch.nn.ModuleList( + [get_layer() for _ in range(num_layers)]) + + # Final layer norm before output. + self.final_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon) + + if deepspeed.checkpointing.is_configured(): + global get_cuda_rng_tracker, checkpoint + get_cuda_rng_tracker = deepspeed.checkpointing.get_cuda_rng_tracker + checkpoint = deepspeed.checkpointing.checkpoint + + + def forward(self, hidden_states, attention_mask): + + def custom(start, end): + def custom_forward(*inputs): + layers_ = self.layers[start:end] + x_ = inputs[0] + for layer in layers_: + x_ = layer(x_, inputs[1]) + return x_ + return custom_forward + + if self.checkpoint_activations: + l = 0 + num_layers = len(self.layers) + chunk_length = self.checkpoint_num_layers + while l < num_layers: + old_mask = attention_mask.clone() + hidden_states = checkpoint(custom(l, l+chunk_length), + hidden_states, attention_mask) + attention_mask = old_mask + l += chunk_length + else: + for layer in self.layers: + hidden_states = layer(hidden_states, attention_mask) + + # Final layer norm. + output = self.final_layernorm(hidden_states) + + return output + + +class BertParallelSelfAttention(torch.nn.Module): + """Parallel self-attention layer for BERT. + + Self-attention layer takes input with size [b, s, h] where b is + the batch size, s is the sequence lenght, and h is the hidden size + and creates output of the same size. + Arguments: + hidden_size: total hidden size of the layer (h). + num_attention_heads: number of attention heads (n). Note that we + require n to be divisible by number of GPUs + used to parallelize the model. Also, we + require hidden size be divisible by n. + dropout_prob: dropout probability for the attention scores. + output_parallel: If true, no all-gather is done on the output and + the output values will be per partition. + We use the following notation: + h: hidden_size + n: num_attention_heads + p: number of partitions + np: n/p + hp: h/p + hn: h/n + b: batch size + s: sequence length + """ + def __init__(self, hidden_size, num_attention_heads, + dropout_prob, output_parallel=False, + init_method=init.xavier_normal_): + super(BertParallelSelfAttention, self).__init__() + # Input configuration. + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.dropout_prob = dropout_prob + self.output_parallel = output_parallel + # Per attention head and per partition values. + world_size = get_model_parallel_world_size() + self.hidden_size_per_partition = divide(hidden_size, world_size) + self.hidden_size_per_attention_head = divide(hidden_size, + num_attention_heads) + self.num_attention_heads_per_partition = divide(num_attention_heads, + world_size) + # Strided linear layer. + self.query_key_value = ColumnParallelLinear(hidden_size, 3*hidden_size, + stride=3, + gather_output=False, + init_method=init_method) + # Dropout. Note that for a single iteration, this layer will generate + # different outputs on different number of parallel partitions but + # on average it should not be partition dependent. + self.dropout = torch.nn.Dropout(dropout_prob) + + if deepspeed.checkpointing.is_configured(): + global get_cuda_rng_tracker, checkpoint + get_cuda_rng_tracker = deepspeed.checkpointing.get_cuda_rng_tracker + checkpoint = deepspeed.checkpointing.checkpoint + + + def _transpose_for_scores(self, tensor): + """Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with + size [b, np, s, hn]. + """ + new_tensor_shape = tensor.size()[:-1] + \ + (self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head) + tensor = tensor.view(*new_tensor_shape) + return tensor.permute(0, 2, 1, 3) + + def forward(self, hidden_states, attention_mask): + + # Attention heads. [b, s, hp] + mixed_x_layer = self.query_key_value(hidden_states) + (mixed_query_layer, + mixed_key_layer, + mixed_value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3) + + # Reshape and transpose [b, np, s, hn] + query_layer = self._transpose_for_scores(mixed_query_layer) + key_layer = self._transpose_for_scores(mixed_key_layer) + value_layer = self._transpose_for_scores(mixed_value_layer) + + # Raw attention scores. [b, np, s, s] + norm_factor = math.sqrt(math.sqrt(self.hidden_size_per_attention_head)) + attention_scores = torch.matmul(query_layer/norm_factor, + key_layer.transpose(-1, -2)/norm_factor) + # Apply the attention mask. + attention_scores += attention_mask + + # Attention probabilities. [b, np, s, s] + attention_probs = torch.nn.Softmax(dim=-1)(attention_scores) + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + with get_cuda_rng_tracker().fork(): + attention_probs = self.dropout(attention_probs) + + # Context layer. + # [b, np, s, hn] + context_layer = torch.matmul(attention_probs, value_layer) + # [b, s, np, hn] + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + \ + (self.hidden_size_per_partition,) + # [b, s, hp] + context_layer = context_layer.view(*new_context_layer_shape) + + # Output. [b, s, h] + if self.output_parallel: + output = context_layer + else: + output = gather_from_model_parallel_region(context_layer) + + return output + + +class BertParallelTransformerOutput(torch.nn.Module): + """The output layer used after self attention and intermediate + parts of transformer layer.""" + def __init__(self, input_size, output_size, dropout_prob, + layernorm_epsilon=1.0e-12, input_is_parallel=False, + init_method=init.xavier_normal_): + super(BertParallelTransformerOutput, self).__init__() + # Components. + self.dense = RowParallelLinear(input_size, + output_size, + input_is_parallel=input_is_parallel, + init_method=init_method) + self.dropout = torch.nn.Dropout(dropout_prob) + self.layernorm = LayerNorm(output_size, eps=layernorm_epsilon) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + layernorm_input = hidden_states + input_tensor + hidden_states = self.layernorm(layernorm_input) + return hidden_states + + +class BertParallelTransformerLayer(torch.nn.Module): + """A single layer transformer for Bert. + + We use the following notation: + h: hidden size + n: number of attention heads + b: batch size + s: sequence length + Transformore layer takes input with size [b, s, h] and returns an + output of the same size. + + Arguments: + hidden_size: The hidden size of the self attention. + intermediate_size: size of the intermediate state after + self attention. In both BERT and GPT + this is set to be 4 times the hidden + size. + num_attention_heads: number of attention head in the self + attention. + attention_dropout_prob: dropout probability of the attention + score in self attention. + output_dropout_prob: dropout probability for the outputs + after self attention and final output. + intermediate_activation_fn: activation function for output + of intermediate. + layernorm_epsilon: epsilon used in layernorm to avoid + division by zero. + init_method: initialization method used for the weights. Note + that all biases are initialized to zero and + layernorm weight are initialized to one. + """ + def __init__(self, + hidden_size, + intermediate_size, + num_attention_heads, + attention_dropout_prob, + output_dropout_prob, + intermediate_activation_fn, + layernorm_epsilon, + init_method=init.xavier_normal_): + super(BertParallelTransformerLayer, self).__init__() + + # Self attention. + self.attention = BertParallelSelfAttention(hidden_size, + num_attention_heads, + attention_dropout_prob, + output_parallel=True, + init_method=init_method) + # Self attention output. + self.self_output = BertParallelTransformerOutput( + hidden_size, hidden_size, output_dropout_prob, + layernorm_epsilon=layernorm_epsilon, + input_is_parallel=True, + init_method=init_method) + # Intermediate. + self.intermediate = ColumnParallelLinear(hidden_size, intermediate_size, + gather_output=False, + init_method=init_method) + self.intermediate_activation_fn = intermediate_activation_fn + # Output. + self.output = BertParallelTransformerOutput( + intermediate_size, hidden_size, output_dropout_prob, + layernorm_epsilon=layernorm_epsilon, + input_is_parallel=True, + init_method=init_method) + + def forward(self, hidden_states, attention_mask): + # [b, s, hp] + attention_output_parallel = self.attention(hidden_states, + attention_mask) + # [b, s, h] + attention_self_output = self.self_output(attention_output_parallel, + hidden_states) + # [b, s, ip] + intermediate_output_parallel = self.intermediate(attention_self_output) + intermediate_output_parallel = self.intermediate_activation_fn( + intermediate_output_parallel) + # [b, s, h] + layer_output = self.output(intermediate_output_parallel, + attention_self_output) + + return layer_output diff --git a/examples/Megatron-LM/mpu/utils.py b/examples/Megatron-LM/mpu/utils.py new file mode 100644 index 0000000..94afafd --- /dev/null +++ b/examples/Megatron-LM/mpu/utils.py @@ -0,0 +1,70 @@ +# coding=utf-8 +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch + + +def ensure_divisibility(numerator, denominator): + """Ensure that numerator is divisible by the denominator.""" + assert numerator % denominator == 0, '{} is not divisible by {}'.format( + numerator, denominator) + + +def divide(numerator, denominator): + """Ensure that numerator is divisible by the denominator and return + the division value.""" + ensure_divisibility(numerator, denominator) + return numerator // denominator + + +def split_tensor_along_last_dim(tensor, num_partitions, + contiguous_split_chunks=False): + """Split a tensor along its last dimension. + Arguments: + tensor: input tensor. + num_partitions: number of partitions to split the tensor + contiguous_split_chunks: If True, make each chunk contiguous + in memory. + """ + # Get the size and dimension. + last_dim = tensor.dim() - 1 + last_dim_size = divide(tensor.size()[last_dim], num_partitions) + # Split. + tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) + # Note: torch.split does not create contiguous tensors by default. + if contiguous_split_chunks: + return tuple(chunk.contiguous() for chunk in tensor_list) + + return tensor_list + + +class VocabUtility: + """Split the vocabulary into `world_size` chunks amd return the + first and last index of the vocabulary belonging to the `rank` + partition: Note that indecies in [fist, last)""" + + @staticmethod + def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, + rank, world_size): + index_f = rank * per_partition_vocab_size + index_l = index_f + per_partition_vocab_size + return index_f, index_l + + @staticmethod + def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size): + per_partition_vocab_size = divide(global_vocab_size, world_size) + return VocabUtility.vocab_range_from_per_partition_vocab_size( + per_partition_vocab_size, rank, world_size) diff --git a/examples/Megatron-LM/openwebtext/README.md b/examples/Megatron-LM/openwebtext/README.md new file mode 100644 index 0000000..db55e10 --- /dev/null +++ b/examples/Megatron-LM/openwebtext/README.md @@ -0,0 +1,46 @@ +The following steps show how to prepare training dataset to train the mode. + +# Libraries to install + +``` + pip install ftfy langdetect numpy torch pandas nltk sentencepiece boto3 tqdm regex bs4 newspaper3k htmlmin tldextract + git clone https://github.com/mattilyra/LSH + cd LSH + python setup.py install +``` + +# Download the dataset + +1. Download the deduplicated URLs from [jcpeterson](https://mega.nz/#F!EZZD0YwJ!9_PlEQzdMVLaNdKv_ICNVQ!cc4RgQQZ) +2. Remove blacklisted URLs. +``` +python blacklist_urls.py +``` +3. Download the content from the clean urls with [openwebtext's utilities](https://github.com/eukaryote31/openwebtext/blob/master/download.py). + +4. Merge the contents into one loose json file with 1 json per newline of the format `{'text': text, 'url': unique_url}`. It is important for the url to be unique. + +# Prepare the data for GPT-2 training: + +1. Perform ftfy, english detection and remove documents with less than 128 tokens. This step can be sharded and run on shards. +``` +python cleanup_dataset.py +``` +2. Using LSH, find possible duplicates and store then in a file for later processing. This step can NOT be sharded and usually takes 12 to 24 hours for OpenWebText dataset. +``` +python find_duplicates.py +``` +3. Based on similarity measure defind inside function `is_similar` (default: 0.9), group urls that are similar. Basically, for each group, only one url we should keep and remove the rest. +``` +python group_duplicate_urls.py +``` +4. Remove similar documents that were detected in the last step. +``` +python remove_group_duplicates.py +``` + +5. Shuffle the dataset. +``` +shuf -o train_data.json +``` + diff --git a/examples/Megatron-LM/openwebtext/blacklist_urls.py b/examples/Megatron-LM/openwebtext/blacklist_urls.py new file mode 100644 index 0000000..89fa59c --- /dev/null +++ b/examples/Megatron-LM/openwebtext/blacklist_urls.py @@ -0,0 +1,312 @@ +# coding=utf-8 +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import glob +import re +import time +import tldextract +import sys + + +# List of the domains to blacklist. +domain_blacklist = set([ + '500px', + 'aapks', + 'akamaihd', + 'amazon', + 'apple', + 'artifactfire', + 'artstation', + 'awwni', + 'bandcamp', + 'battleforthenet', + 'coinscalendar', + 'dailymotion', + 'deviantart', + 'discord', + 'discordapp', + 'dlapkandroid', + 'dropbox', + 'e621', + 'ebay', + 'edealinfo', + 'erome', + 'eroshare', + 'explosm', + 'facebook', + 'fbcdn', + 'flickr', + 'furaffinity', + 'futhead', + 'gatopardo', + 'gfycat', + 'gifsound', + 'gifsoup', + 'giphy', + 'github', + 'google', + 'gunprime', + 'gyazo', + 'hotdealstar', + 'imagefap', + 'imageshack', + 'imgflip', + 'imgur', + 'instagram', + 'karmadecay', + 'kryptocal', + 'kym-cdn', + 'liveleak', + 'livememe', + 'lmgtfy', + 'magaimg', + 'memegenerator', + 'minorplanetcenter', + 'minus', + 'mobafire', + 'morejpeg', + 'nocookie', + 'pcpartpicker', + 'photobucket', + 'pinimg', + 'pinterest', + 'pixiv', + 'pornhub', + 'prntscr', + 'puu', + 'qkme', + 'quickmeme', + 'radd', + 'redd', + 'reddit', + 'reddit-stream', + 'redditlog', + 'redditmedia', + 'reddituploads', + 'redtube', + 'reupp', + 'reverb', + 'roanoke', + 'rollingstone', + 'sli', + 'soundcloud', + 'soundgasm', + 'spankbang', + 'spotify', + 'strawpoll', + 'streamable', + 'timeanddate', + 'tinypic', + 'touhouradio', + 'tumblr', + 'twimg', + 'twitch', + 'twitter', + 'vid', + 'vimeo', + 'vine', + 'vkaao', + 'vocaroo', + 'voyagefusion', + 'walmart', + 'wciu', + 'wikimedia', + 'wikipedia', + 'xhamster', + 'xkcd', + 'xvideos', + 'youtu', + 'youtube', + 'youtubedoubler', + 'ytimg', + 'zillexplorer', +]) + +def domain_is_in_blacklist(url): + domain = tldextract.extract(url).domain + return domain in domain_blacklist + + +# List of extentions to blacklist. +extentions_blacklist = ( + '.3gp', + '.7z' + '.ai', + '.aif', + '.apk', + '.app', + '.avi', + '.bin', + '.bmp', + '.bz2', + '.css', + '.csv', + '.dat', + '.deb', + '.dmg', + '.doc', + '.docx', + '.exe', + '.gif', + '.gifv', + '.gz', + '.iso', + '.jar', + '.jpeg', + '.jpg', + '.js', + '.log', + '.mid', + '.midi', + '.mkv', + '.mov', + '.mp3', + '.mp4', + '.mpeg', + '.mpg', + '.ogg', + '.ogv', + '.otf', + '.pdf', + '.pkg', + '.png', + '.pps', + '.ppt', + '.pptx', + '.psd', + '.py', + '.qt', + '.ram', + '.rar', + '.sql', + '.svg', + '.swf', + '.tar.gz', + '.tar', + '.tgz', + '.tiff', + '.ttf', + '.txt', + '.wav', + '.webm', + '.wma', + '.wmv', + '.xls', + '.xlsx', + '.xml', + '.xz', + '.zip', +) + +def extention_is_in_blacklist(url): + if url.split('?')[0].lower().endswith(extentions_blacklist): + return True + return False + + +# Malformed urls. +# This function is adapted from: +# https://stackoverflow.com/questions/7160737/python-how-to-validate-a-url-in-python-malformed-or-not +url_regex = re.compile( + r'^(?:http)s?://' # http:// or https:// + r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|' #domain... + r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})' # ...or ip + r'(?::\d+)?' # optional port + r'(?:/?|[/?]\S+)$', re.IGNORECASE) +def url_is_malformed(url): + return re.match(url_regex, url) is None + + +def print_progress(prefix, start_time, urls_counter, + domain_blacklist_counter, + extention_blacklist_counter, + short_url_counter, malformed_url_counter, + duplicate_url_counter): + string = prefix + ' | ' + string += 'time elapsed (s): {:.2f} | '.format(time.time() - start_time) + string += 'number of urls: {} | '.format(urls_counter) + string += 'domain blacklisted: {} | '.format(domain_blacklist_counter) + string += 'extention blacklisted: {} | '.format(extention_blacklist_counter) + string += 'short urls (<=8): {} | '.format(short_url_counter) + string += 'malformed urls: {} | '.format(malformed_url_counter) + string += 'duplicate urls: {}'.format(duplicate_url_counter) + print(string, flush=True) + + +if __name__ == '__main__': + + + print('remove blacklisted urls ..') + + # Path to the url files. + path = sys.argv[1] + # Output url file. + output = sys.argv[2] + + # Get the list of url files. + files = glob.glob(path + '/*.txt') + print('> found {} files'.format(len(files))) + + urls = set() + urls_counter = 0 + domain_blacklist_counter = 0 + extention_blacklist_counter = 0 + short_url_counter = 0 + malformed_url_counter = 0 + duplicate_url_counter = 0 + start_time = time.time() + for filename in files: + with open(filename, 'r') as f: + for line in f: + url = line.strip() + urls_counter += 1 + if domain_is_in_blacklist(url): + print('[DOMAIN BLACKLIST]: {}'.format(url), flush=True) + domain_blacklist_counter += 1 + elif extention_is_in_blacklist(url): + print('[EXTENTION BLACKLIST]: {}'.format(url), flush=True) + extention_blacklist_counter += 1 + elif len(url) <= 8: + print('[SHORT URL]: {}'.format(url), flush=True) + short_url_counter += 1 + elif url_is_malformed(url): + print('[MALFORMED URL]: {}'.format(url), flush=True) + malformed_url_counter += 1 + elif url in urls: + print('[DUPLICATE URL]: {}'.format(url), flush=True) + duplicate_url_counter += 1 + else: + urls.add(url) + if urls_counter % 100000 == 0: + print_progress('PROGRESS', start_time, urls_counter, + domain_blacklist_counter, + extention_blacklist_counter, + short_url_counter, malformed_url_counter, + duplicate_url_counter) + + print_progress('FINAL', start_time, urls_counter, + domain_blacklist_counter, + extention_blacklist_counter, + short_url_counter, malformed_url_counter, + duplicate_url_counter) + + # Write the final set of urls. + print('> writing cleaned up url list to {}'.format(output)) + with open(output, 'w') as f: + for url in urls: + f.write(url + '\n') + + print('done :-)') diff --git a/examples/Megatron-LM/openwebtext/cleanup_dataset.py b/examples/Megatron-LM/openwebtext/cleanup_dataset.py new file mode 100644 index 0000000..ea418b8 --- /dev/null +++ b/examples/Megatron-LM/openwebtext/cleanup_dataset.py @@ -0,0 +1,115 @@ +# coding=utf-8 +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import ftfy +import json +from langdetect import detect +import numpy as np +import time +import os +import sys + +from tokenizer import Tokenizer + +MIN_DOCUMENT_LENGHT = 128 + + +def print_progress(prefix, start_time, num_docs, num_fixed_text, + num_non_english_docs, chars_non_english_docs, + num_small_docs, chars_small_docs): + + string = prefix + ' | ' + string += 'elapsed time: {:.2f} | '.format(time.time() - start_time) + string += 'documents: {} | '.format(num_docs) + string += 'fixed text: {} | '.format(num_fixed_text) + string += 'non-english: {} | '.format(num_non_english_docs) + string += 'non-english chars: {} | '.format(chars_non_english_docs) + string += 'small docs: {} | '.format(num_small_docs) + string += 'small docs chars: {}'.format(chars_small_docs) + print(string, flush=True) + + +def filter_corpus(filename, out_filename, print_interval=10000): + + print(' > filtering {}'.format(filename)) + + tokenizer = Tokenizer(cache_dir='./cache') + + num_docs = 0 + num_written_docs = 0 + num_small_docs = 0 + num_fixed_text = 0 + num_non_english_docs = 0 + chars_non_english_docs = 0 + chars_small_docs = 0 + start_time = time.time() + with open(out_filename, 'wb') as f: + with open(filename, 'r') as fin: + for line in fin: + try: + num_docs += 1 + myjson = json.loads(line) + # Fix text + text = ftfy.fix_text(myjson['text']) + if text != myjson['text']: + num_fixed_text += 1 + myjson['text'] = text + # Detect language. + if detect(text) != 'en': + print('[non-english text]', myjson) + num_non_english_docs += 1 + chars_non_english_docs += len(text) + continue + # On average each token is 5 characters so 8 is an + # upper bound. + if len(text) < (8 * MIN_DOCUMENT_LENGHT): + tokens = tokenizer.tokenize_document(text) + if len(tokens) < MIN_DOCUMENT_LENGHT: + print('[small document, skipping]:', myjson) + num_small_docs += 1 + chars_small_docs += len(text) + continue + myjson = json.dumps(myjson, ensure_ascii=False) + f.write(myjson.encode('utf-8')) + f.write('\n'.encode('utf-8')) + num_written_docs += 1 + if num_docs % print_interval == 0: + print_progress('[PROGRESS]', start_time, num_docs, + num_fixed_text, num_non_english_docs, + chars_non_english_docs, + num_small_docs, chars_small_docs) + except Exception as e: + print(' skipping ', line, e) + + print_progress('[FINAL]', start_time, num_docs, + num_fixed_text, num_non_english_docs, + chars_non_english_docs, + num_small_docs, chars_small_docs) + + +if __name__ == '__main__': + + print('building gpt2 dataset ...') + + input_filename = sys.argv[1] + output_filename = sys.argv[2] + + print('will be reading {}'.format(input_filename)) + print('and will write the results to {}'.format(output_filename)) + + filter_corpus(input_filename, output_filename) + + diff --git a/examples/Megatron-LM/openwebtext/find_duplicates.py b/examples/Megatron-LM/openwebtext/find_duplicates.py new file mode 100644 index 0000000..6e20c08 --- /dev/null +++ b/examples/Megatron-LM/openwebtext/find_duplicates.py @@ -0,0 +1,100 @@ +# coding=utf-8 +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import itertools +import json +from lsh import cache, minhash +import time +import sys + + +# This function is adapted from: +# https://github.com/mattilyra/LSH/blob/master/examples/Introduction.ipynb +def shingles(text, char_ngram=5): + return set(text[head:head + char_ngram] + for head in range(0, len(text) - char_ngram)) + + +# This function is adapted from: +# https://github.com/mattilyra/LSH/blob/master/examples/Introduction.ipynb +def jaccard(set_a, set_b): + intersection = set_a & set_b + union = set_a | set_b + return len(intersection) / len(union) + + +if __name__ == '__main__': + + print('finding possible duplicate content ...') + + input = sys.argv[1] + output = sys.argv[2] + + hasher = minhash.MinHasher(seeds=100, char_ngram=5, hashbytes=4) + lshcache = cache.Cache(bands=10, hasher=hasher) + + counter = 0 + url_doc = {} + start_time = time.time() + with open(input, 'r') as f: + for line in f: + try: + myjson = json.loads(line) + url = myjson['url'] + text = myjson['text'] + counter += 1 + url_doc[url] = text + lshcache.add_fingerprint(hasher.fingerprint(text), url) + except Exception as e: + print('Error:', e) + if counter % 10000 == 0: + print(' [read]> processed {} documents in {:.2f} seconds ...'. + format(counter, time.time() - start_time), flush=True) + + counter = 0 + start_time = time.time() + deduped = 0 + with open(output, 'wb') as f: + for b in lshcache.bins: + for bucket_id in b: + if len(b[bucket_id]) > 1: + items = list(b[bucket_id]) + main_url = items[0] + main_dhingles = shingles(url_doc[main_url]) + remove_urls = [] + for i in range(1, len(items)): + counter += 1 + other_url= items[i] + other_shingles = shingles(url_doc[other_url]) + try: + jaccard_sim = jaccard(main_dhingles, other_shingles) + except Exception as e: + print('Error:', e) + if jaccard_sim > 0.5: + remove_urls.append({other_url: jaccard_sim}) + deduped += 1 + if counter % 10000 == 0: + print(' [write]> processed {} documents in {:.2f} ' + 'seoncds and deduped {} documents ...'. + format(counter, time.time() - start_time, + deduped), flush=True) + if len(remove_urls) > 0: + myjson = json.dumps({main_url: remove_urls}, + ensure_ascii=False) + f.write(myjson.encode('utf-8')) + f.write('\n'.encode('utf-8')) + + print('done :-)') diff --git a/examples/Megatron-LM/openwebtext/group_duplicates_url.py b/examples/Megatron-LM/openwebtext/group_duplicates_url.py new file mode 100644 index 0000000..0381f47 --- /dev/null +++ b/examples/Megatron-LM/openwebtext/group_duplicates_url.py @@ -0,0 +1,90 @@ +# coding=utf-8 +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import time +import sys + + +def is_similar(jaccard_similarity): + return (js >= 0.9) + + +if __name__ == '__main__': + + + print('grouping duplicate urls ...') + + input = sys.argv[1] + output = sys.argv[2] + + url_to_index = {} + index_to_urls = [] + counter = 0 + start_time = time.time() + with open(input, 'r') as f: + for line in f: + counter += 1 + myjson = json.loads(line) + urls = [] + for main_url in myjson.keys(): + urls.append(main_url) + for value in myjson[main_url]: + for other_url, js in value.items(): + if is_similar(js): + urls.append(other_url) + current_index = -1 + other_indices = set() + for url in urls: + if url in url_to_index: + if current_index == -1: + current_index = url_to_index[url] + elif current_index != url_to_index[url]: + other_indices.add(url_to_index[url]) + if current_index == -1: + current_index = len(index_to_urls) + index_to_urls.append(set()) + for url in urls: + url_to_index[url] = current_index + index_to_urls[current_index].add(url) + for index in other_indices: + for url in index_to_urls[index]: + index_to_urls[current_index].add(url) + url_to_index[url] = current_index + index_to_urls[index] = None + + if counter % 100000 == 0: + print(' > processed {} lines in {} seconds ...'.format( + counter, time.time() - start_time)) + + + total_remove = 0 + total_remain = 0 + for urls in index_to_urls: + if urls is not None: + if len(urls) > 1: + total_remove += (len(urls) - 1) + total_remain += 1 + print('out of {} urls, only {} are unique and {} should be removed'.format( + total_remove+total_remain, total_remain, total_remove)) + + with open(output, 'wb') as f: + for i, urls in enumerate(index_to_urls): + if urls is not None: + if len(urls) > 1: + myjson = json.dumps({str(i): list(urls)}, + ensure_ascii=False) + f.write(myjson.encode('utf-8')) + f.write('\n'.encode('utf-8')) diff --git a/examples/Megatron-LM/openwebtext/make_gpt2_dataset.py b/examples/Megatron-LM/openwebtext/make_gpt2_dataset.py new file mode 100644 index 0000000..48b57e8 --- /dev/null +++ b/examples/Megatron-LM/openwebtext/make_gpt2_dataset.py @@ -0,0 +1,77 @@ +# coding=utf-8 +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import json +import numpy as np +import time +import os +import sys + +from tokenizer import Tokenizer + + +def tokenize_corpus(filename, np_filename, print_interval=10000): + + print(' > tokenizing {}'.format(filename)) + + tokenizer = Tokenizer(cache_dir='./cache') + + tokenized_docs = [] + num_docs = 0 + num_tokens = 0 + start_time = time.time() + with open(filename, 'r') as f: + for line in f: + try: + myjson = json.loads(line) + url = myjson['url'] + sample = myjson['text'] + tokens = tokenizer.tokenize_document(sample) + tokenized_docs.append(np.array(tokens, dtype=np.uint16)) + num_docs += 1 + num_tokens += len(tokens) + if num_docs % print_interval == 0: + print(' processed {:9d} documents in {:.2f} (s) so far'. + format(num_docs, time.time() - start_time), + flush=True) + except Exception as e: + print(' skipping ', line, e) + + print(' >> processed {} document with total of {} tokens ...'.format( + num_docs, num_tokens)) + + tokenized_docs = np.array(tokenized_docs, dtype=object) + np.save(np_filename, tokenized_docs, allow_pickle=True) + print(' >> saved the tokenzed document to {} ...'.format(np_filename)) + + +if __name__ == '__main__': + + print('building gpt2 dataset ...') + + path = sys.argv[1] + shard = sys.argv[2] + + input_filename = os.path.join(path, + 'shards/shard_{:04d}'.format(int(shard))) + output_filename = os.path.join(path, + 'npys/shard_{:04d}.npy'.format(int(shard))) + print('will be reading {}'.format(input_filename)) + print('and will write the results to {}'.format(output_filename)) + + tokenize_corpus(input_filename, output_filename) + + diff --git a/examples/Megatron-LM/openwebtext/make_gpt2_sizes.py b/examples/Megatron-LM/openwebtext/make_gpt2_sizes.py new file mode 100644 index 0000000..9d77749 --- /dev/null +++ b/examples/Megatron-LM/openwebtext/make_gpt2_sizes.py @@ -0,0 +1,38 @@ + +import glob +import json +import os +import time +import sys + +import numpy as np + + +if __name__ == '__main__': + + print('building the shard sizes ...') + + path = sys.argv[1] + print('> reading numpy files from {}'.format(path)) + + npy_files = glob.glob(path + '/*.npy') + npy_files.sort() + print(' found {} numpy files'.format(len(npy_files))) + + size_dict = {} + counter = 0 + start_time = time.time() + for filename in npy_files: + data = np.load(filename, allow_pickle=True) + size = np.hstack(data).size + np_filename = os.path.basename(filename) + size_dict[np_filename] = size + counter += 1 + if counter % 10 == 0: + print(' processed {} files in {:.2f} seconds'.format( + counter, time.time() - start_time)) + + output_filename = os.path.join(path, 'sizes.txt') + with open(output_filename, 'w') as f: + json.dump(size_dict, f) + print('> wrote sizes to {}'.format(output_filename)) diff --git a/examples/Megatron-LM/openwebtext/merge_jsons.py b/examples/Megatron-LM/openwebtext/merge_jsons.py new file mode 100644 index 0000000..6cec66d --- /dev/null +++ b/examples/Megatron-LM/openwebtext/merge_jsons.py @@ -0,0 +1,55 @@ +# coding=utf-8 +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import glob +import sys +import json +import argparse + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + parser.add_argument("--json_path", type=str, default=".", + help="path where all the json files are located") + + parser.add_argument("--output_file", type=str, default="merged_output.json", + help="filename where the merged json should go") + + args = parser.parse_args() + + json_path = args.json_path + out_file = args.output_file + + json_files = glob.glob(json_path + '/*.json') + + counter = 0 + + with open(out_file, 'w') as outfile: + for fname in json_files: + counter += 1 + + if counter % 1024 == 0: + print("Merging at ", counter, flush=True) + + with open(fname, 'r') as infile: + for row in infile: + each_row = json.loads(row) + outfile.write(row) + + + print("Merged file", out_file, flush=True) + + diff --git a/examples/Megatron-LM/openwebtext/remove_group_duplicates.py b/examples/Megatron-LM/openwebtext/remove_group_duplicates.py new file mode 100644 index 0000000..8784809 --- /dev/null +++ b/examples/Megatron-LM/openwebtext/remove_group_duplicates.py @@ -0,0 +1,69 @@ +# coding=utf-8 +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import json +import time +import sys + + +if __name__ == '__main__': + + url_filename = sys.argv[1] + data_filename = sys.argv[2] + output_filename = sys.argv[3] + + urls = set() + with open(url_filename, 'r') as f: + for line in f: + myjson = json.loads(line) + for key in myjson: + this_urls = myjson[key] + for i in range(1, len(this_urls)): + urls.add(this_urls[i]) + print('will be removing {} urls'.format(len(urls)), flush=True) + + written_docs = 0 + removed_docs = 0 + removed_chars = 0 + start_time = time.time() + with open(output_filename, 'wb') as fout: + with open(data_filename, 'r') as fin: + for line in fin: + try: + myjson = json.loads(line) + url = myjson['url'] + if url in urls: + print('removing', myjson) + removed_docs += 1 + removed_chars += len(myjson['text']) + continue + myjson = json.dumps(myjson, ensure_ascii=False) + fout.write(myjson.encode('utf-8')) + fout.write('\n'.encode('utf-8')) + written_docs += 1 + if written_docs % 10000 == 0: + print(' [PROCESSED] time (s): {:.2f} | written: {} ' + '| removed: {} (char: {})'.format( + time.time() - start_time, + written_docs, removed_docs, removed_chars)) + except Exception as e: + print('[SKIPPING]', line, e) + + print(' [PROCESSED] time (s): {:.2f} | written: {} ' + '| removed: {} (char: {})'.format( + time.time() - start_time, + written_docs, removed_docs, removed_chars)) + print('done :-)') diff --git a/examples/Megatron-LM/openwebtext/run_make_gpt2_dataset.sh b/examples/Megatron-LM/openwebtext/run_make_gpt2_dataset.sh new file mode 100755 index 0000000..7afd480 --- /dev/null +++ b/examples/Megatron-LM/openwebtext/run_make_gpt2_dataset.sh @@ -0,0 +1,8 @@ +#!/bin/bash + +echo "processing gpt2 data ..." +DIR="/raid/mpatwary/redownload_v0/0-21" + +for thread in {0..3}; do + echo " launching thread "$thread && python make_gpt2_dataset.py $DIR $thread > $DIR/logs/shard_$thread.log 2>&1 & +done diff --git a/examples/Megatron-LM/openwebtext/tokenizer.py b/examples/Megatron-LM/openwebtext/tokenizer.py new file mode 100644 index 0000000..d38306f --- /dev/null +++ b/examples/Megatron-LM/openwebtext/tokenizer.py @@ -0,0 +1,36 @@ +# coding=utf-8 +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +sys.path.append('..') + +from data_utils.tokenization_gpt2 import GPT2Tokenizer + + +class Tokenizer: + + def __init__(self, cache_dir=None): + self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2', + cache_dir=cache_dir) + self.tokenizer.max_len = int(1e12) + self.eod_token = self.tokenizer.encoder['<|endoftext|>'] + assert self.eod_token < 65535, 'vocab size will not fit in uint16' + print('> GPT2 tokenizer with {} vocab size and eod token {} ...'.format( + len(self.tokenizer.encoder), self.eod_token)) + + def tokenize_document(self, document): + tokens = self.tokenizer.encode(document) + tokens.append(self.eod_token) + return tokens diff --git a/examples/Megatron-LM/pretrain_bert.py b/examples/Megatron-LM/pretrain_bert.py new file mode 100755 index 0000000..f942f30 --- /dev/null +++ b/examples/Megatron-LM/pretrain_bert.py @@ -0,0 +1,586 @@ +# coding=utf-8 +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Pretrain BERT""" + +# Flag to use Pytorch ddp which uses overlapping communication and computation. +USE_TORCH_DDP = False + +from datetime import datetime +import os +import random +import numpy as np +import torch +import torch.nn.functional as F +from deepspeed.accelerator.real_accelerator import get_accelerator + +from arguments import get_args +from configure_data import configure_data +from fp16 import FP16_Module +from fp16 import FP16_Optimizer +from learning_rates import AnnealingLR +from model import BertModel +from model import get_params_for_weight_decay_optimization +from model import gpt2_get_params_for_weight_decay_optimization +if USE_TORCH_DDP: + from torch.nn.parallel.distributed import DistributedDataParallel as DDP +else: + from model import DistributedDataParallel as DDP +import mpu +from deepspeed.accelerator.real_accelerator import get_accelerator +if get_accelerator().device_name() == 'cuda': + from apex.optimizers import FusedAdam as Adam +else: + from torch.optim import Adam +from utils import Timers +from utils import save_checkpoint +from utils import load_checkpoint +from utils import report_memory +from utils import print_args +from utils import print_params_min_max_norm +from utils import print_rank_0 + + +def get_model(args): + """Build the model.""" + + print_rank_0('building BERT model ...') + model = BertModel(args) + + if mpu.get_data_parallel_rank() == 0: + print(' > number of parameters on model parallel rank {}: {}'.format( + mpu.get_model_parallel_rank(), + sum([p.nelement() for p in model.parameters()])), flush=True) + + # GPU allocation. + model.to(torch.device(get_accelerator().current_device_name())) + + # Fp16 conversion. + if args.fp16: + model = FP16_Module(model) + if args.fp32_embedding: + model.module.model.bert.embeddings.word_embeddings.float() + model.module.model.bert.embeddings.position_embeddings.float() + model.module.model.bert.embeddings.token_type_embeddings.float() + if args.fp32_tokentypes: + model.module.model.bert.embeddings.token_type_embeddings.float() + if args.fp32_layernorm: + for name, _module in model.named_modules(): + if 'LayerNorm' in name: + _module.float() + + # Wrap model for distributed training. + if USE_TORCH_DDP: + i = torch.device(get_accelerator().current_device_name()) + model = DDP(model, device_ids=[i], output_device=i, + process_group=mpu.get_data_parallel_group()) + else: + model = DDP(model) + + return model + + +def get_optimizer(model, args): + """Set up the optimizer.""" + + # Build parameter groups (weight decay and non-decay). + while isinstance(model, (DDP, FP16_Module)): + model = model.module + layers = model.model.bert.encoder.layer + pooler = model.model.bert.pooler + lmheads = model.model.cls.predictions + nspheads = model.model.cls.seq_relationship + embeddings = model.model.bert.embeddings + param_groups = [] + param_groups += list(get_params_for_weight_decay_optimization(layers)) + param_groups += list(get_params_for_weight_decay_optimization(pooler)) + param_groups += list(get_params_for_weight_decay_optimization(nspheads)) + param_groups += list(get_params_for_weight_decay_optimization(embeddings)) + param_groups += list(get_params_for_weight_decay_optimization( + lmheads.transform)) + param_groups[1]['params'].append(lmheads.bias) + + # Add model parallel attribute if it is not set. + for param_group in param_groups: + for param in param_group['params']: + if not hasattr(param, 'model_parallel'): + param.model_parallel = False + + # Use Adam. + optimizer = Adam(param_groups, + lr=args.lr, weight_decay=args.weight_decay) + + # Wrap into fp16 optimizer. + if args.fp16: + optimizer = FP16_Optimizer(optimizer, + static_loss_scale=args.loss_scale, + dynamic_loss_scale=args.dynamic_loss_scale, + dynamic_loss_args={ + 'scale_window': args.loss_scale_window, + 'min_scale':args.min_scale, + 'delayed_shift': args.hysteresis}) + + return optimizer + + +def get_learning_rate_scheduler(optimizer, args): + """Build the learning rate scheduler.""" + + # Add linear learning rate scheduler. + if args.lr_decay_iters is not None: + num_iters = args.lr_decay_iters + else: + num_iters = args.train_iters + init_step = -1 + warmup_iter = args.warmup * num_iters + lr_scheduler = AnnealingLR(optimizer, + start_lr=args.lr, + warmup_iter=warmup_iter, + num_iters=num_iters, + decay_style=args.lr_decay_style, + last_iter=init_step) + + return lr_scheduler + + +def setup_model_and_optimizer(args): + """Setup model and optimizer.""" + + model = get_model(args) + optimizer = get_optimizer(model, args) + lr_scheduler = get_learning_rate_scheduler(optimizer, args) + + if args.load is not None: + args.iteration = load_checkpoint(model, optimizer, lr_scheduler, args) + else: + args.iteration = 0 + + return model, optimizer, lr_scheduler + + +def get_batch(data_iterator, timers): + ''' get_batch subdivides the source data into chunks of + length args.seq_length. If source is equal to the example + output of the data loading example, with a seq_length limit + of 2, we'd get the following two Variables for i = 0: + ┌ a g m s ┐ ┌ b h n t ┐ + └ b h n t ┘ └ c i o u ┘ + Note that despite the name of the function, the subdivison of data is not + done along the batch dimension (i.e. dimension 1), since that was handled + by the data loader. The chunks are along dimension 0, corresponding + to the seq_len dimension in the LSTM. A Variable representing an appropriate + shard reset mask of the same dimensions is also returned. + ''' + # Items and their type. + keys = ['text', 'types', 'is_random', 'mask', 'mask_labels', 'pad_mask'] + datatype = torch.int64 + + # Broadcast data. + timers('data loader').start() + if data_iterator is not None: + data = next(data_iterator) + else: + data = None + timers('data loader').stop() + data_b = mpu.broadcast_data(keys, data, datatype) + + # Unpack. + tokens = data_b['text'].long() + types = data_b['types'].long() + next_sentence = data_b['is_random'].long() + loss_mask = data_b['mask'].float() + lm_labels = data_b['mask_labels'].long() + padding_mask = data_b['pad_mask'].byte() + + return tokens, types, next_sentence, loss_mask, lm_labels, padding_mask + + +def forward_step(data_iterator, model, args, timers): + """Forward step.""" + + # Get the batch. + timers('batch generator').start() + tokens, types, next_sentence, loss_mask, lm_labels, \ + padding_mask = get_batch(data_iterator, timers) + timers('batch generator').stop() + # Forward model. + output, nsp = model(tokens, types, 1-padding_mask, + checkpoint_activations=args.checkpoint_activations) + + nsp_loss = F.cross_entropy(nsp.view(-1, 2).contiguous().float(), + next_sentence.view(-1).contiguous(), + ignore_index=-1) + + losses = mpu.vocab_parallel_cross_entropy( + output.contiguous().float(), lm_labels.contiguous()) + loss_mask = loss_mask.contiguous() + lm_loss = torch.sum( + losses.view(-1) * loss_mask.view(-1).float()) / loss_mask.sum() + + return lm_loss, nsp_loss + + +def backward_step(optimizer, model, lm_loss, nsp_loss, args): + """Backward step.""" + + # Total loss. + loss = lm_loss + nsp_loss + + # Backward pass. + optimizer.zero_grad() + if args.fp16: + optimizer.backward(loss, update_master_grads=False) + else: + loss.backward() + + # Reduce across processes. + lm_loss_reduced = lm_loss + nsp_loss_reduced = nsp_loss + + reduced_losses = torch.cat((lm_loss.view(1), nsp_loss.view(1))) + torch.distributed.all_reduce(reduced_losses.data) + reduced_losses.data = reduced_losses.data / args.world_size + if not USE_TORCH_DDP: + model.allreduce_params(reduce_after=False, + fp32_allreduce=args.fp32_allreduce) + lm_loss_reduced = reduced_losses[0] + nsp_loss_reduced = reduced_losses[1] + + # Update master gradients. + if args.fp16: + optimizer.update_master_grads() + + # Clipping gradients helps prevent the exploding gradient. + if args.clip_grad > 0: + if not args.fp16: + mpu.clip_grad_norm(model.parameters(), args.clip_grad) + else: + optimizer.clip_master_grads(args.clip_grad) + + return lm_loss_reduced, nsp_loss_reduced + + +def train_step(data_iterator, model, optimizer, lr_scheduler, + args, timers): + """Single training step.""" + + # Forward model for one step. + timers('forward').start() + lm_loss, nsp_loss = forward_step(data_iterator, model, + args, timers) + timers('forward').stop() + + # Calculate gradients, reduce across processes, and clip. + timers('backward').start() + lm_loss_reduced, nsp_loss_reduced = backward_step(optimizer, model, lm_loss, + nsp_loss, args) + timers('backward').stop() + + # Update parameters. + timers('optimizer').start() + optimizer.step() + timers('optimizer').stop() + + # Update learning rate. + skipped_iter = 0 + if not (args.fp16 and optimizer.overflow): + lr_scheduler.step() + else: + skipped_iter = 1 + + return lm_loss_reduced, nsp_loss_reduced, skipped_iter + + +def train(model, optimizer, lr_scheduler, + train_data_iterator, val_data_iterator, timers, args): + """Train the model.""" + + # Turn on training mode which enables dropout. + model.train() + + # Tracking loss. + total_lm_loss = 0.0 + total_nsp_loss = 0.0 + + # Iterations. + iteration = args.iteration + skipped_iters = 0 + + timers('interval time').start() + report_memory_flag = True + while iteration < args.train_iters: + + lm_loss, nsp_loss, skipped_iter = train_step(train_data_iterator, + model, + optimizer, + lr_scheduler, + args, timers) + skipped_iters += skipped_iter + iteration += 1 + + # Update losses. + total_lm_loss += lm_loss.data.detach().float() + total_nsp_loss += nsp_loss.data.detach().float() + + # Logging. + if iteration % args.log_interval == 0: + learning_rate = optimizer.param_groups[0]['lr'] + avg_nsp_loss = total_nsp_loss.item() / args.log_interval + avg_lm_loss = total_lm_loss.item() / args.log_interval + elapsed_time = timers('interval time').elapsed() + log_string = ' iteration {:8d}/{:8d} |'.format(iteration, + args.train_iters) + log_string += ' elapsed time per iteration (ms): {:.1f} |'.format( + elapsed_time * 1000.0 / args.log_interval) + log_string += ' learning rate {:.3E} |'.format(learning_rate) + log_string += ' lm loss {:.6E} |'.format(avg_lm_loss) + log_string += ' nsp loss {:.6E} |'.format(avg_nsp_loss) + if args.fp16: + log_string += ' loss scale {:.1f} |'.format( + optimizer.loss_scale) + print_rank_0(log_string) + total_nsp_loss = 0.0 + total_lm_loss = 0.0 + if report_memory_flag: + report_memory('after {} iterations'.format(iteration)) + report_memory_flag = False + timers.log(['forward', 'backward', 'optimizer', 'batch generator', + 'data loader'], + normalizer=args.log_interval) + # Checkpointing + if args.save and args.save_interval and iteration % args.save_interval == 0: + save_checkpoint(iteration, model, optimizer, lr_scheduler, args) + + # Evaluation + if args.eval_interval and iteration % args.eval_interval == 0 and args.do_valid: + prefix = 'iteration {}'.format(iteration) + evaluate_and_print_results( + prefix, val_data_iterator, model, args, timers, False) + + if args.exit_interval and iteration % args.exit_interval == 0: + torch.distributed.barrier() + time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S') + rank = torch.distributed.get_rank() + print('rank: {} | time: {} | exiting the program at iteration {}'. + format(rank, time_str, iteration), flush=True) + exit() + + return iteration, skipped_iters + + +def evaluate(data_iterator, model, args, timers, verbose = False): + """Evaluation.""" + + # Turn on evaluation mode which disables dropout. + model.eval() + + total_lm_loss = 0 + total_nsp_loss = 0 + + with torch.no_grad(): + iteration = 0 + while iteration < args.eval_iters: + iteration += 1 + if verbose and iteration % args.log_interval == 0: + print_rank_0('Evaluating iter {}/{}'.format(iteration, args.eval_iters)) + # Forward evaluation. + lm_loss, nsp_loss = forward_step(data_iterator, model, + args, timers) + # Reduce across processes. + if isinstance(model, DDP): + reduced_losses = torch.cat((lm_loss.view(1), nsp_loss.view(1))) + torch.distributed.all_reduce(reduced_losses.data) + reduced_losses.data = reduced_losses.data/args.world_size + lm_loss = reduced_losses[0] + nsp_loss = reduced_losses[1] + + total_lm_loss += lm_loss.data.detach().float().item() + total_nsp_loss += nsp_loss.data.detach().float().item() + + # Move model back to the train mode. + model.train() + + total_lm_loss /= args.eval_iters + total_nsp_loss /= args.eval_iters + return total_lm_loss, total_nsp_loss + + +def evaluate_and_print_results(prefix, data_iterator, model, + args, timers, verbose=False): + """Helper function to evaluate and dump results on screen.""" + lm_loss, nsp_loss = evaluate(data_iterator, model, + args, timers, verbose) + val_loss = lm_loss + nsp_loss + print_rank_0('-' * 100) + string = ' validation loss at {} | '.format(prefix) + string += 'LM loss: {:.6E} | '.format(lm_loss) + string += 'NSP loss: {:.6E} | '.format(nsp_loss) + string += 'total loss: {:.6E}'.format(val_loss) + length = len(string) + 1 + print_rank_0('-' * length) + print_rank_0(string) + print_rank_0('-' * length) + + return val_loss + + +def initialize_distributed(args): + """Initialize torch.distributed.""" + + # Manually set the device ids. + device = args.rank % get_accelerator().device_count() + if args.local_rank is not None: + device = args.local_rank + get_accelerator().set_device(device) + # Call the init process + init_method = 'tcp://' + master_ip = os.getenv('MASTER_ADDR', 'localhost') + master_port = os.getenv('MASTER_PORT', '6000') + init_method += master_ip + ':' + master_port + torch.distributed.init_process_group( + backend=args.distributed_backend, + world_size=args.world_size, rank=args.rank, + init_method=init_method) + + # Set the model-parallel / data-parallel communicators. + mpu.initialize_model_parallel(args.model_parallel_size) + + +def set_random_seed(seed): + """Set random seed for reproducability.""" + + if seed is not None and seed > 0: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + mpu.model_parallel_cuda_manual_seed(seed) + + +def get_train_val_test_data(args): + """Load the data on rank zero and boradcast number of tokens to all GPUS.""" + + (train_data, val_data, test_data) = (None, None, None) + + # Data loader only on rank 0 of each model parallel group. + if mpu.get_model_parallel_rank() == 0: + data_config = configure_data() + data_config.set_defaults(data_set_type='BERT', transpose=False) + (train_data, val_data, test_data), tokenizer = data_config.apply(args) + before = tokenizer.num_tokens + after = before + multiple = args.make_vocab_size_divisible_by * \ + mpu.get_model_parallel_world_size() + while (after % multiple) != 0: + after += 1 + print_rank_0('> padded vocab (size: {}) with {} dummy ' + 'tokens (new size: {})'.format( + before, after - before, after)) + # Need to broadcast num_tokens and num_type_tokens. + token_counts = get_accelerator().LongTensor([after, + tokenizer.num_type_tokens, + int(args.do_train), int(args.do_valid), int(args.do_test)]) + else: + token_counts = get_accelerator().LongTensor([0, 0, 0, 0, 0]) + + # Broadcast num tokens. + torch.distributed.broadcast(token_counts, + mpu.get_model_parallel_src_rank(), + group=mpu.get_model_parallel_group()) + num_tokens = token_counts[0].item() + num_type_tokens = token_counts[1].item() + args.do_train = token_counts[2].item() + args.do_valid = token_counts[3].item() + args.do_test = token_counts[4].item() + + return train_data, val_data, test_data, num_tokens, num_type_tokens + + +def main(): + """Main training program.""" + + # Disable CuDNN. + torch.backends.cudnn.enabled = False + + # Timer. + timers = Timers() + + # Arguments. + args = get_args() + + # Pytorch distributed. + initialize_distributed(args) + if torch.distributed.get_rank() == 0: + print('Pretrain BERT model') + print_args(args) + + # Random seeds for reproducability. + set_random_seed(args.seed) + + # Data stuff. + train_data, val_data, test_data, args.tokenizer_num_tokens, \ + args.tokenizer_num_type_tokens = get_train_val_test_data(args) + + # Model, optimizer, and learning rate. + model, optimizer, lr_scheduler = setup_model_and_optimizer(args) + + if args.resume_dataloader: + if train_data is not None: + train_data.batch_sampler.start_iter = args.iteration % \ + len(train_data) + if val_data is not None: + start_iter_val = (args.train_iters // args.save_interval) * \ + args.eval_interval + val_data.batch_sampler.start_iter = start_iter_val % \ + len(val_data) + + if train_data is not None: + train_data_iterator = iter(train_data) + else: + train_data_iterator = None + if val_data is not None: + val_data_iterator = iter(val_data) + else: + val_data_iterator = None + + iteration = 0 + if args.train_iters > 0: + if args.do_train: + iteration, skipped = train(model, optimizer, + lr_scheduler, + train_data_iterator, + val_data_iterator, + timers, args) + if args.do_valid: + prefix = 'the end of training for val data' + val_loss = evaluate_and_print_results(prefix, val_data_iterator, + model, args, timers, False) + + if args.save and iteration != 0: + save_checkpoint(iteration, model, optimizer, lr_scheduler, args) + + if test_data is not None: + test_data_iterator = iter(test_data) + else: + test_data_iterator = None + + if args.do_test: + # Run on test data. + prefix = 'the end of training for test data' + evaluate_and_print_results(prefix, test_data_iterator, + model, args, timers, True) + + +if __name__ == "__main__": + main() diff --git a/examples/Megatron-LM/pretrain_gpt2.py b/examples/Megatron-LM/pretrain_gpt2.py new file mode 100755 index 0000000..2d7e52b --- /dev/null +++ b/examples/Megatron-LM/pretrain_gpt2.py @@ -0,0 +1,771 @@ +# coding=utf-8 +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Pretrain GPT2""" + +# Flag to use Pytorch ddp which uses overlapping communication and computation. +USE_TORCH_DDP = False + +from datetime import datetime +import os +import random +import math +import numpy as np +import torch + +import deepspeed +from deepspeed.accelerator.real_accelerator import get_accelerator + +from arguments import get_args +from configure_data import configure_data +from fp16 import FP16_Module +from fp16 import FP16_Optimizer +from bf16 import BF16_Module +from learning_rates import AnnealingLR +from model import GPT2Model +from model import gpt2_get_params_for_weight_decay_optimization + +if USE_TORCH_DDP: + from torch.nn.parallel.distributed import DistributedDataParallel as DDP +else: + from model import DistributedDataParallel as DDP +import mpu +from deepspeed.accelerator.real_accelerator import get_accelerator +if get_accelerator().device_name() == 'cuda': + from apex.optimizers import FusedAdam as Adam +else: + from torch.optim import Adam +from utils import Timers +from utils import save_checkpoint +from utils import load_checkpoint +from utils import report_memory +from utils import print_args +from utils import print_params_min_max_norm +from utils import print_rank_0 +import torch.distributed as dist + +from gpt2_data_loader import make_gpt2_dataloaders +import subprocess +import sys +import os + +def get_model(args): + """Build the model.""" + + print_rank_0('building GPT2 model ...') + model = GPT2Model(num_layers=args.num_layers, + vocab_size=args.vocab_size, + hidden_size=args.hidden_size, + num_attention_heads=args.num_attention_heads, + embedding_dropout_prob=args.hidden_dropout, + attention_dropout_prob=args.attention_dropout, + output_dropout_prob=args.hidden_dropout, + max_sequence_length=args.max_position_embeddings, + checkpoint_activations=args.checkpoint_activations, + checkpoint_num_layers=args.checkpoint_num_layers, + parallel_output=True) + + if mpu.get_data_parallel_rank() == 0: + print(' > number of parameters on model parallel rank {}: {}'.format( + mpu.get_model_parallel_rank(), + sum([p.nelement() for p in model.parameters()])), flush=True) + + #To prevent OOM for model sizes that cannot fit in GPU memory in full precision + if args.deepspeed: + if args.fp16: + model.half() + elif args.bf16: + model.bfloat16() + + # GPU allocation. + model.to(torch.device(get_accelerator().current_device_name())) + + # Fp16 conversion. + if args.fp16: + model = FP16_Module(model) + elif args.bf16: + model = BF16_Module(model) + + # Wrap model for distributed training. + if USE_TORCH_DDP: + i = torch.device(get_accelerator().current_device_name()) + model = DDP(model, device_ids=[i], output_device=i, + process_group=mpu.get_data_parallel_group()) + else: + model = DDP(model) + + return model + + +def get_optimizer(model, args): + """Set up the optimizer.""" + + # Build parameter groups (weight decay and non-decay). + while isinstance(model, (DDP, FP16_Module)): + model = model.module + param_groups = gpt2_get_params_for_weight_decay_optimization(model) + + # Add model parallel attribute if it is not set. + for param_group in param_groups: + for param in param_group['params']: + if not hasattr(param, 'model_parallel'): + param.model_parallel = False + + # Use FusedAdam. + optimizer = Adam(param_groups, + lr=args.lr, weight_decay=args.weight_decay) + + print(f'Optimizer = {optimizer.__class__.__name__}') + + if args.deepspeed: + return optimizer, param_groups + + # Wrap into fp16 optimizer. + if args.fp16: + + optimizer = FP16_Optimizer(optimizer, + static_loss_scale=args.loss_scale, + dynamic_loss_scale=args.dynamic_loss_scale, + dynamic_loss_args={ + 'scale_window': args.loss_scale_window, + 'min_scale': args.min_scale, + 'delayed_shift': args.hysteresis}) + + elif args.bf16: + + optimizer = FP16_Optimizer(optimizer, static_loss_scale=1.0) + + return optimizer, param_groups + + +def get_learning_rate_scheduler(optimizer, args): + """Build the learning rate scheduler.""" + + # Add linear learning rate scheduler. + if args.lr_decay_iters is not None: + num_iters = args.lr_decay_iters + else: + num_iters = args.train_iters + num_iters = max(1, num_iters) + init_step = -1 + warmup_iter = args.warmup * num_iters + lr_scheduler = AnnealingLR(optimizer, + start_lr=args.lr, + warmup_iter=warmup_iter, + num_iters=num_iters, + decay_style=args.lr_decay_style, + last_iter=init_step) + + return lr_scheduler + + +def setup_model_and_optimizer(args): + """Setup model and optimizer.""" + + model = get_model(args) + optimizer, param_groups = get_optimizer(model, args) + lr_scheduler = get_learning_rate_scheduler(optimizer, args) + + if args.deepspeed: + print_rank_0("DeepSpeed is enabled.") + + model, optimizer, _, lr_scheduler = deepspeed.initialize( + model=model, + optimizer=optimizer, + model_parameters=param_groups, + args=args, + lr_scheduler=lr_scheduler, + mpu=mpu, + dist_init_required=False + ) + + if args.load is not None: + args.iteration = load_checkpoint(model, optimizer, lr_scheduler, args) + else: + args.iteration = 0 + + return model, optimizer, lr_scheduler + + +def get_masks_and_position_ids(data, + eod_token, + reset_position_ids, + reset_attention_mask): + # Extract batch size and sequence length. + batch_size, seq_length = data.size() + + # Attention mask (lower triangular). + if reset_attention_mask: + att_mask_batch = batch_size + else: + att_mask_batch = 1 + attention_mask = torch.tril(torch.ones( + (att_mask_batch, seq_length, seq_length), device=data.device)).view( + att_mask_batch, 1, seq_length, seq_length) + + # Loss mask. + loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device) + loss_mask[data == eod_token] = 0.0 + + # Position ids. + position_ids = torch.arange(seq_length, dtype=torch.long, + device=data.device) + position_ids = position_ids.unsqueeze(0).expand_as(data) + # We need to clone as the ids will be modifed based on batch index. + if reset_position_ids: + position_ids = position_ids.clone() + + if reset_position_ids or reset_attention_mask: + # Loop through the batches: + for b in range(batch_size): + + # Find indecies where EOD token is. + eod_index = position_ids[b, data[b] == eod_token] + # Detach indecies from positions if going to modify positions. + if reset_position_ids: + eod_index = eod_index.clone() + + # Loop through EOD indecies: + prev_index = 0 + for j in range(eod_index.size()[0]): + i = eod_index[j] + # Mask attention loss. + if reset_attention_mask: + attention_mask[b, 0, (i+1):, :(i+1)] = 0 + # Reset positions. + if reset_position_ids: + position_ids[b, (i+1):] -= (i + 1 - prev_index) + prev_index = i + 1 + + return attention_mask, loss_mask, position_ids + + +def get_batch(data_iterator, args, timers): + ''' get_batch subdivides the source data into chunks of + length args.seq_length. If source is equal to the example + output of the data loading example, with a seq_length limit + of 2, we'd get the following two Variables for i = 0: + ┌ a g m s ┐ ┌ b h n t ┐ + └ b h n t ┘ └ c i o u ┘ + Note that despite the name of the function, the subdivison of data is not + done along the batch dimension (i.e. dimension 1), since that was handled + by the data loader. The chunks are along dimension 0, corresponding + to the seq_len dimension in the LSTM. A Variable representing an appropriate + shard reset mask of the same dimensions is also returned. + ''' + # Items and their type. + keys = ['text'] + datatype = torch.int64 + + # Broadcast data. + timers('data loader').start() + if data_iterator is not None: + data = next(data_iterator) + else: + data = None + timers('data loader').stop() + data_b = mpu.broadcast_data(keys, data, datatype) + + # Unpack. + tokens_ = data_b['text'].long() + labels = tokens_[:, 1:].contiguous() + tokens = tokens_[:, :-1].contiguous() + + # Get the masks and postition ids. + attention_mask, loss_mask, position_ids = get_masks_and_position_ids( + tokens, + args.eod_token, + args.reset_position_ids, + args.reset_attention_mask) + # Convert + if args.fp16: + attention_mask = attention_mask.half() + elif args.bf16: + attention_mask = attention_mask.bfloat16() + + return tokens, labels, loss_mask, attention_mask, position_ids + + +def forward_step(data_iterator, model, args, timers): + """Forward step.""" + + # Get the batch. + timers('batch generator').start() + tokens, labels, loss_mask, attention_mask, position_ids = get_batch( + data_iterator, args, timers) + timers('batch generator').stop() + + # Forward model. + output = model(tokens, position_ids, attention_mask) + losses = mpu.vocab_parallel_cross_entropy(output.contiguous().float(), + labels) + loss_mask = loss_mask.view(-1) + loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() + + return loss + + +def backward_step(optimizer, model, lm_loss, args, timers): + """Backward step.""" + + # Total loss. + loss = lm_loss + + # Backward pass. + if args.deepspeed: + model.backward(loss) + else: + optimizer.zero_grad() + if args.fp16: + optimizer.backward(loss, update_master_grads=False) + else: + loss.backward() + + # Reduce across processes. + lm_loss_reduced = lm_loss + + reduced_losses = lm_loss.view(1) + + if args.deepspeed: + # DeepSpeed backward propagation already addressed all reduce communication. + # Reset the timer to avoid breaking timer logs below. + timers('allreduce').reset() + else: + torch.distributed.all_reduce(reduced_losses.data) + reduced_losses.data = reduced_losses.data / args.world_size + if not USE_TORCH_DDP: + timers('allreduce').start() + model.allreduce_params(reduce_after=False, + fp32_allreduce=args.fp32_allreduce) + timers('allreduce').stop() + + lm_loss_reduced = reduced_losses + + # Update master gradients. + if not args.deepspeed: + if args.fp16: + optimizer.update_master_grads() + + # Clipping gradients helps prevent the exploding gradient. + if args.clip_grad > 0: + if not args.fp16: + mpu.clip_grad_norm(model.parameters(), args.clip_grad) + else: + optimizer.clip_master_grads(args.clip_grad) + + return lm_loss_reduced + +def see_memory_usage(message, force=False): + if not force: + return + dist.barrier() + if dist.get_rank() == 0: + print(message) + print("Memory Allocated ", get_accelerator().memory_allocated()/(1024*1024*1024), "GigaBytes") + print("Max Memory Allocated ", get_accelerator().max_memory_allocated()/(1024*1024*1024), "GigaBytes") + print("Cache Allocated ", get_accelerator().memory_cached()/(1024*1024*1024), "GigaBytes") + print("Max cache Allocated ", get_accelerator().max_memory_cached()/(1024*1024*1024), "GigaBytes") + print(" ") + #input("Press Any Key To Continue ..") + +def train_step(data_iterator, model, optimizer, lr_scheduler, + args, timers): + """Single training step.""" + + # Forward model for one step. + timers('forward').start() + lm_loss = forward_step(data_iterator, model, args, timers) + timers('forward').stop() + + #print_rank_0("loss is {}".format(lm_loss)) + + # Calculate gradients, reduce across processes, and clip. + timers('backward').start() + lm_loss_reduced = backward_step(optimizer, model, lm_loss, args, timers) + timers('backward').stop() + + # Update parameters. + skipped_iter = 0 + timers('optimizer').start() + if args.deepspeed: + model.step() + else: + optimizer.step() + + # Update learning rate. + if not (args.fp16 and optimizer.overflow): + lr_scheduler.step() + else: + skipped_iter = 1 + timers('optimizer').stop() + + return lm_loss_reduced, skipped_iter + + +def train(model, optimizer, lr_scheduler, + train_data_iterator, val_data_iterator, timers, args): + """Train the model.""" + + # Turn on training mode which enables dropout. + model.train() + + # Tracking loss. + total_lm_loss = 0.0 + + # Iterations. + iteration = args.iteration + skipped_iters = 0 + + timers('interval time').start() + report_memory_flag = False + total_time = 0 + total_index = 0 + prof_ite = args.train_iters - 1 + prof_enabled = False + while iteration < args.train_iters: + if iteration == prof_ite: + prof_enabled = True + else: + prof_enabled = False + begin = datetime.now() + print(f"to run iteration {iteration}", "at", begin.strftime('%Y-%m-%d %H:%M:%S')) + with torch.autograd.profiler_legacy.profile(enabled=prof_enabled, use_xpu=True) as prof: + lm_loss, skipped_iter = train_step(train_data_iterator, + model, + optimizer, + lr_scheduler, + args, timers) + get_accelerator().synchronize() # for safty to profile cpu/xpu time, should be removed at last + end = datetime.now() + diff = end - begin + print(f"iteration {iteration} is finished", "at", end.strftime('%Y-%m-%d %H:%M:%S'), "(", diff.total_seconds(), "seconds)") + + if iteration > 2 and iteration < prof_ite: + total_time += diff.total_seconds() + total_index += 1 + print("average time per iteration (ignore iter 0/1/2):", total_time/total_index, "seconds at iteration ", iteration) + + if prof_enabled: + print(prof.key_averages().table(sort_by="self_xpu_time_total", row_limit=-1)) + + # check the GPU memory after each iteration + if not args.disable_sysmon: + sysmon = '/home/adnguye1/repos/tools/pti-gpu/tools/sysmon/build/sysmon' + if os.path.isfile(sysmon): + subprocess.call(sysmon) + + skipped_iters += skipped_iter + iteration += 1 + + # Update losses. + total_lm_loss += lm_loss.data.detach().float() + + # Logging. + if iteration % args.log_interval == 0: + learning_rate = optimizer.param_groups[0]['lr'] + avg_lm_loss = total_lm_loss.item() / args.log_interval + elapsed_time = timers('interval time').elapsed() + log_string = ' iteration {:8d}/{:8d} |'.format(iteration, + args.train_iters) + log_string += ' elapsed time per iteration (ms): {:.1f} |'.format( + elapsed_time * 1000.0 / args.log_interval) + log_string += ' learning rate {:.3E} |'.format(learning_rate) + log_string += ' lm loss {:.6E} |'.format(avg_lm_loss) + if args.fp16: + log_string += ' loss scale {:.1f} |'.format( + optimizer.cur_scale if args.deepspeed else optimizer.loss_scale) + print_rank_0(log_string) + total_lm_loss = 0.0 + if report_memory_flag: + report_memory('after {} iterations'.format(iteration)) + report_memory_flag = False + if USE_TORCH_DDP: + timers.log(['forward', 'backward', 'optimizer', + 'batch generator', 'data loader'], + normalizer=args.log_interval) + else: + timers.log(['forward', 'backward', 'allreduce', 'optimizer', + 'batch generator', 'data loader'], + normalizer=args.log_interval) + # Checkpointing + if args.save and args.save_interval and iteration % args.save_interval == 0: + save_checkpoint(iteration, model, optimizer, lr_scheduler, args) + + # Evaluation + if args.eval_interval and iteration % args.eval_interval == 0 and args.do_valid: + prefix = 'iteration {}'.format(iteration) + evaluate_and_print_results( + prefix, val_data_iterator, model, args, timers, False) + + if args.exit_interval and iteration % args.exit_interval == 0: + torch.distributed.barrier() + time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S') + rank = torch.distributed.get_rank() + print('rank: {} | time: {} | exiting the program at iteration {}'. + format(rank, time_str, iteration), flush=True) + exit() + + return iteration, skipped_iters + + +def evaluate(data_iterator, model, args, timers, verbose=False): + """Evaluation.""" + + # Turn on evaluation mode which disables dropout. + model.eval() + + total_lm_loss = 0 + + with torch.no_grad(): + iteration = 0 + while iteration < args.eval_iters: + iteration += 1 + if verbose and iteration % args.log_interval == 0: + print_rank_0('Evaluating iter {}/{}'.format(iteration, args.eval_iters)) + # Forward evaluation. + lm_loss = forward_step(data_iterator, model, args, timers) + + '''when contiguous memory optimizations are enabled, the buffers + allocated by the optimizations are deallocated during backward pass + in the absence of backward pass the buffers should be reset after each + forward pass''' + if args.deepspeed and args.deepspeed_activation_checkpointing: + deepspeed.checkpointing.reset() + + # Reduce across processes. + if isinstance(model, DDP): + torch.distributed.all_reduce(lm_loss.data) + lm_loss.data = lm_loss.data / args.world_size + + total_lm_loss += lm_loss.data.detach().float().item() + + # Move model back to the train mode. + model.train() + + total_lm_loss /= args.eval_iters + return total_lm_loss + + +def evaluate_and_print_results(prefix, data_iterator, model, + args, timers, verbose=False): + """Helper function to evaluate and dump results on screen.""" + lm_loss = evaluate(data_iterator, model, args, timers, verbose) + lm_ppl = math.exp(min(20, lm_loss)) + print_rank_0('-' * 100) + string = ' validation loss at {} | '.format(prefix) + string += 'LM loss: {:.6E} | '.format(lm_loss) + string += 'LM PPL: {:.6E}'.format(lm_ppl) + length = len(string) + 1 + print_rank_0('-' * length) + print_rank_0(string) + print_rank_0('-' * length) + + return lm_loss + +''' + Optional DeepSpeed Activation Checkpointing features + Gives access to partition activations, contiguous memory optimizations + and cpu checkpointing. + + Activation checkpoint requires keep track of the random states + and setting the random seed for each MP process. Megatron uses + mpu.get_cuda_rng_tracker and mpu.model_parallel_cuda_manual_seed + for keeping track of the random states and setting the random seeds. + Since they are used in places outside of activation checkpointing, + we overwrite them to maintain consistency. + + This must be done before all the calls to mpu.model_parallel_cuda_manual_seed + ''' +def set_deepspeed_activation_checkpointing(args): + + deepspeed.checkpointing.configure(mpu, deepspeed_config=args.deepspeed_config, num_checkpoints=args.num_layers) + mpu.checkpoint = deepspeed.checkpointing.checkpoint + mpu.get_cuda_rng_tracker = deepspeed.checkpointing.get_cuda_rng_tracker + mpu.model_parallel_cuda_manual_seed = deepspeed.checkpointing.model_parallel_cuda_manual_seed + +def initialize_distributed(args): + """Initialize torch.distributed.""" + + if args.deepspeed: + deepspeed.init_distributed(dist_backend=args.distributed_backend) + else: + # Manually set the device ids. + device = args.rank % get_accelerator().device_count() + # Call the init process + init_method = 'tcp://' + master_ip = os.getenv('MASTER_ADDR', 'localhost') + master_port = os.getenv('MASTER_PORT', '6000') + init_method += master_ip + ':' + master_port + torch.distributed.init_process_group( + backend=args.distributed_backend, + world_size=args.world_size, rank=args.rank, + init_method=init_method) + + if args.local_rank is not None: + device = args.local_rank + get_accelerator().set_device(device) + + # Set the model-parallel / data-parallel communicators. + mpu.initialize_model_parallel(args.model_parallel_size) + + # Optional DeepSpeed Activation Checkpointing Features + # + if args.deepspeed and args.deepspeed_activation_checkpointing: + set_deepspeed_activation_checkpointing(args) + + +def set_random_seed(seed): + """Set random seed for reproducability.""" + + if seed is not None and seed > 0: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + mpu.model_parallel_cuda_manual_seed(seed) + + +def get_train_val_test_data(args): + """Load the data on rank zero and boradcast number of tokens to all GPUS.""" + + (train_data, val_data, test_data) = (None, None, None) + + # Data loader only on rank 0 of each model parallel group. + if mpu.get_model_parallel_rank() == 0: + if args.use_npy_data_loader: + (train_data, val_data, test_data), num_tokens, \ + eod_token = make_gpt2_dataloaders(args) + else: + data_config = configure_data() + data_config.set_defaults(data_set_type='GPT2', transpose=False) + (train_data, val_data, test_data), tokenizer = data_config.apply( + args) + num_tokens = tokenizer.num_tokens + eod_token = tokenizer.get_command('eos').Id + assert eod_token == tokenizer.get_command('pad').Id + before = num_tokens + after = before + multiple = args.make_vocab_size_divisible_by * \ + mpu.get_model_parallel_world_size() + while (after % multiple) != 0: + after += 1 + print_rank_0('> padded vocab (size: {}) with {} dummy ' + 'tokens (new size: {})'.format( + before, after - before, after)) + print_rank_0('> found end-of-document token: {}'.format(eod_token)) + token_counts = get_accelerator().LongTensor([after, eod_token, int(args.do_train), int(args.do_valid), int(args.do_test)]) + else: + token_counts = get_accelerator().LongTensor([0, 0, 0, 0, 0]) + + # Broadcast num tokens. + if mpu.get_model_parallel_group().size() > 1: + torch.distributed.broadcast(token_counts, + mpu.get_model_parallel_src_rank(), + group=mpu.get_model_parallel_group()) + num_tokens = token_counts[0].item() + eod_token = token_counts[1].item() + args.do_train = token_counts[2].item() + args.do_valid = token_counts[3].item() + args.do_test = token_counts[4].item() + + return train_data, val_data, test_data, num_tokens, eod_token + + +def main(): + """Main training program.""" + + # Disable CuDNN. + torch.backends.cudnn.enabled = False + + # Timer. + timers = Timers() + + # Arguments. + args = get_args() + + # Pytorch distributed. + initialize_distributed(args) + if torch.distributed.get_rank() == 0: + print('Pretrain GPT2 model') + print_args(args) + + # Random seeds for reproducability. + set_random_seed(args.seed) + + # Data stuff. + train_data, val_data, test_data, args.vocab_size, \ + args.eod_token = get_train_val_test_data(args) + + # Model, optimizer, and learning rate. + model, optimizer, lr_scheduler = setup_model_and_optimizer(args) + + # Resume data loader if necessary. + if args.resume_dataloader: + if train_data is not None: + train_data.batch_sampler.start_iter = args.iteration % \ + len(train_data) + if val_data is not None: + start_iter_val = (args.train_iters // args.save_interval) * \ + args.eval_interval + val_data.batch_sampler.start_iter = start_iter_val % \ + len(val_data) + if train_data is not None: + train_data_iterator = iter(train_data) + else: + train_data_iterator = None + if val_data is not None: + val_data_iterator = iter(val_data) + else: + val_data_iterator = None + + #TODO: figure out how to properly set this especially when resuming training + iteration = 0 + if args.train_iters > 0: + if args.do_train: + iteration, skipped = train(model, optimizer, + lr_scheduler, + train_data_iterator, + val_data_iterator, + timers, args) + + if args.do_valid: + prefix = 'the end of training for val data' + val_loss = evaluate_and_print_results(prefix, val_data_iterator, + model, args, timers, False) + + if args.save and iteration != 0: + save_checkpoint(iteration, model, optimizer, lr_scheduler, args) + + if test_data is not None: + test_data_iterator = iter(test_data) + else: + test_data_iterator = None + + if args.do_test: + # Run on test data. + prefix = 'the end of training for test data' + evaluate_and_print_results(prefix, test_data_iterator, + model, args, timers, True) + + +if __name__ == "__main__": + begin = datetime.now() + print("begin process", os.getpid(), "at", begin.strftime('%Y-%m-%d %H:%M:%S')) + version_file = sys.path[0] + '/versions.log' + if os.path.isfile(version_file): + subprocess.call(['cat', version_file]) + main() + end = datetime.now() + diff = end - begin + print("end process", os.getpid(), "at", end.strftime('%Y-%m-%d %H:%M:%S'), "(", diff.total_seconds()/60.0, "minutes)") diff --git a/examples/Megatron-LM/requirements.txt b/examples/Megatron-LM/requirements.txt new file mode 100644 index 0000000..f676d19 --- /dev/null +++ b/examples/Megatron-LM/requirements.txt @@ -0,0 +1,8 @@ +nltk>=3.6.6 +numpy>=1.22.2 +urllib3>=1.26.5 +pandas>=0.24.0 +sentencepiece>=0.1.8 +boto3==1.24.74 +regex==2022.9.13 +requests==2.28.1 \ No newline at end of file diff --git a/examples/Megatron-LM/scripts/ds_checkpoint_check.sh b/examples/Megatron-LM/scripts/ds_checkpoint_check.sh new file mode 100644 index 0000000..3d64f6d --- /dev/null +++ b/examples/Megatron-LM/scripts/ds_checkpoint_check.sh @@ -0,0 +1,51 @@ +#! /bin/bash + +# Runs the "345M" parameter model + +GPUS_PER_NODE=4 +# Change for multinode config +MASTER_ADDR=localhost +MASTER_PORT=6000 +NNODES=1 +NODE_RANK=0 +WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) + +export DLWS_NUM_WORKER=1 +export DLWS_NUM_GPU_PER_WORKER=4 +MP_SIZE=2 + +gpt_options=" \ + --model-parallel-size ${MP_SIZE} \ + --num-layers 2 \ + --hidden-size 256 \ + --num-attention-heads 16 \ + --batch-size 8 \ + --seq-length 1024 \ + --max-position-embeddings 1024 \ + --train-iters 1100 \ + --resume-dataloader \ + --train-data webtext \ + --lazy-loader \ + --tokenizer-type GPT2BPETokenizer \ + --split 949,50,1 \ + --distributed-backend nccl \ + --lr 0.00015 \ + --save ds_checkpoints + --load ds_checkpoints + --save-interval 100 + --lr-decay-style cosine \ + --weight-decay 1e-2 \ + --clip-grad 1.0 \ + --warmup .01 \ + --checkpoint-activations \ + --fp16 \ + --deepspeed \ + --loss-scale 0 \ + --deepspeed_config ds_config_func_bs8.json +" + +run_cmd="deepspeed --num_nodes ${DLWS_NUM_WORKER} --num_gpus ${DLWS_NUM_GPU_PER_WORKER} pretrain_gpt2.py $@ ${gpt_options}" +echo ${run_cmd} +eval ${run_cmd} + +set +x diff --git a/examples/Megatron-LM/scripts/ds_zero-offload_10B_config.json b/examples/Megatron-LM/scripts/ds_zero-offload_10B_config.json new file mode 100755 index 0000000..50d3ebb --- /dev/null +++ b/examples/Megatron-LM/scripts/ds_zero-offload_10B_config.json @@ -0,0 +1,33 @@ +{ + "train_micro_batch_size_per_gpu": 10, + "gradient_accumulation_steps": 1, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00015, + "weight_decay": 1e-2 + } + }, + "zero_optimization": { + "stage": 2, + "cpu_offload": true, + "reduce_bucket_size": 50000000 + }, + "zero_allow_untested_optimizer": true, + "gradient_clipping": 1.0, + "fp16": { + "enabled": true, + "loss_scale": 0, + "loss_scale_window": 1000, + "hysteresis": 2, + "min_loss_scale": 1 + }, + "activation_checkpointing": { + "partition_activations": true, + "contiguous_memory_optimization": true, + "cpu_checkpointing": true + }, + "wall_clock_breakdown": true +} + diff --git a/examples/Megatron-LM/scripts/ds_zero-offload_10B_pretrain_gpt2_model_parallel.sh b/examples/Megatron-LM/scripts/ds_zero-offload_10B_pretrain_gpt2_model_parallel.sh new file mode 100755 index 0000000..c54a638 --- /dev/null +++ b/examples/Megatron-LM/scripts/ds_zero-offload_10B_pretrain_gpt2_model_parallel.sh @@ -0,0 +1,49 @@ +#! /bin/bash + +# Change for multinode config +MP_SIZE=1 + +NUM_WORKERS=1 +NUM_GPUS_PER_WORKER=1 + +script_path=$(realpath $0) +script_dir=$(dirname $script_path) + +config_json="$script_dir/ds_zero-offload_10B_config.json" +gpt_options=" \ + --model-parallel-size ${MP_SIZE} \ + --num-layers 50 \ + --hidden-size 4096 \ + --num-attention-heads 32 \ + --batch-size 10 \ + --seq-length 1024 \ + --max-position-embeddings 1024 \ + --train-iters 100 \ + --resume-dataloader \ + --train-data webtext \ + --lazy-loader \ + --tokenizer-type GPT2BPETokenizer \ + --split 949,50,1 \ + --distributed-backend nccl \ + --lr 0.00015 \ + --no-load-optim \ + --lr-decay-style cosine \ + --weight-decay 1e-2 \ + --clip-grad 1.0 \ + --warmup .01 \ + --checkpoint-activations \ + --deepspeed-activation-checkpointing \ + --fp16 \ + --log-interval 1 \ +" +gpt_options="${gpt_options} + --deepspeed \ + --deepspeed_config ${config_json} \ +" + + +run_cmd="deepspeed --num_nodes ${NUM_WORKERS} --num_gpus ${NUM_GPUS_PER_WORKER} pretrain_gpt2.py $@ ${gpt_options}" +echo ${run_cmd} +eval ${run_cmd} + +set +x diff --git a/examples/Megatron-LM/scripts/ds_zero-offload_config.json b/examples/Megatron-LM/scripts/ds_zero-offload_config.json new file mode 100755 index 0000000..bd7e06b --- /dev/null +++ b/examples/Megatron-LM/scripts/ds_zero-offload_config.json @@ -0,0 +1,31 @@ +{ + "train_micro_batch_size_per_gpu": 12, + "gradient_accumulation_steps": 5, + "steps_per_print": 100, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00015, + "weight_decay": 1e-2 + } + }, + "zero_optimization": { + "stage": 2, + "cpu_offload": true, + "reduce_bucket_size": 50000000 + }, + "gradient_clipping": 1.0, + "fp16": { + "enabled": true, + "loss_scale": 0, + "loss_scale_window": 1000, + "hysteresis": 2, + "min_loss_scale": 1 + }, + "activation_checkpointing": { + "partition_activations": true, + "contiguous_memory_optimization": true, + "cpu_checkpointing": true + }, + "wall_clock_breakdown": false +} diff --git a/examples/Megatron-LM/scripts/ds_zero-offload_config_bf16.json b/examples/Megatron-LM/scripts/ds_zero-offload_config_bf16.json new file mode 100644 index 0000000..00ba94c --- /dev/null +++ b/examples/Megatron-LM/scripts/ds_zero-offload_config_bf16.json @@ -0,0 +1,44 @@ +{ + "train_micro_batch_size_per_gpu": 12, + "gradient_accumulation_steps": 5, + "steps_per_print": 100, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00015, + "weight_decay": 1e-2 + } + }, + "zero_optimization": { + "stage": 2, + "cpu_offload": true, + "reduce_bucket_size": 50000000 + }, + "gradient_clipping": 1.0, + "fp16": { + "enabled": false, + "loss_scale": 0, + "loss_scale_window": 1000, + "hysteresis": 2, + "min_loss_scale": 1 + }, + "bfloat16": { + "enabled": true, + "loss_scale": 1.0 + }, + "activation_checkpointing": { + "partition_activations": true, + "contiguous_memory_optimization": true, + "cpu_checkpointing": true + }, + "wall_clock_breakdown": false, + "flops_profiler": { + "enabled": true, + "profile_step": 5, + "module_depth": -1, + "top_modules": 1, + "detailed": true, + "output_file": null + } + } + \ No newline at end of file diff --git a/examples/Megatron-LM/scripts/ds_zero-offload_pretrain_gpt2_model_parallel.sh b/examples/Megatron-LM/scripts/ds_zero-offload_pretrain_gpt2_model_parallel.sh new file mode 100755 index 0000000..5203570 --- /dev/null +++ b/examples/Megatron-LM/scripts/ds_zero-offload_pretrain_gpt2_model_parallel.sh @@ -0,0 +1,48 @@ +#! /bin/bash + +# Change for multinode config +MP_SIZE=4 + +NUM_WORKERS=1 +NUM_GPUS_PER_WORKER=16 + +script_path=$(realpath $0) +script_dir=$(dirname $script_path) + +config_json="$script_dir/ds_zero-offload_config.json" +gpt_options=" \ + --model-parallel-size ${MP_SIZE} \ + --num-layers 4\ + --hidden-size 1024 \ + --num-attention-heads 16 \ + --batch-size 12 \ + --seq-length 1024 \ + --max-position-embeddings 1024 \ + --train-iters 400 \ + --resume-dataloader \ + --train-data webtext \ + --lazy-loader \ + --tokenizer-type GPT2BPETokenizer \ + --split 949,50,1 \ + --distributed-backend nccl \ + --lr 0.00015 \ + --no-load-optim \ + --lr-decay-style cosine \ + --weight-decay 1e-2 \ + --clip-grad 1.0 \ + --warmup .01 \ + --checkpoint-activations \ + --deepspeed-activation-checkpointing \ + --fp16 \ + --log-interval 5 \ +" +gpt_options="${gpt_options} + --deepspeed \ + --deepspeed_config ${config_json} \ +" + +run_cmd="deepspeed --num_nodes ${NUM_WORKERS} --num_gpus ${NUM_GPUS_PER_WORKER} pretrain_gpt2.py $@ ${gpt_options}" +echo ${run_cmd} +eval ${run_cmd} + +set +x diff --git a/examples/Megatron-LM/scripts/ds_zero-offload_pretrain_gpt2_model_parallel_bf16.sh b/examples/Megatron-LM/scripts/ds_zero-offload_pretrain_gpt2_model_parallel_bf16.sh new file mode 100644 index 0000000..6c3cf73 --- /dev/null +++ b/examples/Megatron-LM/scripts/ds_zero-offload_pretrain_gpt2_model_parallel_bf16.sh @@ -0,0 +1,48 @@ +#! /bin/bash + +# Change for multinode config +MP_SIZE=1 + +NUM_WORKERS=1 +NUM_GPUS_PER_WORKER=12 + +script_path=$(realpath $0) +script_dir=$(dirname $script_path) + +config_json="$script_dir/ds_zero-offload_config_bf16.json" +gpt_options=" \ + --model-parallel-size ${MP_SIZE} \ + --num-layers 24\ + --hidden-size 1024 \ + --num-attention-heads 16 \ + --batch-size 1 \ + --seq-length 512 \ + --max-position-embeddings 1024 \ + --train-iters 400 \ + --resume-dataloader \ + --train-data webtext \ + --lazy-loader \ + --tokenizer-type GPT2BPETokenizer \ + --split 949,50,1 \ + --distributed-backend ccl \ + --lr 0.00015 \ + --no-load-optim \ + --lr-decay-style cosine \ + --weight-decay 1e-2 \ + --clip-grad 1.0 \ + --warmup .01 \ + --checkpoint-activations \ + --deepspeed-activation-checkpointing \ + --bf16 \ + --log-interval 5 \ +" +gpt_options="${gpt_options} + --deepspeed \ + --deepspeed_config ${config_json} \ +" + +run_cmd="deepspeed --num_nodes ${NUM_WORKERS} --num_gpus ${NUM_GPUS_PER_WORKER} pretrain_gpt2.py $@ ${gpt_options}" +echo ${run_cmd} +eval ${run_cmd} + +set +x diff --git a/examples/Megatron-LM/scripts/ds_zero2_config.json b/examples/Megatron-LM/scripts/ds_zero2_config.json new file mode 100755 index 0000000..4c47642 --- /dev/null +++ b/examples/Megatron-LM/scripts/ds_zero2_config.json @@ -0,0 +1,29 @@ +{ + "train_batch_size": 16, + "gradient_accumulation_steps": 1, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00015, + "weight_decay": 1e-2 + } + }, + "zero_optimization": { + "stage": 2 + }, + "zero_allow_untested_optimizer": true, + "gradient_clipping": 1.0, + "fp16": { + "enabled": true, + "loss_scale": 0, + "loss_scale_window": 1000, + "hysteresis": 2, + "min_loss_scale": 1 + }, + "activation_checkpointing": { + "partition_activations": true, + "contiguous_memory_optimization": false + }, + "wall_clock_breakdown": false +} diff --git a/examples/Megatron-LM/scripts/ds_zero2_config_bf16.json b/examples/Megatron-LM/scripts/ds_zero2_config_bf16.json new file mode 100755 index 0000000..f30aab3 --- /dev/null +++ b/examples/Megatron-LM/scripts/ds_zero2_config_bf16.json @@ -0,0 +1,43 @@ +{ + "train_batch_size": 96, + "gradient_accumulation_steps": 1, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00015, + "weight_decay": 1e-2 + } + }, + "zero_optimization": { + "stage": 2, + "reduce_scatter": false + }, + "zero_allow_untested_optimizer": true, + "communication_data_type": "bfp16", + "gradient_clipping": 1.0, + "fp16": { + "enabled": false, + "loss_scale": 0, + "loss_scale_window": 1000, + "hysteresis": 2, + "min_loss_scale": 1 + }, + "bfloat16": { + "enabled": true, + "loss_scale": 1.0 + }, + "activation_checkpointing": { + "partition_activations": true, + "contiguous_memory_optimization": false + }, + "wall_clock_breakdown": false, + "flops_profiler": { + "enabled": true, + "profile_step": -1, + "module_depth": -1, + "top_modules": 1, + "detailed": true, + "output_file": null + } +} diff --git a/examples/Megatron-LM/scripts/ds_zero2_pretrain_gpt2_model_parallel.sh b/examples/Megatron-LM/scripts/ds_zero2_pretrain_gpt2_model_parallel.sh new file mode 100755 index 0000000..110a3ea --- /dev/null +++ b/examples/Megatron-LM/scripts/ds_zero2_pretrain_gpt2_model_parallel.sh @@ -0,0 +1,48 @@ +#! /bin/bash + +# Change for multinode config +MP_SIZE=4 + +NUM_WORKERS=1 +NUM_GPUS_PER_WORKER=16 + +script_path=$(realpath $0) +script_dir=$(dirname $script_path) + +config_json="$script_dir/ds_zero2_config.json" +gpt_options=" \ + --model-parallel-size ${MP_SIZE} \ + --num-layers 24 \ + --hidden-size 1024 \ + --num-attention-heads 16 \ + --batch-size 8 \ + --seq-length 1024 \ + --max-position-embeddings 1024 \ + --train-iters 100000 \ + --resume-dataloader \ + --train-data webtext \ + --lazy-loader \ + --tokenizer-type GPT2BPETokenizer \ + --split 949,50,1 \ + --distributed-backend nccl \ + --lr 0.00015 \ + --no-load-optim \ + --lr-decay-style cosine \ + --weight-decay 1e-2 \ + --clip-grad 1.0 \ + --warmup .01 \ + --checkpoint-activations \ + --deepspeed-activation-checkpointing \ + --fp16 \ +" +gpt_options="${gpt_options} + --deepspeed \ + --deepspeed_config ${config_json} \ +" + + +run_cmd="deepspeed --num_nodes ${NUM_WORKERS} --num_gpus ${NUM_GPUS_PER_WORKER} pretrain_gpt2.py $@ ${gpt_options}" +echo ${run_cmd} +eval ${run_cmd} + +set +x diff --git a/examples/Megatron-LM/scripts/generate_text.sh b/examples/Megatron-LM/scripts/generate_text.sh new file mode 100755 index 0000000..df9dc23 --- /dev/null +++ b/examples/Megatron-LM/scripts/generate_text.sh @@ -0,0 +1,29 @@ +#!/bin/bash + +CHECKPOINT_PATH=/path/to/checkpoint +MPSIZE=1 +NLAYERS=24 +NHIDDEN=1024 +NATT=16 +MAXSEQLEN=1024 + +#SAMPLING ARGS +TEMP=0.9 +#If TOPK/TOPP are 0 it defaults to greedy sampling, top-k will also override top-p +TOPK=0 +TOPP=0 + +python generate_samples.py \ + --model-parallel-size $MPSIZE \ + --num-layers $NLAYERS \ + --hidden-size $NHIDDEN \ + --load $CHECKPOINT_PATH \ + --num-attention-heads $NATT \ + --max-position-embeddings 1024 \ + --tokenizer-type GPT2BPETokenizer \ + --fp16 \ + --cache-dir cache \ + --out-seq-length $MAXSEQLEN \ + --temperature $TEMP \ + --top_k $TOPK \ + --top_p $TOPP diff --git a/examples/Megatron-LM/scripts/gpt-3.6b-fp16.sh b/examples/Megatron-LM/scripts/gpt-3.6b-fp16.sh new file mode 100755 index 0000000..cb5580c --- /dev/null +++ b/examples/Megatron-LM/scripts/gpt-3.6b-fp16.sh @@ -0,0 +1,49 @@ +#! /bin/bash + +# Change for multinode config +MP_SIZE=1 + +NUM_WORKERS=1 +NUM_GPUS_PER_WORKER=12 + + +script_path=$(realpath $0) +script_dir=$(dirname $script_path) + +config_json="$script_dir/ds_zero2_config.json" +gpt_options=" \ + --model-parallel-size ${MP_SIZE} \ + --num-layers 30 \ + --hidden-size 3072 \ + --num-attention-heads 32 \ + --batch-size 8 \ + --seq-length 2048 \ + --max-position-embeddings 2048 \ + --train-iters 10 \ + --resume-dataloader \ + --train-data webtext \ + --lazy-loader \ + --tokenizer-type GPT2BPETokenizer \ + --split 949,50,1 \ + --distributed-backend ccl \ + --lr 0.00015 \ + --no-load-optim \ + --lr-decay-style cosine \ + --weight-decay 1e-2 \ + --clip-grad 1.0 \ + --warmup .01 \ + --checkpoint-activations \ + --deepspeed-activation-checkpointing \ + --fp16 \ +" +gpt_options="${gpt_options} + --deepspeed \ + --deepspeed_config ${config_json} \ +" + + +run_cmd="deepspeed --num_nodes ${NUM_WORKERS} --num_gpus ${NUM_GPUS_PER_WORKER} pretrain_gpt2.py ${gpt_options} $@" +echo ${run_cmd} +eval ${run_cmd} + +set +x diff --git a/examples/Megatron-LM/scripts/gpt-3.6b-offload.sh b/examples/Megatron-LM/scripts/gpt-3.6b-offload.sh new file mode 100644 index 0000000..b081e24 --- /dev/null +++ b/examples/Megatron-LM/scripts/gpt-3.6b-offload.sh @@ -0,0 +1,49 @@ +#! /bin/bash + +# Change for multinode config +MP_SIZE=1 + +NUM_WORKERS=1 +NUM_GPUS_PER_WORKER=12 + + +script_path=$(realpath $0) +script_dir=$(dirname $script_path) + +config_json="$script_dir/ds_zero-offload_config_bf16.json" +gpt_options=" \ + --model-parallel-size ${MP_SIZE} \ + --num-layers 30 \ + --hidden-size 3072 \ + --num-attention-heads 32 \ + --batch-size 1 \ + --seq-length 2048 \ + --max-position-embeddings 2048 \ + --train-iters 6 \ + --resume-dataloader \ + --train-data webtext \ + --lazy-loader \ + --tokenizer-type GPT2BPETokenizer \ + --split 949,50,1 \ + --distributed-backend ccl \ + --lr 0.00015 \ + --no-load-optim \ + --lr-decay-style cosine \ + --weight-decay 1e-2 \ + --clip-grad 1.0 \ + --warmup .01 \ + --checkpoint-activations \ + --deepspeed-activation-checkpointing \ + --bf16 \ +" +gpt_options="${gpt_options} + --deepspeed \ + --deepspeed_config ${config_json} \ +" + + +run_cmd="deepspeed --num_nodes ${NUM_WORKERS} --num_gpus ${NUM_GPUS_PER_WORKER} pretrain_gpt2.py $@ ${gpt_options}" +echo ${run_cmd} +eval ${run_cmd} + +set +x diff --git a/examples/Megatron-LM/scripts/gpt-3.6b.sh b/examples/Megatron-LM/scripts/gpt-3.6b.sh new file mode 100755 index 0000000..9b5a469 --- /dev/null +++ b/examples/Megatron-LM/scripts/gpt-3.6b.sh @@ -0,0 +1,61 @@ +#! /bin/bash + +# Change for multinode config +MP_SIZE=1 + +NUM_WORKERS=1 +NUM_GPUS_PER_WORKER=12 + + +script_path=$(realpath $0) +script_dir=$(dirname $script_path) + +config_json="$script_dir/ds_zero2_config_bf16.json" +gpt_options=" \ + --model-parallel-size ${MP_SIZE} \ + --num-layers 30 \ + --hidden-size 3072 \ + --num-attention-heads 32 \ + --batch-size 8 \ + --seq-length 2048 \ + --max-position-embeddings 2048 \ + --train-iters 10 \ + --resume-dataloader \ + --train-data c4/en \ + --lazy-loader \ + --tokenizer-type GPT2BPETokenizer \ + --split 949,50,1 \ + --distributed-backend ccl \ + --lr 0.00015 \ + --no-load-optim \ + --lr-decay-style cosine \ + --weight-decay 1e-2 \ + --clip-grad 1.0 \ + --warmup .01 \ + --checkpoint-activations \ + --deepspeed-activation-checkpointing \ + --bf16 \ +" +gpt_options="${gpt_options} + --deepspeed \ + --deepspeed_config ${config_json} \ +" + +ds_args="" +gpt2_args="" + +for i in "$@"; do + if [[ $i =~ "--oneprof_args" ]]; then + ds_args="$ds_args $i" + elif [[ $i =~ "--onetrace_args" ]]; then + ds_args="$ds_args $i" + else + gpt2_args="$gpt2_args $i" + fi +done + +run_cmd="deepspeed $ds_args --num_nodes ${NUM_WORKERS} --num_gpus ${NUM_GPUS_PER_WORKER} pretrain_gpt2.py ${gpt_options} $gpt2_args" +echo ${run_cmd} +eval ${run_cmd} + +set +x diff --git a/examples/Megatron-LM/scripts/mp2_256m.json b/examples/Megatron-LM/scripts/mp2_256m.json new file mode 100755 index 0000000..30a7d85 --- /dev/null +++ b/examples/Megatron-LM/scripts/mp2_256m.json @@ -0,0 +1,41 @@ +{ + "train_batch_size": 4, + "gradient_accumulation_steps": 1, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00015, + "weight_decay": 1e-2 + } + }, + "zero_optimization": { + "stage": 2 + }, + "zero_allow_untested_optimizer": true, + "gradient_clipping": 1.0, + "fp16": { + "enabled": false, + "loss_scale": 0, + "loss_scale_window": 1000, + "hysteresis": 2, + "min_loss_scale": 1 + }, + "bfloat16": { + "enabled": true, + "loss_scale": 1.0 + }, + "activation_checkpointing": { + "partition_activations": true, + "contiguous_memory_optimization": false + }, + "wall_clock_breakdown": false, + "flops_profiler": { + "enabled": true, + "profile_step": 5, + "module_depth": -1, + "top_modules": 1, + "detailed": true, + "output_file": null + } +} diff --git a/examples/Megatron-LM/scripts/mp2_256m.sh b/examples/Megatron-LM/scripts/mp2_256m.sh new file mode 100755 index 0000000..59b05cc --- /dev/null +++ b/examples/Megatron-LM/scripts/mp2_256m.sh @@ -0,0 +1,48 @@ +#! /bin/bash + +# Change for multinode config +MP_SIZE=2 + +NUM_WORKERS=1 +NUM_GPUS_PER_WORKER=4 + +script_path=$(realpath $0) +script_dir=$(dirname $script_path) + +config_json="$script_dir/mp2_256m.json" +gpt_options=" \ + --model-parallel-size ${MP_SIZE} \ + --num-layers 24 \ + --hidden-size 1024 \ + --num-attention-heads 16 \ + --batch-size 1 \ + --seq-length 512 \ + --max-position-embeddings 1024 \ + --train-iters 100000 \ + --resume-dataloader \ + --train-data webtext \ + --lazy-loader \ + --tokenizer-type GPT2BPETokenizer \ + --split 949,50,1 \ + --distributed-backend ccl \ + --lr 0.00015 \ + --no-load-optim \ + --lr-decay-style cosine \ + --weight-decay 1e-2 \ + --clip-grad 1.0 \ + --warmup .01 \ + --checkpoint-activations \ + --deepspeed-activation-checkpointing \ + --bf16 \ +" +gpt_options="${gpt_options} + --deepspeed \ + --deepspeed_config ${config_json} \ +" + + +run_cmd="deepspeed --num_nodes ${NUM_WORKERS} --num_gpus ${NUM_GPUS_PER_WORKER} pretrain_gpt2.py $@ ${gpt_options}" +echo ${run_cmd} +eval ${run_cmd} + +set +x diff --git a/examples/Megatron-LM/scripts/presplit_sentences_json.py b/examples/Megatron-LM/scripts/presplit_sentences_json.py new file mode 100644 index 0000000..68d0222 --- /dev/null +++ b/examples/Megatron-LM/scripts/presplit_sentences_json.py @@ -0,0 +1,27 @@ +""" +Usage: +python scripts/presplit_sentences_json.py +""" + +import sys +import json + +import nltk + +nltk.download('punkt') + +input_file = sys.argv[1] +output_file = sys.argv[2] + +line_seperator = "\n" + +with open(input_file, 'r') as ifile: + with open(output_file, "w") as ofile: + for doc in ifile.readlines(): + parsed = json.loads(doc) + sent_list = [] + for line in parsed['text'].split('\n'): + if line != '\n': + sent_list.extend(nltk.tokenize.sent_tokenize(line)) + parsed['text'] = line_seperator.join(sent_list) + ofile.write(json.dumps(parsed)+'\n') diff --git a/examples/Megatron-LM/scripts/pretrain_bert.sh b/examples/Megatron-LM/scripts/pretrain_bert.sh new file mode 100755 index 0000000..e7b9769 --- /dev/null +++ b/examples/Megatron-LM/scripts/pretrain_bert.sh @@ -0,0 +1,34 @@ +#!/bin/bash + +RANK=0 +WORLD_SIZE=1 + +python pretrain_bert.py \ + --num-layers 24 \ + --hidden-size 1024 \ + --num-attention-heads 16 \ + --batch-size 4 \ + --seq-length 512 \ + --max-preds-per-seq 80 \ + --max-position-embeddings 512 \ + --train-iters 1000000 \ + --save checkpoints/bert_345m \ + --load checkpoints/bert_345m \ + --resume-dataloader \ + --train-data wikipedia \ + --lazy-loader \ + --tokenizer-type BertWordPieceTokenizer \ + --tokenizer-model-type bert-large-uncased \ + --presplit-sentences \ + --cache-dir cache \ + --split 949,50,1 \ + --distributed-backend nccl \ + --lr 0.0001 \ + --lr-decay-style linear \ + --lr-decay-iters 990000 \ + --weight-decay 1e-2 \ + --clip-grad 1.0 \ + --warmup .01 \ + --fp16 \ + --fp32-layernorm \ + --fp32-embedding diff --git a/examples/Megatron-LM/scripts/pretrain_bert_distributed.sh b/examples/Megatron-LM/scripts/pretrain_bert_distributed.sh new file mode 100755 index 0000000..fe40dc2 --- /dev/null +++ b/examples/Megatron-LM/scripts/pretrain_bert_distributed.sh @@ -0,0 +1,43 @@ +#!/bin/bash + +GPUS_PER_NODE=8 +# Change for multinode config +MASTER_ADDR=localhost +MASTER_PORT=6000 +NNODES=1 +NODE_RANK=0 +WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) + +DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT" + +python -m torch.distributed.launch $DISTRIBUTED_ARGS \ + pretrain_bert.py \ + --num-layers 24 \ + --hidden-size 1024 \ + --num-attention-heads 16 \ + --batch-size 4 \ + --seq-length 512 \ + --max-preds-per-seq 80 \ + --max-position-embeddings 512 \ + --train-iters 1000000 \ + --save checkpoints/bert_345m \ + --load checkpoints/bert_345m \ + --resume-dataloader \ + --train-data wikipedia \ + --lazy-loader \ + --tokenizer-type BertWordPieceTokenizer \ + --tokenizer-model-type bert-large-uncased \ + --presplit-sentences \ + --cache-dir cache \ + --split 949,50,1 \ + --distributed-backend nccl \ + --lr 0.0001 \ + --lr-decay-style linear \ + --lr-decay-iters 990000 \ + --weight-decay 1e-2 \ + --clip-grad 1.0 \ + --warmup .01 \ + --fp16 \ + --fp32-layernorm \ + --fp32-embedding + diff --git a/examples/Megatron-LM/scripts/pretrain_bert_model_parallel.sh b/examples/Megatron-LM/scripts/pretrain_bert_model_parallel.sh new file mode 100644 index 0000000..2cca630 --- /dev/null +++ b/examples/Megatron-LM/scripts/pretrain_bert_model_parallel.sh @@ -0,0 +1,44 @@ +#!/bin/bash + +GPUS_PER_NODE=8 +# Change for multinode config +MASTER_ADDR=localhost +MASTER_PORT=6000 +NNODES=1 +NODE_RANK=0 +WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) + +DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT" + +python -m torch.distributed.launch $DISTRIBUTED_ARGS \ + pretrain_bert.py \ + --model-parallel-size 2 \ + --num-layers 24 \ + --hidden-size 1024 \ + --num-attention-heads 16 \ + --batch-size 4 \ + --seq-length 512 \ + --max-preds-per-seq 80 \ + --max-position-embeddings 512 \ + --train-iters 1000000 \ + --save checkpoints/bert_345m_mp2 \ + --load checkpoints/bert_345m_mp2 \ + --resume-dataloader \ + --train-data wikipedia \ + --lazy-loader \ + --tokenizer-type BertWordPieceTokenizer \ + --tokenizer-model-type bert-large-uncased \ + --presplit-sentences \ + --cache-dir cache \ + --split 949,50,1 \ + --distributed-backend nccl \ + --lr 0.0001 \ + --lr-decay-style linear \ + --lr-decay-iters 990000 \ + --weight-decay 1e-2 \ + --clip-grad 1.0 \ + --warmup .01 \ + --fp16 \ + --fp32-layernorm \ + --fp32-embedding + diff --git a/examples/Megatron-LM/scripts/pretrain_bert_sentencepiece.sh b/examples/Megatron-LM/scripts/pretrain_bert_sentencepiece.sh new file mode 100755 index 0000000..289d371 --- /dev/null +++ b/examples/Megatron-LM/scripts/pretrain_bert_sentencepiece.sh @@ -0,0 +1,35 @@ +#!/bin/bash + +RANK=0 +WORLD_SIZE=1 + +python pretrain_bert.py \ + --num-layers 24 \ + --hidden-size 1024 \ + --num-attention-heads 16 \ + --batch-size 4 \ + --seq-length 512 \ + --max-preds-per-seq 80 \ + --max-position-embeddings 512 \ + --train-iters 1000000 \ + --save checkpoints/bert_345m \ + --load checkpoints/bert_345m \ + --resume-dataloader \ + --train-data wikipedia \ + --lazy-loader \ + --tokenizer-type SentencePieceTokenizer \ + --tokenizer-model-type bpe \ + --tokenizer-path tokenizer.model \ + --presplit-sentences \ + --cache-dir cache \ + --split 949,50,1 \ + --distributed-backend nccl \ + --lr 0.0001 \ + --lr-decay-style linear \ + --lr-decay-iters 990000 \ + --weight-decay 1e-2 \ + --clip-grad 1.0 \ + --warmup .01 \ + --fp16 \ + --fp32-layernorm \ + --fp32-embedding diff --git a/examples/Megatron-LM/scripts/pretrain_bert_tfrecords_distributed.sh b/examples/Megatron-LM/scripts/pretrain_bert_tfrecords_distributed.sh new file mode 100755 index 0000000..436c92c --- /dev/null +++ b/examples/Megatron-LM/scripts/pretrain_bert_tfrecords_distributed.sh @@ -0,0 +1,44 @@ +#!/bin/bash + +GPUS_PER_NODE=8 +# Change for multinode config +MASTER_ADDR=localhost +MASTER_PORT=6000 +NNODES=1 +NODE_RANK=0 +WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) + +DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT" + +python -m torch.distributed.launch $DISTRIBUTED_ARGS \ + pretrain_bert.py \ + --num-layers 24 \ + --hidden-size 1024 \ + --num-attention-heads 16 \ + --batch-size 4 \ + --seq-length 512 \ + --max-preds-per-seq 80 \ + --max-position-embeddings 512 \ + --train-iters 1000000 \ + --save checkpoints/bert_345m \ + --load checkpoints/bert_345m \ + --resume-dataloader \ + --use-tfrecords \ + --train-data \ + --valid-data \ + --test-data \ + --tokenizer-type BertWordPieceTokenizer \ + --tokenizer-model-type bert-large-uncased \ + --presplit-sentences \ + --cache-dir cache \ + --split 949,50,1 \ + --distributed-backend nccl \ + --lr 0.0001 \ + --lr-decay-style linear \ + --lr-decay-iters 990000 \ + --weight-decay 1e-2 \ + --clip-grad 1.0 \ + --warmup .01 \ + --fp16 \ + --fp32-layernorm \ + --fp32-embedding diff --git a/examples/Megatron-LM/scripts/pretrain_gpt2.sh b/examples/Megatron-LM/scripts/pretrain_gpt2.sh new file mode 100644 index 0000000..2cee4bf --- /dev/null +++ b/examples/Megatron-LM/scripts/pretrain_gpt2.sh @@ -0,0 +1,34 @@ +#! /bin/bash + +# Runs the "345M" parameter model + +RANK=0 +WORLD_SIZE=1 + +python pretrain_gpt2.py \ + --num-layers 24 \ + --hidden-size 1024 \ + --num-attention-heads 16 \ + --batch-size 8 \ + --seq-length 1024 \ + --max-position-embeddings 1024 \ + --train-iters 320000 \ + --save checkpoints/gpt2_345m \ + --load checkpoints/gpt2_345m \ + --resume-dataloader \ + --train-data wikipedia \ + --lazy-loader \ + --tokenizer-type GPT2BPETokenizer \ + --cache-dir cache \ + --split 949,50,1 \ + --distributed-backend nccl \ + --lr 0.00015 \ + --lr-decay-style cosine \ + --weight-decay 1e-2 \ + --clip-grad 1.0 \ + --warmup .01 \ + --checkpoint-activations \ + --fp16 + + +set +x diff --git a/examples/Megatron-LM/scripts/pretrain_gpt2_distributed.sh b/examples/Megatron-LM/scripts/pretrain_gpt2_distributed.sh new file mode 100755 index 0000000..9c96020 --- /dev/null +++ b/examples/Megatron-LM/scripts/pretrain_gpt2_distributed.sh @@ -0,0 +1,42 @@ +#! /bin/bash + +# Runs the "345M" parameter model + +GPUS_PER_NODE=8 +# Change for multinode config +MASTER_ADDR=localhost +MASTER_PORT=6000 +NNODES=1 +NODE_RANK=0 +WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) + +DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT" + +python -m torch.distributed.launch $DISTRIBUTED_ARGS \ + pretrain_gpt2.py \ + --num-layers 24 \ + --hidden-size 1024 \ + --num-attention-heads 16 \ + --batch-size 8 \ + --seq-length 1024 \ + --max-position-embeddings 1024 \ + --train-iters 320000 \ + --save checkpoints/gpt2_345m \ + --load checkpoints/gpt2_345m \ + --resume-dataloader \ + --train-data wikipedia \ + --lazy-loader \ + --tokenizer-type GPT2BPETokenizer \ + --cache-dir cache \ + --split 949,50,1 \ + --distributed-backend nccl \ + --lr 0.00015 \ + --lr-decay-style cosine \ + --weight-decay 1e-2 \ + --clip-grad 1.0 \ + --warmup .01 \ + --checkpoint-activations \ + --fp16 + + +set +x diff --git a/examples/Megatron-LM/scripts/pretrain_gpt2_model_parallel.sh b/examples/Megatron-LM/scripts/pretrain_gpt2_model_parallel.sh new file mode 100644 index 0000000..fd4ebf9 --- /dev/null +++ b/examples/Megatron-LM/scripts/pretrain_gpt2_model_parallel.sh @@ -0,0 +1,42 @@ +#! /bin/bash + +# Runs the "345M" parameter model + +GPUS_PER_NODE=8 +# Change for multinode config +MASTER_ADDR=localhost +MASTER_PORT=6000 +NNODES=1 +NODE_RANK=0 +WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) + +DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT" + +python -m torch.distributed.launch $DISTRIBUTED_ARGS \ + pretrain_gpt2.py \ + --model-parallel-size 2 \ + --num-layers 24 \ + --hidden-size 1024 \ + --num-attention-heads 16 \ + --batch-size 8 \ + --seq-length 1024 \ + --max-position-embeddings 1024 \ + --train-iters 320000 \ + --save checkpoints/gpt2_345m_mp2 \ + --load checkpoints/gpt2_345m_mp2 \ + --resume-dataloader \ + --train-data webtext \ + --lazy-loader \ + --tokenizer-type GPT2BPETokenizer \ + --split 949,50,1 \ + --distributed-backend nccl \ + --lr 0.00015 \ + --no-load-optim \ + --lr-decay-style cosine \ + --weight-decay 1e-2 \ + --clip-grad 1.0 \ + --warmup .01 \ + --checkpoint-activations \ + --fp16 \ + --deepspeed +set +x diff --git a/examples/Megatron-LM/scripts/run_gpt2_eval.py b/examples/Megatron-LM/scripts/run_gpt2_eval.py new file mode 100644 index 0000000..516448d --- /dev/null +++ b/examples/Megatron-LM/scripts/run_gpt2_eval.py @@ -0,0 +1,89 @@ +""" +example usage: +python scripts/run_gpt2_eval.py \ + --model-parallel-size 1 \ + --num-layers 12 \ + --hidden-size 768 \ + --num-attention-heads 12 \ + --model-path \ + --data-path \ + --batch-size 16 \ + --cache-dir +""" +import argparse +import subprocess + +parser = argparse.ArgumentParser('run zero shot GPT2 eval') +parser.add_argument('--model-path', type=str, required=True, + help='Saved model path for evaluation') +parser.add_argument('--batch-size', type=int, default=4, + help='batch size to use for evaluation') +parser.add_argument('--num-attention-heads', type=int, default=12, + help='num of transformer attention heads') +parser.add_argument('--hidden-size', type=int, default=768, + help='tansformer hidden size') +parser.add_argument('--num-layers', type=int, default=12, + help='num decoder layers') +parser.add_argument('--data-path', type=str, required=True, + help='Data path for evaluation data') +parser.add_argument('--cloze-eval', action='store_true', + help='Run lambada cloze eval instead of perplexity eval.') +parser.add_argument('--webtext-eval', action='store_true', + help='Run webtext PPL eval instead of wikitext PPL eval.') +parser.add_argument('--eval-iters', default=5000, type=int, + help='number of iterations to run webtext evaluation') +parser.add_argument('--model-parallel-size', type=int, default=1, + help='model parallel size to use') +parser.add_argument('--load-openai', action='store_true', + help='Load weights from saved openai/hf checkpoints') +parser.add_argument('--cache-dir', type=str, default='cache', + help='directory to cache gpt2 tokenizers') +args = parser.parse_args() + +multinode_args = '' +if args.model_parallel_size > 1: + multinode_args += ' -m torch.distributed.launch --nproc_per_node {} '.format(args.model_parallel_size) + +CMD = ' --model-parallel-size {model_par} \ + --num-layers {nlayers} \ + --hidden-size {hidden} \ + --log-interval 100 \ + --load {model} \ + --eval-batch-size {batch} \ + --num-attention-heads {natt} \ + --seq-length 1024 \ + --max-position-embeddings 1024 \ + --tokenizer-type GPT2BPETokenizer \ + --text-key text \ + --distributed-backend nccl \ + --hidden-dropout 0.1 \ + --attention-dropout 0.1 \ + --fp16 \ + --overlapping-eval 32 \ + --cache-dir {cache} '.format(model_par=args.model_parallel_size, + nlayers=args.num_layers, + hidden=args.hidden_size, + model=args.model_path, + batch=args.batch_size, + natt=args.num_attention_heads, + cache=args.cache_dir) + +if args.load_openai: + CMD += ' --load-openai ' +if args.cloze_eval: + CMD += ' --cloze-eval ' + CMD = 'evaluate_gpt2.py' + CMD + print('Running Lambada Eval Command:', flush=True) +elif args.webtext_eval: + CMD += '--train-iters 0 --eval-iters {} --test-data {} --loose-json '.format(args.eval_iters, args.data_path) + CMD = 'pretrain_gpt2.py' + CMD + print('Running Webtext Eval Command:', flush=True) +else: + CMD += ' --valid-data {} '.format(args.data_path) + CMD = 'evaluate_gpt2.py' + CMD + print('Running PPL Eval Command:', flush=True) + +CMD = 'python3 '+multinode_args+CMD +print(CMD, flush=True) + +subprocess.call(CMD.split()) diff --git a/examples/Megatron-LM/scripts/split_json.py b/examples/Megatron-LM/scripts/split_json.py new file mode 100644 index 0000000..c0b1415 --- /dev/null +++ b/examples/Megatron-LM/scripts/split_json.py @@ -0,0 +1,119 @@ +""" +Takes a corpora of files (specified by `--input_files`) with json data separated +by newlines (loose json). Splits data into train.json, val.json, test.json files +under `output_dir`. + +Note: This code has the potential to override files with the names +train.json, val.json, test.json in `--output_dir`. +""" +import os +import argparse +import math +import random + +parser = argparse.ArgumentParser('resplit loose json data into train/val/test') +parser.add_argument('--input_files', nargs='+', required=True, + help='whitespace separated list of input data files') +parser.add_argument('--output_dir', required=True, + help='output directory where to put files') +parser.add_argument('--test_percent', type=float, nargs='+', default=[0.05, 0], + help='percentage of available data to use for val/test dataset') +args = parser.parse_args() + +def get_lines(filepath): + lines = [] + with open(filepath, 'r') as f: + for i, l in enumerate(f.readlines()): + l = l.strip() + lines.append(l) + return lines + +def get_splits(lines, line_counts): + all_lines = [] + line_idx = [] + file_mappings = [] + for i, l in enumerate(lines): + all_lines.extend(l) + line_idx.extend(list(range(len(l)))) + file_mappings.extend([i]*len(l)) + + indices = list(range(len(all_lines))) + random.shuffle(indices) + all_lines = [all_lines[idx] for idx in indices] + line_idx = [line_idx[idx] for idx in indices] + file_mappings = [file_mappings[idx] for idx in indices] + + splits = [] + mappings = [] + start = 0 + for end in line_counts: + end += start + splits.append(all_lines[start:end]) + mappings.append(format_mappings(line_idx[start:end], file_mappings[start:end])) + start = end + return splits, mappings + +def format_mappings(line_idx, file_mappings): + lines = [] + for m, l in zip(file_mappings, line_idx): + lines.append(str(m).strip()+'\t'+str(l).strip()) + return lines + + +def get_filepaths(filepaths, output_dir): + paths = [] + train_path = 'train.json' + dev_path = 'dev.json' + test_path = 'test.json' + paths.append(os.path.join(output_dir, train_path)) + paths.append(os.path.join(output_dir, dev_path)) + paths.append(os.path.join(output_dir, test_path)) + return paths + +def write_files(lines, mappings, filepaths): + for l, m, path in zip(lines, mappings, filepaths): + write_file(l, path) + write_mapping_file(m, path) + +def write_file(lines, path): + print('Writing:', path) + with open(path, 'w') as f: + for l in lines: + f.write(l+'\n') + +def write_mapping_file(m, path): + path = path+'.map' + m = [get_mapping_header()]+m + write_file(m, path) + +def get_mapping_header(): + return 'file\tline #' + +if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + +lines = [] + +for filepath in args.input_files: + _lines = get_lines(filepath) + lines.append(_lines) + +#calculate number of lines to use for each +line_counts = [len(l) for l in lines] +total_lines = sum(line_counts) +dev_percent = args.test_percent[0] +dev_lines = math.ceil(dev_percent*total_lines) +test_percent = 0 +if len(args.test_percent)==2: + test_percent=args.test_percent[1] +test_lines = math.ceil(test_percent*total_lines) +train_lines = total_lines-(test_lines+dev_lines) +normed_lines = [train_lines, dev_lines, test_lines] +normed_lines = [int(l) for l in normed_lines] + + +splits, mappings = get_splits(lines, normed_lines) +filepaths = get_filepaths(args.input_files, args.output_dir) +print('Writing output to:', filepaths) +write_files(splits, mappings, filepaths) + diff --git a/examples/Megatron-LM/utils.py b/examples/Megatron-LM/utils.py new file mode 100644 index 0000000..acbcfbd --- /dev/null +++ b/examples/Megatron-LM/utils.py @@ -0,0 +1,411 @@ +# coding=utf-8 +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for logging and serialization""" + +import os +import random +import time +import numpy as np +import torch +from deepspeed.accelerator.real_accelerator import get_accelerator + +from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP +from fp16 import FP16_Optimizer +import mpu +import model + + +def print_rank_0(message): + if torch.distributed.is_initialized(): + if torch.distributed.get_rank() == 0: + print(message, flush=True) + else: + print(message, flush=True) + + +def print_args(args): + """Print arguments.""" + + print('arguments:', flush=True) + for arg in vars(args): + dots = '.' * (29 - len(arg)) + print(' {} {} {}'.format(arg, dots, getattr(args, arg)), flush=True) + + +def print_params_min_max_norm(optimizer, iteration): + """Print min, max, and norm of all parameters.""" + index = 0 + rank = torch.distributed.get_rank() + string = 'iteration, rank, index, model-parallel,min, max, norm\n' + optimizer_ = optimizer + if isinstance(optimizer, FP16_Optimizer): + optimizer_ = optimizer.optimizer + for param_group in optimizer_.param_groups: + for param in param_group['params']: + index += 1 + min_ = param.data.min() + max_ = param.data.max() + norm = param.data.norm() + string += '{:7d}, {:4d}, {:4d}, {:2d}, '.format( + iteration, rank, index, int(param.model_parallel)) + string += '{:.6E}, {:.6E}, {:.6E}\n'.format(min_, max_, norm) + print(string, flush=True) + + +class Timers: + """Group of timers.""" + + class Timer: + """Timer.""" + + def __init__(self, name): + self.name_ = name + self.elapsed_ = 0.0 + self.started_ = False + self.start_time = time.time() + + def start(self): + """Start the timer.""" + assert not self.started_, 'timer has already been started' + get_accelerator().synchronize() + self.start_time = time.time() + self.started_ = True + + def stop(self): + """Stop the timer.""" + assert self.started_, 'timer is not started' + get_accelerator().synchronize() + self.elapsed_ += (time.time() - self.start_time) + self.started_ = False + + def reset(self): + """Reset timer.""" + self.elapsed_ = 0.0 + self.started_ = False + + def elapsed(self, reset=True): + """Calculate the elapsed time.""" + started_ = self.started_ + # If the timing in progress, end it first. + if self.started_: + self.stop() + # Get the elapsed time. + elapsed_ = self.elapsed_ + # Reset the elapsed time + if reset: + self.reset() + # If timing was in progress, set it back. + if started_: + self.start() + return elapsed_ + + def __init__(self): + self.timers = {} + + def __call__(self, name): + if name not in self.timers: + self.timers[name] = self.Timer(name) + return self.timers[name] + + def log(self, names, normalizer=1.0, reset=True): + """Log a group of timers.""" + assert normalizer > 0.0 + string = 'time (ms)' + for name in names: + elapsed_time = self.timers[name].elapsed( + reset=reset) * 1000.0/ normalizer + string += ' | {}: {:.2f}'.format(name, elapsed_time) + print_rank_0(string) + + +def report_memory(name): + """Simple GPU memory report.""" + + mega_bytes = 1024.0 * 1024.0 + string = name + ' memory (MB)' + string += ' | allocated: {}'.format( + get_accelerator().memory_allocated() / mega_bytes) + string += ' | max allocated: {}'.format( + get_accelerator().max_memory_allocated() / mega_bytes) + string += ' | cached: {}'.format(get_accelerator().memory_cached() / mega_bytes) + string += ' | max cached: {}'.format( + get_accelerator().max_memory_cached()/ mega_bytes) + print_rank_0(string) + + +def get_checkpoint_name(checkpoints_path, iteration, release=False, zero=False): + if release: + d = 'release' + else: + d = 'iter_{:07d}'.format(iteration) + if zero: + dp_rank = mpu.get_data_parallel_rank() + d += '_zero_dp_rank_{}'.format(dp_rank) + return os.path.join(checkpoints_path, d, + 'mp_rank_{:02d}'.format(mpu.get_model_parallel_rank()), + 'model_optim_rng.pt') + + +def ensure_directory_exists(filename): + dirname = os.path.dirname(filename) + if not os.path.exists(dirname): + os.makedirs(dirname) + + +def get_checkpoint_tracker_filename(checkpoints_path): + return os.path.join(checkpoints_path, 'latest_checkpointed_iteration.txt') + + +def save_zero_checkpoint(args, iteration, optimizer): + zero_sd = {'iteration': iteration, + 'optimizer_state_dict': optimizer.state_dict()} + zero_checkpoint_name = get_checkpoint_name(args.save, iteration, zero=True) + ensure_directory_exists(zero_checkpoint_name) + torch.save(zero_sd, zero_checkpoint_name) + print(' successfully saved {}'.format(zero_checkpoint_name)) + +def save_checkpoint(iteration, model, optimizer, + lr_scheduler, args): + """Save a model checkpoint.""" + if args.deepspeed: + save_ds_checkpoint(iteration, model, args) + else: + # Only rank zer0 of the data parallel writes to the disk. + if isinstance(model, torchDDP): + model = model.module + + if mpu.get_data_parallel_rank() == 0: + checkpoint_name = get_checkpoint_name(args.save, iteration) + print('global rank {} is saving checkpoint at iteration {:7d} to {}'. + format(torch.distributed.get_rank(), iteration, checkpoint_name)) + + sd = {} + sd['iteration'] = iteration + sd['model'] = model.state_dict() + + # Optimizer stuff. + if not args.no_save_optim: + if optimizer is not None: + sd['optimizer'] = optimizer.state_dict() + if lr_scheduler is not None: + sd['lr_scheduler'] = lr_scheduler.state_dict() + + # rng states. + if not args.no_save_rng: + sd['random_rng_state'] = random.getstate() + sd['np_rng_state'] = np.random.get_state() + sd['torch_rng_state'] = torch.get_rng_state() + sd['cuda_rng_state'] = get_accelerator().get_rng_state() + sd['rng_tracker_states'] = mpu.get_cuda_rng_tracker().get_states() + + + ensure_directory_exists(checkpoint_name) + torch.save(sd, checkpoint_name) + print(' successfully saved {}'.format(checkpoint_name)) + + # Wait so everyone is done (necessary) + torch.distributed.barrier() + # And update the latest iteration + if torch.distributed.get_rank() == 0: + tracker_filename = get_checkpoint_tracker_filename(args.save) + with open(tracker_filename, 'w') as f: + f.write(str(iteration)) + # Wait so everyone is done (not necessary) + torch.distributed.barrier() + +def save_ds_checkpoint(iteration, model, args): + """Save a model checkpoint.""" + + sd = {} + sd['iteration'] = iteration + # rng states. + if not args.no_save_rng: + sd['random_rng_state'] = random.getstate() + sd['np_rng_state'] = np.random.get_state() + sd['torch_rng_state'] = torch.get_rng_state() + sd['cuda_rng_state'] = get_accelerator().get_rng_state() + sd['rng_tracker_states'] = mpu.get_cuda_rng_tracker().get_states() + + model.save_checkpoint(args.save, iteration, client_state = sd) + + +def get_checkpoint_iteration(args): + # Read the tracker file and set the iteration. + tracker_filename = get_checkpoint_tracker_filename(args.load) + if not os.path.isfile(tracker_filename): + print_rank_0('WARNING: could not find the metadata file {} '.format( + tracker_filename)) + print_rank_0(' will not load any checkpoints and will start from ' + 'random') + return 0, False, False + iteration = 0 + release = False + with open(tracker_filename, 'r') as f: + metastring = f.read().strip() + try: + iteration = int(metastring) + except ValueError: + release = metastring == 'release' + if not release: + print_rank_0('ERROR: Invalid metadata file {}. Exiting'.format( + tracker_filename)) + exit() + + assert iteration > 0 or release, 'error parsing metadata file {}'.format( + tracker_filename) + + return iteration, release, True + +def load_checkpoint(model, optimizer, lr_scheduler, args): + """Load a model checkpoint.""" + + iteration, release, success = get_checkpoint_iteration(args) + + if not success: + return 0 + + if args.deepspeed: + + checkpoint_name, sd = model.load_checkpoint(args.load, iteration) + + if checkpoint_name is None: + if mpu.get_data_parallel_rank() == 0: + print("Unable to load checkpoint.") + return iteration + + else: + + # Checkpoint. + checkpoint_name = get_checkpoint_name(args.load, iteration, release) + + if mpu.get_data_parallel_rank() == 0: + print('global rank {} is loading checkpoint {}'.format( + torch.distributed.get_rank(), checkpoint_name)) + + # Load the checkpoint. + sd = torch.load(checkpoint_name, map_location='cpu') + + if isinstance(model, torchDDP): + model = model.module + + # Model. + try: + model.load_state_dict(sd['model']) + except KeyError: + print_rank_0('A metadata file exists but unable to load model ' + 'from checkpoint {}, exiting'.format(checkpoint_name)) + exit() + + # Optimizer. + if not release and not args.finetune and not args.no_load_optim: + try: + if optimizer is not None: + optimizer.load_state_dict(sd['optimizer']) + if lr_scheduler is not None: + lr_scheduler.load_state_dict(sd['lr_scheduler']) + except KeyError: + print_rank_0('Unable to load optimizer from checkpoint {}, exiting. ' + 'Specify --no-load-optim or --finetune to prevent ' + 'attempting to load the optimizer ' + 'state.'.format(checkpoint_name)) + exit() + + # Iterations. + if args.finetune or release: + iteration = 0 + else: + try: + iteration = sd['iteration'] + except KeyError: + try: # Backward compatible with older checkpoints + iteration = sd['total_iters'] + except KeyError: + print_rank_0('A metadata file exists but Unable to load iteration ' + ' from checkpoint {}, exiting'.format(checkpoint_name)) + exit() + + # rng states. + if not release and not args.finetune and not args.no_load_rng: + try: + random.setstate(sd['random_rng_state']) + np.random.set_state(sd['np_rng_state']) + torch.set_rng_state(sd['torch_rng_state']) + get_accelerator().set_rng_state(sd['cuda_rng_state']) + mpu.get_cuda_rng_tracker().set_states(sd['rng_tracker_states']) + except KeyError: + print_rank_0('Unable to load optimizer from checkpoint {}, exiting. ' + 'Specify --no-load-optim or --finetune to prevent ' + 'attempting to load the optimizer ' + 'state.'.format(checkpoint_name)) + exit() + + torch.distributed.barrier() + if mpu.get_data_parallel_rank() == 0: + print(' successfully loaded {}'.format(checkpoint_name)) + + return iteration + + +def load_weights(src, dst, dst2src=False): + """ + Loads weights from src to dst via in place copy. + src is a huggingface gpt2model, while dst is one of our models. + dst2src=True loads parameters from our models into huggingface's. + ^dst2src is still untested + """ + conv_layer = 'Conv1D' in str(type(src)) + for n, p in src.named_parameters(): + if dst2src: + data = dst._parameters[n].data + load = p.data + else: + data = p.data + load = dst._parameters[n].data + if conv_layer and 'weight' in n: + data = data.t().contiguous() + load.copy_(data) +# dst._parameters[n].data.copy_(data) + +def load_mlp(our, oai, dst2src=False): + load_weights(oai.c_fc, our.dense_h_to_4h, dst2src) + load_weights(oai.c_proj, our.dense_4h_to_h, dst2src) + +def load_attention(our, oai, dst2src=False): + load_weights(oai.c_attn, our.query_key_value, dst2src) + load_weights(oai.c_proj, our.dense, dst2src) + +def load_transformer_layer(our, oai, dst2src=False): + load_weights(oai.ln_1, our.input_layernorm, dst2src) + load_weights(oai.ln_2, our.post_attention_layernorm, dst2src) + load_mlp(our.mlp, oai.mlp, dst2src) + load_attention(our.attention, oai.attn, dst2src) + +def move_weights(our, oai, dst2src=False): + """ + Loads weights from `oai` to `our` via in place copy. + `oai` is a huggingface gpt2model, while `our` is one of our models. + dst2src=True loads parameters from our models into huggingface's. + ^dst2src=True is still untested + """ +# while isinstance(our, (torchDDP, model.distributed.DistributedDataParallel, FP16_Module)): +# our=our.module + transformer_model = oai.transformer + load_weights(transformer_model.ln_f, our.transformer.final_layernorm, dst2src) + load_weights(transformer_model.wte, our.word_embeddings, dst2src) + load_weights(transformer_model.wpe, our.position_embeddings, dst2src) + + for our_layer, oai_layer in zip(our.transformer.layers, oai.transformer.h): + load_transformer_layer(our_layer, oai_layer, dst2src) diff --git a/examples/README.md b/examples/README.md new file mode 100644 index 0000000..add6eaa --- /dev/null +++ b/examples/README.md @@ -0,0 +1,8 @@ + +# DeepSpeed +This repo contains example models that use [DeepSpeed](https://github.com/microsoft/DeepSpeed). + +# Note on Megatron examples + +This is a porting of [Megatron-LM](https://github.com/microsoft/DeepSpeedExamples/tree/master/megatron/Megatron-LM) that supports Intel GPU devices. + diff --git a/examples/autotuning/.gitignore b/examples/autotuning/.gitignore new file mode 100644 index 0000000..82319e4 --- /dev/null +++ b/examples/autotuning/.gitignore @@ -0,0 +1,4 @@ +autotuning_results* +autotuning_exps* +output* +mnli diff --git a/examples/autotuning/README.md b/examples/autotuning/README.md new file mode 100644 index 0000000..d028a94 --- /dev/null +++ b/examples/autotuning/README.md @@ -0,0 +1,3 @@ +# Autotuning Examples + +This showcases the [autotuning](https://github.com/microsoft/DeepSpeed/tree/master/deepspeed/autotuning) feature in DeepSpeed (DS). diff --git a/examples/autotuning/hf/README.md b/examples/autotuning/hf/README.md new file mode 100644 index 0000000..567deda --- /dev/null +++ b/examples/autotuning/hf/README.md @@ -0,0 +1,62 @@ +# Autotuning Hugging Face Examples + +This showcases the [autotuning](https://github.com/microsoft/DeepSpeed/tree/master/deepspeed/autotuning) feature in DeepSpeed (DS) with Hugging Face (HF) models. + +## List of Models + +- [DistilBERT](distilbert) +- [BERT-base](bert-base) +- [BERT-large](bert-large) +- [GPT2](gpt2) +- [GPT2-medium](gpt2-medium) +- [GPT2-large](gpt2-large) +- [GPT2-xl](gpt2-xl) +- [DeBERTa](deberta) + +Each model folder has a `test_tune.sh` script: + +- `./test_tune.sh tune` tunes the model training and then runs it using the selected tuned DeepSpeed configuration. +- `./test_tune.sh 0` runs the model using HF without DeepSpeed. +- `./test_tune.sh z0` runs the model using HF + DS with ZeRO optimization disabled. +- `./test_tune.sh z1` runs the model using HF + DS with ZeRO optimization stage 1. +- `./test_tune.sh z2` runs the model using HF + DS with ZeRO optimization stage 2. +- `./test_tune.sh z3` runs the model using HF + DS with ZeRO optimization stage 3. + + +## Testing Environment + +The training runs on 1 node with 16 Nvidia V100 GPUs. The autotuning uses the same hardware resource as the training. +The HF packages below are used. + +HF examples require installing the `transformers` package from source: +```bash + git clone https://github.com/huggingface/transformers.git + cd transformers + pip install . +``` +The `datasets` package can be installed by `pip install datasets` + +Below are the versions used in this test. + +- transformers (4.12.0) +- datasets (1.11.0) + +## Throughput Comparison + +The table below shows the throughput (samples per second) comparison. The corresponding train micro-batch size per GPU (mbs or tmbspg) and ZeRO stage used to achieve the throughput value is also shown in the parentheses. Assume the strategy users would use in the handtuning process is to start from `mbs = 1` and increase mbs by 2 each time until running out of GPU memory. + - `baseline` is the vanila HF without DeepSpeed (DS) and mbs is hand-tuned. + - `HF + DS hand-tuned` is HF with DS, and mbs is hand-tuned while other DS configuration uses default values. + - `HF + DS autotuning` is HF with DS, and the DS configuration is selected from autotuning. + +Notation: Hugging Face (HF), DeepSpeed (DS), ZeRO stage (z), gradient accumulation steps (gas), train micro-batch size per GPU (mbs or tmbspg). + +| Model name | num_params | baseline (vanila HF) | HF + DS hand-tuned | HF + DS autotuning (fast-mode) | throughput improvement over baseline | autotuning time (mins) | number of experiments | +| :----------: | :--------: | :---------------------------: | :----------------------------------: | :----------------------------: | :----------------------------------: | :--------------------: | :-------------------: | +| DistilBERT | 66M | 5161.902 (gas = 1, mbs = 256) | 5305.067 (z = 0, gas = 1 mbs = 256) | 5305.067 (z0_gas1_tmbspg256) | 1.03x | 11 | 11 | +| BERT-base | 0.11B | 2502.236 (gas = 1,mbs = 128) | 2523.684 (z = 0, gas = 1, mbs = 128) | 2736.561 (z0_gas1_tmbspg235) | 1.09x | 35 | 34 | +| BERT-large | 0.34B | 742.692 (gas = 1,mbs = 64) | 766.929 (z = 1, gas = 1, mbs = 64) | 808.168 (z1_gas1_tmbspg93) | 1.09x | 36 | 22 | +| GPT2 | 0.12B | 284.142 (gas = 1,mbs = 8) | 397.827 (z = 1, gas = 1, mbs = 8) | 431.586 (z1_gas1_tmbspg14) | 1.52x | 25 | 17 | +| GPT2-medium | 0.35B | 71.61 (gas = 1, mbs = 2) | 142.211 (z = 1, gas = 1, mbs = 4) | 163.3 (z1_gas1_tmbspg6) | 2.28 | 15 | 25 | +| GPT2-large | 0.77B | 27.874 (gas = 1, mbs = 1) | 56.797 (z = 1, gas = 1, mbs = 2) | 69.061 (z = 1, mbs = 3) | 2.48x | 27 | 13 | +| GPT2-xl | 1.5B | Not runnable | 27.462 (gas = 1, mbs = 1) | 27.497 (z1_gas1_tmbspg1) | inf | 21 | 9 | +| DeBERTa | 1.5B | Not runnable | 140.587 (z = 1, gas = 1 mbs = 8) | 162.395 (z1_gas1_tmbspg11) | inf | 40 | 12 | diff --git a/examples/autotuning/hf/bert-base/README.md b/examples/autotuning/hf/bert-base/README.md new file mode 100644 index 0000000..02450fd --- /dev/null +++ b/examples/autotuning/hf/bert-base/README.md @@ -0,0 +1,58 @@ +# [bert-base-cased](https://huggingface.co/bert-base-cased) + +This model has the following configuration: + +- 12-layer +- 768 hidden dimension +- 12 attention heads +- 110M parameters. + +## Environment + +The training use fp32 and runs on 1 node with 16 Nvidia V100 GPUs. The autotuning uses the same hardware resource as the training. `max_train_batch_size` is set to `4096`. +The HF packages below are used. + +HF examples require installing the `transformers` package from source: +```bash + git clone https://github.com/huggingface/transformers.git + cd transformers + pip install . +``` +The `datasets` package can be installed by `pip install datasets` + +Below are the versions used in this test. + +- transformers (4.12.0) +- datasets (1.11.0) + +## Throughput Comparison + +The table below shows the throughput (samples per second) comparison. The corresponding train micro-batch size per GPU (mbs or tmbspg) and ZeRO stage used to achieve the throughput value is also shown in the parentheses. Assume the strategy users would use in the handtuning process is to start from `mbs = 1` and increase mbs by 2 each time until running out of GPU memory. + - `baseline` is the vanila HF without DeepSpeed (DS) and mbs is hand-tuned. + - `HF + DS hand-tuned` is HF with DS, and mbs is hand-tuned while other DS configuration uses default values. + - `HF + DS autotuning` is HF with DS, and the DS configuration is selected from autotuning. + +Notation: Hugging Face (HF), DeepSpeed (DS), ZeRO stage (z), gradient accumulation steps (gas), train micro-batch size per GPU (mbs or tmbspg). + +| Model name | baseline (vanila HF) | HF + DS handtuned | HF + DS autotuning | +| ---------- | ----------------------------- | ------------------------------------ | ---------------------------- | +| BERT-base | 2502.236 (gas = 1, mbs = 128) | 2523.684 (z = 0, gas = 1, mbs = 128) | 2736.561 (z0_gas1_tmbspg235) | + +## Detailed `HF + DS autotuning` Result Summary + +Note that the performance metric used in autotuning is calculated using the timings captured within DeepSpeed forward, backward, and step functions. The sum of these timings is less than the actual training step latency, thus the throughput metric values used by autotuning would be higher than the end-to-end throughput in training. + +- Fast-mode Autotuning time: 35 mins +- Number of experiments: 34 +- Throughput Improvement over baseline: 1.09x + + +| tuning_space | num_experiments | best_metric_val | best_exp_name | +| :----------- | --------------: | --------------: | :---------------- | +| z0 | 9 | 2930.18 | z0_gas1_tmbspg235 | +| z1 | 7 | 2930.17 | z1_gas1_tmbspg235 | +| z2 | 8 | 2744.16 | z2_gas1_tmbspg235 | +| z3 | 10 | 2479.47 | z3_gas1_tmbspg238 | +| global | 34 | 2930.18 | z0_gas1_tmbspg235 | + +Tuning completed in 0:34:41.842250. Total number of experiments: 34. diff --git a/examples/autotuning/hf/bert-base/ds_config_tune.json b/examples/autotuning/hf/bert-base/ds_config_tune.json new file mode 100644 index 0000000..23a48dd --- /dev/null +++ b/examples/autotuning/hf/bert-base/ds_config_tune.json @@ -0,0 +1,12 @@ +{ + "train_micro_batch_size_per_gpu": "auto", + "autotuning": { + "enabled": true, + "overwrite": false, + "max_train_batch_size": 4096, + "arg_mappings": { + "train_micro_batch_size_per_gpu": "--per_device_train_batch_size", + "gradient_accumulation_steps ": "--gradient_accumulation_steps" + } + } +} diff --git a/examples/autotuning/hf/bert-base/test_tune.sh b/examples/autotuning/hf/bert-base/test_tune.sh new file mode 100755 index 0000000..532efc9 --- /dev/null +++ b/examples/autotuning/hf/bert-base/test_tune.sh @@ -0,0 +1,114 @@ +TASK_NAME=mnli +MODEL_NAME=bert-base-cased +HF_PATH=~/projects +PER_DEVICE_TRAIN_BATCH_SIZE=64 +MAX_TRAIN_BATCH_SIZE=4096 +NEPOCHS=1 +NGPUS=16 +NNODES=1 +MAX_STEPS=200 +OUTPUT_DIR=./${TASK_NAME}/output_b${PER_DEVICE_TRAIN_BATCH_SIZE}_g${NGPUS}_$MAX_STEPS + +TEST=$1 + +if [ ${TEST} == "0" ] +then + python -m torch.distributed.launch --nproc_per_node=$NGPUS $HF_PATH/transformers/examples/pytorch/text-classification/run_glue.py \ + --model_name_or_path $MODEL_NAME \ + --task_name $TASK_NAME \ + --do_train \ + --max_seq_length 128 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_b${PER_DEVICE_TRAIN_BATCH_SIZE}_0 \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" +elif [ ${TEST} == "z0" ] +then + deepspeed --num_nodes=$NNODES $HF_PATH/transformers/examples/pytorch/text-classification/run_glue.py --deepspeed ../dsconfigs/ds_config_z0.json \ + --model_name_or_path $MODEL_NAME \ + --task_name $TASK_NAME \ + --do_train \ + --max_seq_length 128 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_b${PER_DEVICE_TRAIN_BATCH_SIZE}_z0 \ + --save_steps 0 \ + --overwrite_output_dir \ + --max_steps $MAX_STEPS +elif [ ${TEST} == "z1" ] +then + deepspeed --num_nodes=$NNODES $HF_PATH/transformers/examples/pytorch/text-classification/run_glue.py --deepspeed ../dsconfigs/ds_config_z1.json \ + --model_name_or_path $MODEL_NAME \ + --task_name $TASK_NAME \ + --do_train \ + --max_seq_length 128 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_b${PER_DEVICE_TRAIN_BATCH_SIZE}_z1 \ + --save_steps 0 \ + --overwrite_output_dir \ + --max_steps $MAX_STEPS +elif [ ${TEST} == "z2" ] +then + deepspeed --num_nodes=$NNODES $HF_PATH/transformers/examples/pytorch/text-classification/run_glue.py --deepspeed ../dsconfigs/ds_config_z2.json \ + --model_name_or_path $MODEL_NAME \ + --task_name $TASK_NAME \ + --do_train \ + --max_seq_length 128 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_b${PER_DEVICE_TRAIN_BATCH_SIZE}_z2 \ + --save_steps 0 \ + --overwrite_output_dir \ + --max_steps $MAX_STEPS +elif [ ${TEST} == "z3" ] +then + deepspeed --num_nodes=$NNODES $HF_PATH/transformers/examples/pytorch/text-classification/run_glue.py --deepspeed ../dsconfigs/ds_config_z3.json \ + --model_name_or_path $MODEL_NAME \ + --task_name $TASK_NAME \ + --do_train \ + --max_seq_length 128 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_b${PER_DEVICE_TRAIN_BATCH_SIZE}_z3 \ + --save_steps 0 \ + --overwrite_output_dir \ + --max_steps $MAX_STEPS +elif [ ${TEST} == "tune" ] +then + deepspeed --autotuning run --num_nodes=$NNODES $HF_PATH/transformers/examples/pytorch/text-classification/run_glue.py --deepspeed ./ds_config_tune.json \ + --model_name_or_path $MODEL_NAME \ + --task_name $TASK_NAME \ + --do_train \ + --max_seq_length 128 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_tune \ + --save_steps 0 \ + --overwrite_output_dir \ + --max_steps $MAX_STEPS +elif [ ${TEST} == "fs" ] +then + python -m torch.distributed.launch --nproc_per_node=$NGPUS $HF_PATH/transformers/examples/pytorch/text-classification/run_glue.py \ + --model_name_or_path $MODEL_NAME \ + --task_name $TASK_NAME \ + --do_train \ + --max_seq_length 128 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_b${PER_DEVICE_TRAIN_BATCH_SIZE}_fs \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --sharded_ddp zero_dp_2 +fi diff --git a/examples/autotuning/hf/bert-large/README.md b/examples/autotuning/hf/bert-large/README.md new file mode 100644 index 0000000..157dba0 --- /dev/null +++ b/examples/autotuning/hf/bert-large/README.md @@ -0,0 +1,55 @@ +# [bert-large-uncased](https://huggingface.co/bert-large-uncased) + +This model has the following configuration: + +- 24-layer +- 1024 hidden dimension +- 16 attention heads +- 336M parameters + +The training use fp32 and runs on 1 node with 16 Nvidia V100 GPUs. The autotuning uses the same hardware resource as the training. `max_train_batch_size` is not defined. +The HF packages below are used. + +HF examples require installing the `transformers` package from source: +```bash + git clone https://github.com/huggingface/transformers.git + cd transformers + pip install . +``` +The `datasets` package can be installed by `pip install datasets` + +Below are the versions used in this test. + +- transformers (4.12.0) +- datasets (1.11.0) + +## Throughput Comparison + +The table below shows the throughput (samples per second) comparison. The corresponding train micro-batch size per GPU (mbs or tmbspg) and ZeRO stage used to achieve the throughput value is also shown in the parentheses. Assume the strategy users would use in the handtuning process is to start from `mbs = 1` and increase mbs by 2 each time until running out of GPU memory. + - `baseline` is the vanila HF without DeepSpeed (DS) and mbs is hand-tuned. + - `HF + DS hand-tuned` is HF with DS, and mbs is hand-tuned while other DS configuration uses default values. + - `HF + DS autotuning` is HF with DS, and the DS configuration is selected from autotuning. + +Notation: Hugging Face (HF), DeepSpeed (DS), ZeRO stage (z), gradient accumulation steps (gas), train micro-batch size per GPU (mbs or tmbspg). + +| Model name | baseline (vanila HF) | HF + DS handtuned | HF + DS autotuning | +| ---------- | --------------------------- | --------------------------------- | -------------------------- | +| BERT-large | 742.692 (gas = 1, mbs = 64) | 766.929 (z = 1, gas =1, mbs = 64) | 808.168 (z1_gas1_tmbspg93) | + +## Detailed `HF + DS autotuning` Result Summary + +Note that the performance metric used in autotuning is calculated using the timings captured within DeepSpeed forward, backward, and step functions. The sum of these timings is less than the actual training step latency, thus the throughput metric values used by autotuning would be higher than the end-to-end throughput in training. + +- Fast-mode Autotuning time: 36 mins +- Number of experiments: 22 +- Throughput Improvement over baseline: 1.09x + +| tuning_space | num_experiments | best_metric_val | best_exp_name | +| :----------- | --------------: | --------------: | :--------------- | +| z0 | 6 | 835.244 | z0_gas1_tmbspg93 | +| z1 | 6 | 842.243 | z1_gas1_tmbspg93 | +| z2 | 9 | 764.524 | z2_gas1_tmbspg94 | +| z3 | 1 | 0 | z3_gas1_tmbspg94 | +| global | 22 | 842.243 | z1_gas1_tmbspg93 | + +Tuning completed in 0:36:16.261417. Total number of experiments: 23. diff --git a/examples/autotuning/hf/bert-large/ds_config_tune.json b/examples/autotuning/hf/bert-large/ds_config_tune.json new file mode 100644 index 0000000..e79f9c4 --- /dev/null +++ b/examples/autotuning/hf/bert-large/ds_config_tune.json @@ -0,0 +1,11 @@ +{ + "train_micro_batch_size_per_gpu": "auto", + "autotuning": { + "enabled": true, + "overwrite": false, + "arg_mappings": { + "train_micro_batch_size_per_gpu": "--per_device_train_batch_size", + "gradient_accumulation_steps ": "--gradient_accumulation_steps" + } + } +} diff --git a/examples/autotuning/hf/bert-large/test_tune.sh b/examples/autotuning/hf/bert-large/test_tune.sh new file mode 100755 index 0000000..e63f917 --- /dev/null +++ b/examples/autotuning/hf/bert-large/test_tune.sh @@ -0,0 +1,114 @@ +TASK_NAME=mnli +MODEL_NAME=bert-large-uncased +HF_PATH=~/projects +PER_DEVICE_TRAIN_BATCH_SIZE=64 +MAX_TRAIN_BATCH_SIZE=4096 +NEPOCHS=1 +NGPUS=16 +NNODES=1 +MAX_STEPS=200 +OUTPUT_DIR=./${TASK_NAME}/output_b${PER_DEVICE_TRAIN_BATCH_SIZE}_g${NGPUS}_$MAX_STEPS + +TEST=$1 + +if [ ${TEST} == "0" ] +then + python -m torch.distributed.launch --nproc_per_node=$NGPUS $HF_PATH/transformers/examples/pytorch/text-classification/run_glue.py \ + --model_name_or_path $MODEL_NAME \ + --task_name $TASK_NAME \ + --do_train \ + --max_seq_length 128 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_b${PER_DEVICE_TRAIN_BATCH_SIZE}_0 \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" +elif [ ${TEST} == "z0" ] +then + deepspeed --num_nodes=$NNODES --num_gpus=$NGPUS $HF_PATH/transformers/examples/pytorch/text-classification/run_glue.py --deepspeed ../dsconfigs/ds_config_z0.json \ + --model_name_or_path $MODEL_NAME \ + --task_name $TASK_NAME \ + --do_train \ + --max_seq_length 128 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_b${PER_DEVICE_TRAIN_BATCH_SIZE}_z0 \ + --save_steps 0 \ + --overwrite_output_dir \ + --max_steps $MAX_STEPS +elif [ ${TEST} == "z1" ] +then + deepspeed --num_nodes=$NNODES --num_gpus=$NGPUS $HF_PATH/transformers/examples/pytorch/text-classification/run_glue.py --deepspeed ../dsconfigs/ds_config_z1.json \ + --model_name_or_path $MODEL_NAME \ + --task_name $TASK_NAME \ + --do_train \ + --max_seq_length 128 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_b${PER_DEVICE_TRAIN_BATCH_SIZE}_z1 \ + --save_steps 0 \ + --overwrite_output_dir \ + --max_steps $MAX_STEPS +elif [ ${TEST} == "z2" ] +then + deepspeed --num_nodes=$NNODES --num_gpus=$NGPUS $HF_PATH/transformers/examples/pytorch/text-classification/run_glue.py --deepspeed ../dsconfigs/ds_config_z2.json \ + --model_name_or_path $MODEL_NAME \ + --task_name $TASK_NAME \ + --do_train \ + --max_seq_length 128 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_b${PER_DEVICE_TRAIN_BATCH_SIZE}_z2 \ + --save_steps 0 \ + --overwrite_output_dir \ + --max_steps $MAX_STEPS +elif [ ${TEST} == "z3" ] +then + deepspeed --num_nodes=$NNODES --num_gpus=$NGPUS $HF_PATH/transformers/examples/pytorch/text-classification/run_glue.py --deepspeed ../dsconfigs/ds_config_z3.json \ + --model_name_or_path $MODEL_NAME \ + --task_name $TASK_NAME \ + --do_train \ + --max_seq_length 128 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_b${PER_DEVICE_TRAIN_BATCH_SIZE}_z3 \ + --save_steps 0 \ + --overwrite_output_dir \ + --max_steps $MAX_STEPS +elif [ ${TEST} == "tune" ] +then + deepspeed --autotuning run --num_nodes=$NNODES --num_gpus=$NGPUS $HF_PATH/transformers/examples/pytorch/text-classification/run_glue.py --deepspeed ./ds_config_tune.json \ + --model_name_or_path $MODEL_NAME \ + --task_name $TASK_NAME \ + --do_train \ + --max_seq_length 128 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_tune \ + --save_steps 0 \ + --overwrite_output_dir \ + --max_steps $MAX_STEPS +elif [ ${TEST} == "fs" ] +then + python -m torch.distributed.launch --nproc_per_node=$NGPUS $HF_PATH/transformers/examples/pytorch/text-classification/run_glue.py \ + --model_name_or_path $MODEL_NAME \ + --task_name $TASK_NAME \ + --do_train \ + --max_seq_length 128 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_b${PER_DEVICE_TRAIN_BATCH_SIZE}_fs \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --sharded_ddp zero_dp_2 +fi diff --git a/examples/autotuning/hf/deberta/README.md b/examples/autotuning/hf/deberta/README.md new file mode 100644 index 0000000..9144376 --- /dev/null +++ b/examples/autotuning/hf/deberta/README.md @@ -0,0 +1,72 @@ +# [deberta-v2-xxlarge-mnli](https://huggingface.co/microsoft/deberta-v2-xxlarge) + +This model has the following configuration: + +- 48-layer +- 1536 hidden dimension +- 1.5B parameters. + +Refer to [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://github.com/microsoft/DeBERTa). +## Environment + +The training use fp16 and runs on 1 node with 16 Nvidia V100 GPUs. The autotuning uses the same hardware resource as the training. `max_train_batch_size` is not defined. +The HF packages below are used. + +HF examples require installing the `transformers` package from source: +```bash + git clone https://github.com/huggingface/transformers.git + cd transformers + pip install . +``` +The `datasets` package can be installed by `pip install datasets` + +Below are the versions used in this test. + +- transformers (4.12.0) +- datasets (1.11.0) +## Throughput Comparison + +The table below shows the throughput (samples per second) comparison. The corresponding train micro-batch size per GPU (mbs or tmbspg) and ZeRO stage used to achieve the throughput value is also shown in the parentheses. Assume the strategy users would use in the handtuning process is to start from `mbs = 1` and increase mbs by 2 each time until running out of GPU memory. + - `baseline` is the vanila HF without DeepSpeed (DS) and mbs is hand-tuned. + - `HF + DS hand-tuned` is HF with DS, and mbs is hand-tuned while other DS configuration uses default values. + - `HF + DS autotuning` is HF with DS, and the DS configuration is selected from autotuning. + +Notation: Hugging Face (HF), DeepSpeed (DS), ZeRO stage (z), gradient accumulation steps (gas), train micro-batch size per GPU (mbs or tmbspg), reduce_bucket_size (rbs), allgather_bucket_size (abs). + +| Model name | baseline (vanila HF) | HF + DS hand-tuned | HF + DS autotuning (fast-mode) | +| ---------- | -------------------- | --------------------------------- | ------------------------------ | +| DeBERTa | Not runnable | 140.587 (z = 1, gas = 1 mbs = 8), | 162.395 (z1_gas1_tmbspg11) | + +## Detailed `HF + DS autotuning` Result Summary + +Note that the performance metric used in autotuning is calculated using the timings captured within DeepSpeed forward, backward, and step functions. The sum of these timings is less than the actual training step latency, thus the throughput metric values used by autotuning would be higher than the end-to-end throughput in training. +### Fast-mode +- Autotuning time: 40 mins +- Number of experiments: 12 +- Throughput Improvement over baseline: Inf + +| tuning_space | num_experiments | best_metric_val | best_exp_name | +| :----------- | --------------: | --------------: | :--------------- | +| z0 | 1 | 0 | z0_gas1_tmbspg1 | +| z1 | 6 | 177.843 | z1_gas1_tmbspg11 | +| z2 | 4 | 154.002 | z2_gas1_tmbspg14 | +| z3 | 1 | 0 | z3_gas1_tmbspg14 | +| global | 12 | 177.843 | z1_gas1_tmbspg11 | + +Tuning completed in 0:39:25.253998. Total number of experiments: 12. + +### Full-mode ("fast" set to false) +- Autotuning time: 1 hr 2 mins +- Number of experiments: 24 +- Throughput Improvement over baseline: Inf + +| tuning_space | num_experiments | best_metric_val | best_exp_name | +| :---------------- | --------------: | --------------: | :------------------------------------- | +| z0 | 1 | 0 | z0_gas1_tmbspg1 | +| z1 | 6 | 177.843 | z1_gas1_tmbspg11 | +| z1_rbs_abs_tmbspg | 12 | 193.577 | z1_rbs5.0e+07_abs1.0e+09_gas1_tmbspg11 | +| z2 | 4 | 154.002 | z2_gas1_tmbspg14 | +| z3 | 1 | 0 | z3_gas1_tmbspg14 | +| global | 24 | 193.577 | z1_rbs5.0e+07_abs1.0e+09_gas1_tmbspg11 | + +Tuning completed in 1:02:32.759424. Total number of experiments: 24. diff --git a/examples/autotuning/hf/deberta/ds_config_fp16_tune.json b/examples/autotuning/hf/deberta/ds_config_fp16_tune.json new file mode 100644 index 0000000..b405929 --- /dev/null +++ b/examples/autotuning/hf/deberta/ds_config_fp16_tune.json @@ -0,0 +1,16 @@ +{ + "train_micro_batch_size_per_gpu": "auto", + "fp16": { + "enabled": true, + "initial_scale_power": 12 + }, + "autotuning": { + "enabled": true, + "overwrite": false, + "fast": true, + "arg_mappings": { + "train_micro_batch_size_per_gpu": "--per_device_train_batch_size", + "gradient_accumulation_steps ": "--gradient_accumulation_steps" + } + } +} \ No newline at end of file diff --git a/examples/autotuning/hf/deberta/test_tune.sh b/examples/autotuning/hf/deberta/test_tune.sh new file mode 100755 index 0000000..d4de499 --- /dev/null +++ b/examples/autotuning/hf/deberta/test_tune.sh @@ -0,0 +1,127 @@ +MODEL_NAME=microsoft/deberta-v2-xxlarge +TASK_NAME=mnli +PER_DEVICE_TRAIN_BATCH_SIZE=1 +HF_PATH=~/projects +NEPOCHS=1 +NGPUS=16 +NNODES=1 +MAX_STEPS=200 +OUTPUT_DIR=./output_b${PER_DEVICE_TRAIN_BATCH_SIZE}_g${NGPUS}_$MAX_STEPS + +TEST=$1 + +if [ ${TEST} == "0" ] +then + python -m torch.distributed.launch --nproc_per_node=$NGPUS $HF_PATH/transformers/examples/pytorch/text-classification/run_glue.py \ + --model_name_or_path $MODEL_NAME \ + --task_name $TASK_NAME \ + --do_train \ + --fp16 \ + --max_seq_length 256 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 3e-6 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_0 \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" +elif [ ${TEST} == "z0" ] +then + deepspeed --num_nodes=$NNODES --num_gpus=$NGPUS $HF_PATH/transformers/examples/pytorch/text-classification/run_glue.py --deepspeed ../dsconfigs/ds_config_fp16_z0.json\ + --model_name_or_path $MODEL_NAME \ + --task_name $TASK_NAME \ + --do_train \ + --fp16 \ + --max_seq_length 256 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 3e-6 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_z0 \ + --overwrite_output_dir \ + --save_steps 0 \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" +elif [ ${TEST} == "z1" ] +then + deepspeed --num_nodes=$NNODES --num_gpus=$NGPUS $HF_PATH/transformers/examples/pytorch/text-classification/run_glue.py --deepspeed ../dsconfigs/ds_config_fp16_z1.json\ + --model_name_or_path $MODEL_NAME \ + --task_name $TASK_NAME \ + --do_train \ + --fp16 \ + --max_seq_length 256 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 3e-6 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_z1 \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" +elif [ ${TEST} == "z2" ] +then + deepspeed --num_nodes=$NNODES --num_gpus=$NGPUS $HF_PATH/transformers/examples/pytorch/text-classification/run_glue.py --deepspeed ../dsconfigs/ds_config_fp16_z2.json\ + --model_name_or_path $MODEL_NAME \ + --task_name $TASK_NAME \ + --do_train \ + --fp16 \ + --max_seq_length 256 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 3e-6 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_z2 \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" +elif [ ${TEST} == "z3" ] +then + deepspeed --num_nodes=$NNODES --num_gpus=$NGPUS $HF_PATH/transformers/examples/pytorch/text-classification/run_glue.py --deepspeed ../dsconfigs/ds_config_fp16_z3.json\ + --model_name_or_path $MODEL_NAME \ + --task_name $TASK_NAME \ + --do_train \ + --fp16 \ + --max_seq_length 256 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 3e-6 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_z3 \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" +elif [ ${TEST} == "tune" ] +then + deepspeed --autotuning run --num_nodes=$NNODES --num_gpus=$NGPUS $HF_PATH/transformers/examples/pytorch/text-classification/run_glue.py --deepspeed ./ds_config_fp16_tune.json\ + --model_name_or_path $MODEL_NAME \ + --task_name $TASK_NAME \ + --do_train \ + --fp16 \ + --max_seq_length 256 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 3e-6 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_tune \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" +elif [ ${TEST} == "fs" ] +then + python -m torch.distributed.launch --nproc_per_node=$NGPUS $HF_PATH/transformers/examples/pytorch/text-classification/run_glue.py \ + --model_name_or_path $MODEL_NAME \ + --task_name $TASK_NAME \ + --do_train \ + --fp16 \ + --max_seq_length 256 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 3e-6 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_fs \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" + --sharded_ddp zero_dp_2 +fi diff --git a/examples/autotuning/hf/distilbert/README.md b/examples/autotuning/hf/distilbert/README.md new file mode 100644 index 0000000..dce9920 --- /dev/null +++ b/examples/autotuning/hf/distilbert/README.md @@ -0,0 +1,69 @@ +# [distilbert-base-uncased](https://huggingface.co/distilbert-base-uncased) + +This model has the following configuration: + +- 12-layer +- 768 hidden dimension +- 12 attention heads +- 66M parameters. + +## Environment + +The training uses 1 node with 16 Nvidia V100 GPUs, fp32, max_train_batch_size = 4096. The autotuning uses the same hardware resource as the training. `"max_train_batch_size"` is set to `4096`. +The HF packages below are used. + +HF examples require installing the `transformers` package from source: +```bash + git clone https://github.com/huggingface/transformers.git + cd transformers + pip install . +``` +The `datasets` package can be installed by `pip install datasets` + +Below are the versions used in this test. + +- transformers (4.12.0) +- datasets (1.11.0) +## Throughput Comparison + +The table below shows the throughput (samples per second) comparison. The corresponding train micro-batch size per GPU (mbs or tmbspg) and ZeRO stage used to achieve the throughput value is also shown in the parentheses. Assume the strategy users would use in the handtuning process is to start from `mbs = 1` and increase mbs by 2 each time until running out of GPU memory. + - `baseline` is the vanila HF without DeepSpeed (DS) and mbs is hand-tuned. + - `HF + DS hand-tuned` is HF with DS, and mbs is hand-tuned while other DS configuration uses default values. + - `HF + DS autotuning` is HF with DS, and the DS configuration is selected from autotuning. + +Notation: Hugging Face (HF), DeepSpeed (DS), ZeRO stage (z), gradient accumulation steps (gas), train micro-batch size per GPU (mbs or tmbspg). + +| Model name | baseline (vanila HF) | HF + DS hand-tuned | HF + DS autotuning (fast-mode) | +| ---------- | ----------------------------- | ------------------------------------ | ------------------------------ | +| DistilBERT | 5161.902 (gas = 1, mbs = 256) | 5305.067 (z = 0, gas = 1 mbs = 256), | 5305.067 (z0_gas1_tmbspg256) | + +3700.296 + +## Detailed `HF + DS autotuning` Result Summary + +Note that the performance metric used in autotuning is calculated using the timings captured within DeepSpeed forward, backward, and step functions. The sum of these timings is less than the actual training step latency, thus the throughput metric values used by autotuning would be higher than the end-to-end throughput in training. + +- Fast-mode Autotuning time: 11 mins +- Number of experiments: 11 +- Throughput Improvement: 1.03x + +| tuning_space | num_experiments | best_metric_val | best_exp_name | +| :----------- | --------------: | --------------: | :---------------- | +| z0 | 5 | 5759.96 | z0_gas1_tmbspg256 | +| z1 | 2 | 5667.06 | z1_gas1_tmbspg256 | +| z2 | 2 | 5366.97 | z2_gas1_tmbspg256 | +| z3 | 2 | 4892.49 | z3_gas1_tmbspg256 | +| global | 11 | 5759.96 | z0_gas1_tmbspg256 | + +Tuning completed in 0:10:45.085016. Total number of experiments: 11. + + +| tuning_space | num_experiments | best_metric_val | best_exp_name | +| :----------- | --------------: | --------------: | :----------------- | +| z0 | 7 | 5759.98 | z0_gas22_tmbspg179 | +| z1 | 2 | 5543.49 | z1_gas1_tmbspg269 | +| z2 | 2 | 5044.88 | z2_gas15_tmbspg269 | +| z3 | 2 | 4627.63 | z3_gas1_tmbspg269 | +| global | 13 | 5759.98 | z0_gas22_tmbspg179 | + +Tuning completed in 0:25:44.502148. Total number of experiments: 13. diff --git a/examples/autotuning/hf/distilbert/ds_config_tune.json b/examples/autotuning/hf/distilbert/ds_config_tune.json new file mode 100644 index 0000000..23a48dd --- /dev/null +++ b/examples/autotuning/hf/distilbert/ds_config_tune.json @@ -0,0 +1,12 @@ +{ + "train_micro_batch_size_per_gpu": "auto", + "autotuning": { + "enabled": true, + "overwrite": false, + "max_train_batch_size": 4096, + "arg_mappings": { + "train_micro_batch_size_per_gpu": "--per_device_train_batch_size", + "gradient_accumulation_steps ": "--gradient_accumulation_steps" + } + } +} diff --git a/examples/autotuning/hf/distilbert/test_tune.sh b/examples/autotuning/hf/distilbert/test_tune.sh new file mode 100755 index 0000000..08b92d5 --- /dev/null +++ b/examples/autotuning/hf/distilbert/test_tune.sh @@ -0,0 +1,119 @@ +TASK_NAME=mnli +MODEL_NAME=distilbert-base-uncased +HF_PATH=~/projects +PER_DEVICE_TRAIN_BATCH_SIZE=64 +MAX_TRAIN_BATCH_SIZE=4096 +NEPOCHS=1 +NGPUS=16 +NNODES=1 +MAX_STEPS=200 +OUTPUT_DIR=./${TASK_NAME}/output_b${PER_DEVICE_TRAIN_BATCH_SIZE}_g${NGPUS}_$MAX_STEPS + +TEST=$1 + +if [ ${TEST} == "0" ] +then + python -m torch.distributed.launch --nproc_per_node=$NGPUS $HF_PATH/transformers/examples/pytorch/text-classification/run_glue.py \ + --model_name_or_path $MODEL_NAME \ + --task_name $TASK_NAME \ + --do_train \ + --max_seq_length 128 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_0 \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" +elif [ ${TEST} == "z0" ] +then + deepspeed --num_nodes=$NNODES --num_gpus=$NGPUS $HF_PATH/transformers/examples/pytorch/text-classification/run_glue.py --deepspeed ../dsconfigs/ds_config_z0.json \ + --model_name_or_path $MODEL_NAME \ + --task_name $TASK_NAME \ + --do_train \ + --max_seq_length 128 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_z0 \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" +elif [ ${TEST} == "z1" ] +then + deepspeed --num_nodes=$NNODES $HF_PATH/transformers/examples/pytorch/text-classification/run_glue.py --deepspeed ../dsconfigs/ds_config_z1.json \ + --model_name_or_path $MODEL_NAME \ + --task_name $TASK_NAME \ + --do_train \ + --max_seq_length 128 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_z1 \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" +elif [ ${TEST} == "z2" ] +then + deepspeed --num_nodes=$NNODES $HF_PATH/transformers/examples/pytorch/text-classification/run_glue.py --deepspeed ../dsconfigs/ds_config_z2.json \ + --model_name_or_path $MODEL_NAME \ + --task_name $TASK_NAME \ + --do_train \ + --max_seq_length 128 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_z2 \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" +elif [ ${TEST} == "z3" ] +then + deepspeed --num_nodes=$NNODES $HF_PATH/transformers/examples/pytorch/text-classification/run_glue.py --deepspeed ../dsconfigs/ds_config_z3.json \ + --model_name_or_path $MODEL_NAME \ + --task_name $TASK_NAME \ + --do_train \ + --max_seq_length 128 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_z3 \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" +elif [ ${TEST} == "tune" ] +then + deepspeed --autotuning run --num_nodes=$NNODES --num_gpus=$NGPUS $HF_PATH/transformers/examples/pytorch/text-classification/run_glue.py --deepspeed ./ds_config_tune.json \ + --model_name_or_path $MODEL_NAME \ + --task_name $TASK_NAME \ + --do_train \ + --max_seq_length 128 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_tune \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" +elif [ ${TEST} == "fs" ] +then + python -m torch.distributed.launch --nproc_per_node=$NGPUS $HF_PATH/transformers/examples/pytorch/text-classification/run_glue.py \ + --model_name_or_path $MODEL_NAME \ + --task_name $TASK_NAME \ + --do_train \ + --max_seq_length 128 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_fs \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" + --sharded_ddp zero_dp_2 +fi diff --git a/examples/autotuning/hf/dsconfigs/ds_config_fp16_tune.json b/examples/autotuning/hf/dsconfigs/ds_config_fp16_tune.json new file mode 100644 index 0000000..7ae3116 --- /dev/null +++ b/examples/autotuning/hf/dsconfigs/ds_config_fp16_tune.json @@ -0,0 +1,15 @@ +{ + "train_micro_batch_size_per_gpu": "auto", + "fp16": { + "enabled": true + }, + "autotuning": { + "enabled": true, + "overwrite": false, + "fast": true, + "arg_mappings": { + "train_micro_batch_size_per_gpu": "--per_device_train_batch_size", + "gradient_accumulation_steps ": "--gradient_accumulation_steps" + } + } +} diff --git a/examples/autotuning/hf/dsconfigs/ds_config_fp16_z0.json b/examples/autotuning/hf/dsconfigs/ds_config_fp16_z0.json new file mode 100644 index 0000000..ff375bb --- /dev/null +++ b/examples/autotuning/hf/dsconfigs/ds_config_fp16_z0.json @@ -0,0 +1,9 @@ +{ + "train_micro_batch_size_per_gpu": "auto", + "zero_optimization": { + "stage": 0 + }, + "fp16": { + "enabled": true + } +} diff --git a/examples/autotuning/hf/dsconfigs/ds_config_fp16_z1.json b/examples/autotuning/hf/dsconfigs/ds_config_fp16_z1.json new file mode 100644 index 0000000..209706d --- /dev/null +++ b/examples/autotuning/hf/dsconfigs/ds_config_fp16_z1.json @@ -0,0 +1,9 @@ +{ + "train_micro_batch_size_per_gpu": "auto", + "zero_optimization": { + "stage": 1 + }, + "fp16": { + "enabled": true + } +} diff --git a/examples/autotuning/hf/dsconfigs/ds_config_fp16_z2.json b/examples/autotuning/hf/dsconfigs/ds_config_fp16_z2.json new file mode 100644 index 0000000..d3782ab --- /dev/null +++ b/examples/autotuning/hf/dsconfigs/ds_config_fp16_z2.json @@ -0,0 +1,9 @@ +{ + "train_micro_batch_size_per_gpu": "auto", + "zero_optimization": { + "stage": 2 + }, + "fp16": { + "enabled": true + } +} diff --git a/examples/autotuning/hf/dsconfigs/ds_config_fp16_z3.json b/examples/autotuning/hf/dsconfigs/ds_config_fp16_z3.json new file mode 100644 index 0000000..d0affd2 --- /dev/null +++ b/examples/autotuning/hf/dsconfigs/ds_config_fp16_z3.json @@ -0,0 +1,9 @@ +{ + "train_micro_batch_size_per_gpu": "auto", + "zero_optimization": { + "stage": 3 + }, + "fp16": { + "enabled": true + } +} diff --git a/examples/autotuning/hf/dsconfigs/ds_config_tune.json b/examples/autotuning/hf/dsconfigs/ds_config_tune.json new file mode 100644 index 0000000..413e196 --- /dev/null +++ b/examples/autotuning/hf/dsconfigs/ds_config_tune.json @@ -0,0 +1,12 @@ +{ + "train_micro_batch_size_per_gpu": "auto", + "autotuning": { + "enabled": true, + "overwrite": false, + "fast": true, + "arg_mappings": { + "train_micro_batch_size_per_gpu": "--per_device_train_batch_size", + "gradient_accumulation_steps ": "--gradient_accumulation_steps" + } + } +} diff --git a/examples/autotuning/hf/dsconfigs/ds_config_z0.json b/examples/autotuning/hf/dsconfigs/ds_config_z0.json new file mode 100644 index 0000000..6247e56 --- /dev/null +++ b/examples/autotuning/hf/dsconfigs/ds_config_z0.json @@ -0,0 +1,6 @@ +{ + "train_micro_batch_size_per_gpu": "auto", + "zero_optimization": { + "stage": 0 + } +} diff --git a/examples/autotuning/hf/dsconfigs/ds_config_z1.json b/examples/autotuning/hf/dsconfigs/ds_config_z1.json new file mode 100644 index 0000000..fd39970 --- /dev/null +++ b/examples/autotuning/hf/dsconfigs/ds_config_z1.json @@ -0,0 +1,6 @@ +{ + "train_micro_batch_size_per_gpu": "auto", + "zero_optimization": { + "stage": 1 + } +} diff --git a/examples/autotuning/hf/dsconfigs/ds_config_z2.json b/examples/autotuning/hf/dsconfigs/ds_config_z2.json new file mode 100644 index 0000000..b898aee --- /dev/null +++ b/examples/autotuning/hf/dsconfigs/ds_config_z2.json @@ -0,0 +1,6 @@ +{ + "train_micro_batch_size_per_gpu": "auto", + "zero_optimization": { + "stage": 2 + } +} diff --git a/examples/autotuning/hf/dsconfigs/ds_config_z3.json b/examples/autotuning/hf/dsconfigs/ds_config_z3.json new file mode 100644 index 0000000..5b11886 --- /dev/null +++ b/examples/autotuning/hf/dsconfigs/ds_config_z3.json @@ -0,0 +1,6 @@ +{ + "train_micro_batch_size_per_gpu": "auto", + "zero_optimization": { + "stage": 3 + } +} diff --git a/examples/autotuning/hf/gpt2-large/README.md b/examples/autotuning/hf/gpt2-large/README.md new file mode 100644 index 0000000..a736db4 --- /dev/null +++ b/examples/autotuning/hf/gpt2-large/README.md @@ -0,0 +1,59 @@ +# [gpt2-large](https://huggingface.co/gpt2-large) + +This model has the following configuration: + +- 36-layer +- 1280 hidden dimension +- 20 attention heads +- 774M parameters. + +Refer to [GPT-2/GPT and causal language modeling](https://github.com/huggingface/transformers/tree/master/examples/pytorch/language-modeling#gpt-2gpt-and-causal-language-modeling) + +## Environment + +The training use fp16 and runs on 1 node with 16 Nvidia V100 GPUs. The autotuning uses the same hardware resource as the training. `max_train_batch_size` is not defined. +The HF packages below are used. + +HF examples require installing the `transformers` package from source: +```bash + git clone https://github.com/huggingface/transformers.git + cd transformers + pip install . +``` +The `datasets` package can be installed by `pip install datasets` + +Below are the versions used in this test. + +- transformers (4.12.0) +- datasets (1.11.0)datasets (1.11.0) + +## Throughput Comparison + +The table below shows the throughput (samples per second) comparison. The corresponding train micro-batch size per GPU (mbs or tmbspg) and ZeRO stage used to achieve the throughput value is also shown in the parentheses. Assume the strategy users would use in the handtuning process is to start from `mbs = 1` and increase mbs by 2 each time until running out of GPU memory. + - `baseline` is the vanila HF without DeepSpeed (DS) and mbs is hand-tuned. + - `HF + DS hand-tuned` is HF with DS, and mbs is hand-tuned while other DS configuration uses default values. + - `HF + DS autotuning` is HF with DS, and the DS configuration is selected from autotuning. + +Notation: Hugging Face (HF), DeepSpeed (DS), ZeRO stage (z), gradient accumulation steps (gas), train micro-batch size per GPU (mbs or tmbspg). + +| Model name | baseline (vanila HF) | HF + DS hand-tuned | HF + DS autotuning (fast-mode) | +| ---------- | -------------------- | ------------------------ | ------------------------------ | +| GPT2-large | 27.874 (mbs = 1) | 56.797 (z = 1, mbs = 2), | 69.061 (z = 1, mbs = 3) | + +## Detailed `HF + DS autotuning` Result Summary + +Note that the performance metric used in autotuning is calculated using the timings captured within DeepSpeed forward, backward, and step functions. The sum of these timings is less than the actual training step latency, thus the throughput metric values used by autotuning would be higher than the end-to-end throughput in training. + +- Fast-mode Autotuning time: 27 mins +- Number of experiments: 13 +- Throughput Improvement over baseline: 2.48x + +| tuning_space | num_experiments | best_metric_val | best_exp_name | +| :----------- | --------------: | --------------: | :-------------- | +| z0 | 4 | 59.0229 | z0_gas1_tmbspg2 | +| z1 | 5 | 87.3017 | z1_gas1_tmbspg3 | +| z2 | 3 | 77.8338 | z2_gas1_tmbspg3 | +| z3 | 1 | 0 | z3_gas1_tmbspg3 | +| global | 13 | 87.3017 | z1_gas1_tmbspg3 | + +Tuning completed in 0:27:33.988447. Total number of experiments: 13. diff --git a/examples/autotuning/hf/gpt2-large/test_tune.sh b/examples/autotuning/hf/gpt2-large/test_tune.sh new file mode 100755 index 0000000..c5fa9b6 --- /dev/null +++ b/examples/autotuning/hf/gpt2-large/test_tune.sh @@ -0,0 +1,132 @@ +MODEL_NAME=gpt2-large +PER_DEVICE_TRAIN_BATCH_SIZE=1 +HF_PATH=~/projects +NEPOCHS=1 +NGPUS=16 +NNODES=1 +MAX_STEPS=200 +OUTPUT_DIR=./output_b${PER_DEVICE_TRAIN_BATCH_SIZE}_g${NGPUS}_$MAX_STEPS + +TEST=$1 + +if [ ${TEST} == "0" ] +then + python -m torch.distributed.launch --nproc_per_node=$NGPUS $HF_PATH/transformers/examples/pytorch/language-modeling/run_clm.py \ + --model_name_or_path $MODEL_NAME \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --do_train \ + --do_eval \ + --fp16 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_0 \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" +elif [ ${TEST} == "z0" ] +then + deepspeed --num_nodes=$NNODES --num_gpus=$NGPUS $HF_PATH/transformers/examples/pytorch/language-modeling/run_clm.py --deepspeed ../dsconfigs/ds_config_fp16_z0.json\ + --model_name_or_path $MODEL_NAME \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --do_train \ + --do_eval \ + --fp16 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_z0 \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" +elif [ ${TEST} == "z1" ] +then + deepspeed --num_nodes=$NNODES --num_gpus=$NGPUS $HF_PATH/transformers/examples/pytorch/language-modeling/run_clm.py --deepspeed ../dsconfigs/ds_config_fp16_z1.json\ + --model_name_or_path $MODEL_NAME \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --do_train \ + --do_eval \ + --fp16 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_z1 \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" +elif [ ${TEST} == "z2" ] +then + deepspeed --num_nodes=$NNODES --num_gpus=$NGPUS $HF_PATH/transformers/examples/pytorch/language-modeling/run_clm.py --deepspeed ../dsconfigs/ds_config_fp16_z2.json\ + --model_name_or_path $MODEL_NAME \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --do_train \ + --do_eval \ + --fp16 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_z2 \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" +elif [ ${TEST} == "z3" ] +then + deepspeed --num_nodes=$NNODES --num_gpus=$NGPUS $HF_PATH/transformers/examples/pytorch/language-modeling/run_clm.py --deepspeed ../dsconfigs/ds_config_fp16_z3.json\ + --model_name_or_path $MODEL_NAME \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --do_train \ + --do_eval \ + --fp16 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_z3 \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" +elif [ ${TEST} == "tune" ] +then + deepspeed --autotuning run --num_nodes=$NNODES --num_gpus=$NGPUS $HF_PATH/transformers/examples/pytorch/language-modeling/run_clm.py --deepspeed ../dsconfigs/ds_config_fp16_tune.json\ + --model_name_or_path $MODEL_NAME \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --do_train \ + --do_eval \ + --fp16 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_tune \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" +elif [ ${TEST} == "fs" ] +then + python -m torch.distributed.launch --nproc_per_node=$NGPUS $HF_PATH/transformers/examples/pytorch/language-modeling/run_clm.py \ + --model_name_or_path $MODEL_NAME \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --do_train \ + --do_eval \ + --fp16 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_fs \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" + --sharded_ddp zero_dp_2 +fi diff --git a/examples/autotuning/hf/gpt2-medium/README.md b/examples/autotuning/hf/gpt2-medium/README.md new file mode 100644 index 0000000..e97a1f9 --- /dev/null +++ b/examples/autotuning/hf/gpt2-medium/README.md @@ -0,0 +1,57 @@ +# [gpt2-medium](https://huggingface.co/gpt2-medium) + +This model has the following configuration: +- 24-layer +- 1024 hidden dimension +- 16 attention heads +- 345M parameters. + +Refer to [GPT-2/GPT and causal language modeling](https://github.com/huggingface/transformers/tree/master/examples/pytorch/language-modeling#gpt-2gpt-and-causal-language-modeling) + +## Environment + +The training use fp16 and runs on 1 node with 16 Nvidia V100 GPUs. The autotuning uses the same hardware resource as the training. `max_train_batch_size` is not defined. +The HF packages below are used. + +HF examples require installing the `transformers` package from source: +```bash + git clone https://github.com/huggingface/transformers.git + cd transformers + pip install . +``` +The `datasets` package can be installed by `pip install datasets` + +Below are the versions used in this test. + +- transformers (4.12.0) +- datasets (1.11.0) +## Throughput Comparison + +The table below shows the throughput (samples per second) comparison. The corresponding train micro-batch size per GPU (mbs or tmbspg) and ZeRO stage used to achieve the throughput value is also shown in the parentheses. Assume the strategy users would use in the handtuning process is to start from `mbs = 1` and increase mbs by 2 each time until running out of GPU memory. + - `baseline` is the vanila HF without DeepSpeed (DS) and mbs is hand-tuned. + - `HF + DS hand-tuned` is HF with DS, and mbs is hand-tuned while other DS configuration uses default values. + - `HF + DS autotuning` is HF with DS, and the DS configuration is selected from autotuning. + +Notation: Hugging Face (HF), DeepSpeed (DS), ZeRO stage (z), gradient accumulation steps (gas), train micro-batch size per GPU (mbs or tmbspg). + +| Model name | baseline (vanila HF) | HF + DS hand-tuned | HF + DS autotuning (fast-mode) | +| ----------- | ------------------------ | --------------------------------- | ------------------------------ | +| GPT2-medium | 71.61 (gas = 1, mbs = 2) | 142.211 (z = 1, gas = 1, mbs = 4) | 163.3 (z1_gas1_tmbspg6) | + +## Detailed `HF + DS autotuning` Result Summary + +Note that the performance metric used in autotuning is calculated using the timings captured within DeepSpeed forward, backward, and step functions. The sum of these timings is less than the actual training step latency, thus the throughput metric values used by autotuning would be higher than the end-to-end throughput in training. + +- Fast-mode Autotuning time: 25 mins +- Number of experiments: 15 +- Throughput Improvement over baseline: 2.28x + +| tuning_space | num_experiments | best_metric_val | best_exp_name | +| :----------- | --------------: | --------------: | :-------------- | +| z0 | 6 | 167.688 | z0_gas1_tmbspg5 | +| z1 | 5 | 175.46 | z1_gas1_tmbspg6 | +| z2 | 3 | 161.619 | z2_gas1_tmbspg6 | +| z3 | 1 | 0 | z3_gas1_tmbspg6 | +| global | 15 | 175.46 | z1_gas1_tmbspg6 | + +Tuning completed in 0:25:18.653731. Total number of experiments: 15. diff --git a/examples/autotuning/hf/gpt2-medium/test_tune.sh b/examples/autotuning/hf/gpt2-medium/test_tune.sh new file mode 100755 index 0000000..567deb4 --- /dev/null +++ b/examples/autotuning/hf/gpt2-medium/test_tune.sh @@ -0,0 +1,142 @@ +MODEL_NAME=gpt2-medium +PER_DEVICE_TRAIN_BATCH_SIZE=1 +HF_PATH=~/projects +NEPOCHS=1 +NGPUS=16 +NNODES=1 +MAX_STEPS=200 +OUTPUT_DIR=./output_b${PER_DEVICE_TRAIN_BATCH_SIZE}_g${NGPUS}_$MAX_STEPS + +TEST=$1 + +if [ ${TEST} == "0" ] +then + python -m torch.distributed.launch --nproc_per_node=$NGPUS $HF_PATH/transformers/examples/pytorch/language-modeling/run_clm.py \ + --model_name_or_path $MODEL_NAME \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --do_train \ + --fp16 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_0 \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" +elif [ ${TEST} == "z0" ] +then + deepspeed --num_nodes=$NNODES --num_gpus=$NGPUS $HF_PATH/transformers/examples/pytorch/language-modeling/run_clm.py --deepspeed ../dsconfigs/ds_config_fp16_z0.json\ + --model_name_or_path $MODEL_NAME \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --do_train \ + --fp16 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_z0 \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" +elif [ ${TEST} == "z1" ] +then + deepspeed --num_nodes=$NNODES --num_gpus=$NGPUS $HF_PATH/transformers/examples/pytorch/language-modeling/run_clm.py --deepspeed ../dsconfigs/ds_config_fp16_z1.json\ + --model_name_or_path $MODEL_NAME \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --do_train \ + --fp16 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_z1 \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" +elif [ ${TEST} == "z2" ] +then + deepspeed --num_nodes=$NNODES --num_gpus=$NGPUS $HF_PATH/transformers/examples/pytorch/language-modeling/run_clm.py --deepspeed ../dsconfigs/ds_config_fp16_z2.json\ + --model_name_or_path $MODEL_NAME \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --do_train \ + --fp16 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_z2 \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" +elif [ ${TEST} == "z3" ] +then + deepspeed --num_nodes=$NNODES --num_gpus=$NGPUS $HF_PATH/transformers/examples/pytorch/language-modeling/run_clm.py --deepspeed ../dsconfigs/ds_config_fp16_z3.json\ + --model_name_or_path $MODEL_NAME \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --do_train \ + --fp16 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_z3 \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" +elif [ ${TEST} == "tune" ] +then + deepspeed --autotuning run --num_nodes=$NNODES --num_gpus=$NGPUS $HF_PATH/transformers/examples/pytorch/language-modeling/run_clm.py --deepspeed ../dsconfigs/ds_config_fp16_tune.json\ + --model_name_or_path $MODEL_NAME \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --do_train \ + --fp16 \ + --block_size 512 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_tune \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" +elif [ ${TEST} == "tune_test" ] +then + deepspeed --num_nodes=$NNODES --num_gpus=$NGPUS $HF_PATH/transformers/examples/pytorch/language-modeling/run_clm.py --deepspeed ../dsconfigs/ds_config_fp16_tune_test.json \ + --model_name_or_path $MODEL_NAME \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --do_train \ + --fp16 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_tune_test \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" +elif [ ${TEST} == "fs" ] +then + python -m torch.distributed.launch --nproc_per_node=$NGPUS $HF_PATH/transformers/examples/pytorch/language-modeling/run_clm.py \ + --model_name_or_path $MODEL_NAME \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --do_train \ + --fp16 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_fs \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" + --sharded_ddp zero_dp_2 +fi diff --git a/examples/autotuning/hf/gpt2-xl/README.md b/examples/autotuning/hf/gpt2-xl/README.md new file mode 100644 index 0000000..f6d81b2 --- /dev/null +++ b/examples/autotuning/hf/gpt2-xl/README.md @@ -0,0 +1,56 @@ +# [gpt2-xl](https://huggingface.co/gpt2-xl) + +This model has the following configuration: +- 48-layer +- 1600 hidden dimension +- 25 attention heads +- 1.5B parameters. + +Refer to [GPT-2/GPT and causal language modeling](https://github.com/huggingface/transformers/tree/master/examples/pytorch/language-modeling#gpt-2gpt-and-causal-language-modeling) + +## Environment + +The training use fp16 and runs on 1 node with 16 Nvidia V100 GPUs. The autotuning uses the same hardware resource as the training. `max_train_batch_size` is not defined. +The HF packages below are used. + +HF examples require installing the `transformers` package from source: +```bash + git clone https://github.com/huggingface/transformers.git + cd transformers + pip install . +``` +The `datasets` package can be installed by `pip install datasets` + +Below are the versions used in this test. + +- transformers (4.12.0) +- datasets (1.11.0) +## Throughput Comparison + +The table below shows the throughput (samples per second) comparison. The corresponding train micro-batch size per GPU (mbs or tmbspg) and ZeRO stage used to achieve the throughput value is also shown in the parentheses. Assume the strategy users would use in the handtuning process is to start from `mbs = 1` and increase mbs by 2 each time until running out of GPU memory. + - `baseline` is the vanila HF without DeepSpeed (DS) and mbs is hand-tuned. + - `HF + DS hand-tuned` is HF with DS, and mbs is hand-tuned while other DS configuration uses default values. + - `HF + DS autotuning` is HF with DS, and the DS configuration is selected from autotuning. + +Notation: Hugging Face (HF), DeepSpeed (DS), ZeRO stage (z), gradient accumulation steps (gas), train micro-batch size per GPU (mbs or tmbspg). + +| Model name | baseline (vanila HF) | HF + DS hand-tuned | HF + DS autotuning (fast-mode) | +| ---------- | -------------------- | --------------------------------- | -------------------------------- | +| GPT2-xl | Not runnable | Zero1 (27.462, gas = 1, mbs = 1), | Zero1 (27.497, gas = 1, mbs = 1) | + +## Detailed `HF + DS autotuning` Result Summary + +Note that the performance metric used in autotuning is calculated using the timings captured within DeepSpeed forward, backward, and step functions. The sum of these timings is less than the actual training step latency, thus the throughput metric values used by autotuning would be higher than the end-to-end throughput in training. + +- Fast-mode Autotuning time: 21 mins +- Number of experiments: 9 +- Throughput Improvement over baseline: Inf + +| tuning_space | num_experiments | best_metric_val | best_exp_name | +| :----------- | --------------: | --------------: | :-------------- | +| z1 | 3 | 40.1749 | z1_gas1_tmbspg1 | +| z2 | 3 | 33.0472 | z2_gas1_tmbspg1 | +| z3 | 3 | 12.8604 | z3_gas1_tmbspg1 | +| global | 9 | 40.1749 | z1_gas1_tmbspg1 | + +Tuning completed in 0:20:55.156000. Total number of experiments: 9. diff --git a/examples/autotuning/hf/gpt2-xl/test_tune.sh b/examples/autotuning/hf/gpt2-xl/test_tune.sh new file mode 100755 index 0000000..3c14463 --- /dev/null +++ b/examples/autotuning/hf/gpt2-xl/test_tune.sh @@ -0,0 +1,142 @@ +MODEL_NAME=gpt2-xl +PER_DEVICE_TRAIN_BATCH_SIZE=1 +HF_PATH=~/projects +NEPOCHS=1 +NGPUS=16 +NNODES=1 +MAX_STEPS=50 +OUTPUT_DIR=./output_b${PER_DEVICE_TRAIN_BATCH_SIZE}_g${NGPUS}_$MAX_STEPS + +TEST=$1 + +if [ ${TEST} == "0" ] +then + python -m torch.distributed.launch --nproc_per_node=$NGPUS $HF_PATH/transformers/examples/pytorch/language-modeling/run_clm.py \ + --model_name_or_path $MODEL_NAME \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --do_train \ + --fp16 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_0 \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" +elif [ ${TEST} == "z0" ] +then + deepspeed --num_nodes=$NNODES --num_gpus=$NGPUS $HF_PATH/transformers/examples/pytorch/language-modeling/run_clm.py --deepspeed ../dsconfigs/ds_config_fp16_z0.json\ + --model_name_or_path $MODEL_NAME \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --do_train \ + --fp16 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_z0 \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" +elif [ ${TEST} == "z1" ] +then + deepspeed --num_nodes=$NNODES --num_gpus=$NGPUS $HF_PATH/transformers/examples/pytorch/language-modeling/run_clm.py --deepspeed ../dsconfigs/ds_config_fp16_z1.json\ + --model_name_or_path $MODEL_NAME \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --do_train \ + --fp16 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_z1 \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" +elif [ ${TEST} == "z2" ] +then + deepspeed --num_nodes=$NNODES --num_gpus=$NGPUS $HF_PATH/transformers/examples/pytorch/language-modeling/run_clm.py --deepspeed ../dsconfigs/ds_config_fp16_z2.json\ + --model_name_or_path $MODEL_NAME \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --do_train \ + --fp16 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_z2 \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" +elif [ ${TEST} == "z3" ] +then + deepspeed --num_nodes=$NNODES --num_gpus=$NGPUS $HF_PATH/transformers/examples/pytorch/language-modeling/run_clm.py --deepspeed ../dsconfigs/ds_config_fp16_z3.json\ + --model_name_or_path $MODEL_NAME \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --do_train \ + --fp16 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_z3 \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" +elif [ ${TEST} == "tune" ] +then + deepspeed --autotuning run --num_nodes=$NNODES --num_gpus=$NGPUS $HF_PATH/transformers/examples/pytorch/language-modeling/run_clm.py --deepspeed ../dsconfigs/ds_config_fp16_tune.json\ + --model_name_or_path $MODEL_NAME \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --do_train \ + --fp16 \ + --block_size 512 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_tune \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" +elif [ ${TEST} == "tune_test" ] +then + deepspeed --num_nodes=$NNODES --num_gpus=$NGPUS $HF_PATH/transformers/examples/pytorch/language-modeling/run_clm.py --deepspeed ../dsconfigs/ds_config_fp16_tune_test.json \ + --model_name_or_path $MODEL_NAME \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --do_train \ + --fp16 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_tune_test \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" +elif [ ${TEST} == "fs" ] +then + python -m torch.distributed.launch --nproc_per_node=$NGPUS $HF_PATH/transformers/examples/pytorch/language-modeling/run_clm.py \ + --model_name_or_path $MODEL_NAME \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --do_train \ + --fp16 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_fs \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" + --sharded_ddp zero_dp_2 +fi diff --git a/examples/autotuning/hf/gpt2/README.md b/examples/autotuning/hf/gpt2/README.md new file mode 100644 index 0000000..bb42691 --- /dev/null +++ b/examples/autotuning/hf/gpt2/README.md @@ -0,0 +1,59 @@ +# [gpt2](https://huggingface.co/gpt2) + +This model has the following configuration: + +- 12-layer +- 768 hidden dimension +- 12 attention heads +- 117M parameters. + +Refer to [GPT-2/GPT and causal language modeling](https://github.com/huggingface/transformers/tree/master/examples/pytorch/language-modeling#gpt-2gpt-and-causal-language-modeling) + +## Environment + +The training use fp16 and runs on 1 node with 16 Nvidia V100 GPUs. The autotuning uses the same hardware resource as the training. `max_train_batch_size` is not defined. +The HF packages below are used. + +HF examples require installing the `transformers` package from source: +```bash + git clone https://github.com/huggingface/transformers.git + cd transformers + pip install . +``` +The `datasets` package can be installed by `pip install datasets` + +Below are the versions used in this test. + +- transformers (4.12.0) +- datasets (1.11.0) +## Throughput Comparison + +The table below shows the throughput (samples per second) comparison. The corresponding train micro-batch size per GPU (mbs or tmbspg) and ZeRO stage used to achieve the throughput value is also shown in the parentheses. Assume the strategy users would use in the handtuning process is to start from `mbs = 1` and increase mbs by 2 each time until running out of GPU memory. + - `baseline` is the vanila HF without DeepSpeed (DS) and mbs is hand-tuned. + - `HF + DS hand-tuned` is HF with DS, and mbs is hand-tuned while other DS configuration uses default values. + - `HF + DS autotuning` is HF with DS, and the DS configuration is selected from autotuning. + +Notation: Hugging Face (HF), DeepSpeed (DS), ZeRO stage (z), gradient accumulation steps (gas), train micro-batch size per GPU (mbs or tmbspg). + +| Model name | baseline (vanila HF) | HF + DS hand-tuned | HF + DS autotuning (fast-mode) | +| ---------- | -------------------- | ------------------------ | ------------------------------ | +| GPT2 | 284.142 (mbs = 8) | 397.827 (z = 1, mbs = 8) | 431.586 (z1_gas1_tmbspg15) | + + +## Detailed `HF + DS autotuning` Result Summary + +Note that the performance metric used in autotuning is calculated using the timings captured within DeepSpeed forward, backward, and step functions. The sum of these timings is less than the actual training step latency, thus the throughput metric values used by autotuning would be higher than the end-to-end throughput in training. + +- Fast-mode Autotuning time: 25 mins +- Number of experiments: 17 +- Throughput Improvement over baseline: 1.52x + +| tuning_space | num_experiments | best_metric_val | best_exp_name | +| :----------- | --------------: | --------------: | :--------------- | +| z0 | 9 | 441.693 | z0_gas1_tmbspg11 | +| z1 | 6 | 452.004 | z1_gas1_tmbspg15 | +| z2 | 1 | 0 | z2_gas1_tmbspg15 | +| z3 | 1 | 0 | z3_gas1_tmbspg15 | +| global | 17 | 452.004 | z1_gas1_tmbspg15 | + +Tuning completed in 0:24:19.976427. Total number of experiments: 17. diff --git a/examples/autotuning/hf/gpt2/test_tune.sh b/examples/autotuning/hf/gpt2/test_tune.sh new file mode 100755 index 0000000..b570c45 --- /dev/null +++ b/examples/autotuning/hf/gpt2/test_tune.sh @@ -0,0 +1,133 @@ +MODEL_NAME=gpt2 +PER_DEVICE_TRAIN_BATCH_SIZE=1 +HF_PATH=~/projects +NEPOCHS=1 +NGPUS=16 +NNODES=1 +MAX_STEPS=200 +OUTPUT_DIR=./output_b${PER_DEVICE_TRAIN_BATCH_SIZE}_g${NGPUS}_$MAX_STEPS + +TEST=$1 + + +if [ ${TEST} == "0" ] +then + python -m torch.distributed.launch --nproc_per_node=$NGPUS $HF_PATH/transformers/examples/pytorch/language-modeling/run_clm.py \ + --model_name_or_path $MODEL_NAME \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --do_train \ + --do_eval \ + --fp16 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_0 \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" +elif [ ${TEST} == "z0" ] +then + deepspeed --num_nodes=$NNODES --num_gpus=$NGPUS $HF_PATH/transformers/examples/pytorch/language-modeling/run_clm.py --deepspeed ../dsconfigs/ds_config_fp16_z0.json\ + --model_name_or_path $MODEL_NAME \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --do_train \ + --do_eval \ + --fp16 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_z0 \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" +elif [ ${TEST} == "z1" ] +then + deepspeed --num_nodes=$NNODES --num_gpus=$NGPUS $HF_PATH/transformers/examples/pytorch/language-modeling/run_clm.py --deepspeed ../dsconfigs/ds_config_fp16_z1.json\ + --model_name_or_path $MODEL_NAME \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --do_train \ + --do_eval \ + --fp16 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_z1 \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" +elif [ ${TEST} == "z2" ] +then + deepspeed --num_nodes=$NNODES --num_gpus=$NGPUS $HF_PATH/transformers/examples/pytorch/language-modeling/run_clm.py --deepspeed ../dsconfigs/ds_config_fp16_z2.json\ + --model_name_or_path $MODEL_NAME \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --do_train \ + --do_eval \ + --fp16 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_z2 \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" +elif [ ${TEST} == "z3" ] +then + deepspeed --num_nodes=$NNODES --num_gpus=$NGPUS $HF_PATH/transformers/examples/pytorch/language-modeling/run_clm.py --deepspeed ../dsconfigs/ds_config_fp16_z3.json\ + --model_name_or_path $MODEL_NAME \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --do_train \ + --do_eval \ + --fp16 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_z3 \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" +elif [ ${TEST} == "tune" ] +then + deepspeed --autotuning run --num_nodes=$NNODES --num_gpus=$NGPUS $HF_PATH/transformers/examples/pytorch/language-modeling/run_clm.py --deepspeed ../dsconfigs/ds_config_fp16_tune.json\ + --model_name_or_path $MODEL_NAME \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --do_train \ + --do_eval \ + --fp16 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_tune \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" +elif [ ${TEST} == "fs" ] +then + python -m torch.distributed.launch --nproc_per_node=$NGPUS $HF_PATH/transformers/examples/pytorch/language-modeling/run_clm.py \ + --model_name_or_path $MODEL_NAME \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --do_train \ + --do_eval \ + --fp16 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_fs \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" + --sharded_ddp zero_dp_2 +fi diff --git a/examples/pipeline_parallelism/alexnet.py b/examples/pipeline_parallelism/alexnet.py new file mode 100644 index 0000000..03c77a0 --- /dev/null +++ b/examples/pipeline_parallelism/alexnet.py @@ -0,0 +1,47 @@ +# +# Implementation of AlexNet for illustrative purposes. The train.py driver +# can import AlexNet from here or directly from torchvision. +# +# Taken from torchvision.models.alexnet: +# https://pytorch.org/docs/1.6.0/_modules/torchvision/models/alexnet.html#alexnet + + +import torch +import torch.nn as nn + + +class AlexNet(nn.Module): + def __init__(self, num_classes=1000): + super(AlexNet, self).__init__() + self.features = nn.Sequential( + nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2), + nn.Conv2d(64, 192, kernel_size=5, padding=2), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2), + nn.Conv2d(192, 384, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(384, 256, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, 256, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2), + ) + self.avgpool = nn.AdaptiveAvgPool2d((6, 6)) + self.classifier = nn.Sequential( + nn.Dropout(), + nn.Linear(256 * 6 * 6, 4096), + nn.ReLU(inplace=True), + nn.Dropout(), + nn.Linear(4096, 4096), + nn.ReLU(inplace=True), + nn.Linear(4096, num_classes), + ) + + def forward(self, x): + x = self.features(x) + x = self.avgpool(x) + x = torch.flatten(x, 1) + x = self.classifier(x) + return x diff --git a/examples/pipeline_parallelism/ds_config.json b/examples/pipeline_parallelism/ds_config.json new file mode 100644 index 0000000..0f97f59 --- /dev/null +++ b/examples/pipeline_parallelism/ds_config.json @@ -0,0 +1,19 @@ + { + "train_batch_size" : 256, + "train_micro_batch_size_per_gpu" : 8, + + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.001, + "betas": [ + 0.9, + 0.999 + ], + "eps": 1e-8 + } + }, + + "steps_per_print" : 10, + "wall_clock_breakdown" : false + } diff --git a/examples/pipeline_parallelism/run.sh b/examples/pipeline_parallelism/run.sh new file mode 100755 index 0000000..9753282 --- /dev/null +++ b/examples/pipeline_parallelism/run.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +deepspeed train.py --deepspeed_config=ds_config.json -p 2 --steps=200 diff --git a/examples/pipeline_parallelism/train.py b/examples/pipeline_parallelism/train.py new file mode 100755 index 0000000..1a418b4 --- /dev/null +++ b/examples/pipeline_parallelism/train.py @@ -0,0 +1,159 @@ +#!/usr/bin/env python3 + +import os +import argparse + +import torch +import torch.distributed as dist + +import torchvision +import torchvision.transforms as transforms +from torchvision.models import AlexNet +from torchvision.models import vgg19 + +import deepspeed +from deepspeed.pipe import PipelineModule +from deepspeed.utils import RepeatingLoader + + +def cifar_trainset(local_rank, dl_path='/tmp/cifar10-data'): + transform = transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ]) + + # Ensure only one rank downloads. + # Note: if the download path is not on a shared filesytem, remove the semaphore + # and switch to args.local_rank + dist.barrier() + if local_rank != 0: + dist.barrier() + trainset = torchvision.datasets.CIFAR10(root=dl_path, + train=True, + download=True, + transform=transform) + if local_rank == 0: + dist.barrier() + return trainset + + +def get_args(): + parser = argparse.ArgumentParser(description='CIFAR') + parser.add_argument('--local_rank', + type=int, + default=-1, + help='local rank passed from distributed launcher') + parser.add_argument('-s', + '--steps', + type=int, + default=100, + help='quit after this many steps') + parser.add_argument('-p', + '--pipeline-parallel-size', + type=int, + default=2, + help='pipeline parallelism') + parser.add_argument('--backend', + type=str, + default='nccl', + help='distributed backend') + parser.add_argument('--seed', type=int, default=1138, help='PRNG seed') + parser = deepspeed.add_config_arguments(parser) + args = parser.parse_args() + return args + + +def train_base(args): + torch.manual_seed(args.seed) + + # VGG also works :-) + #net = vgg19(num_classes=10) + net = AlexNet(num_classes=10) + + trainset = cifar_trainset(args.local_rank) + + engine, _, dataloader, __ = deepspeed.initialize( + args=args, + model=net, + model_parameters=[p for p in net.parameters() if p.requires_grad], + training_data=trainset) + + dataloader = RepeatingLoader(dataloader) + data_iter = iter(dataloader) + + rank = dist.get_rank() + gas = engine.gradient_accumulation_steps() + + criterion = torch.nn.CrossEntropyLoss() + + total_steps = args.steps * engine.gradient_accumulation_steps() + step = 0 + for micro_step in range(total_steps): + batch = next(data_iter) + inputs = batch[0].to(engine.device) + labels = batch[1].to(engine.device) + + outputs = engine(inputs) + loss = criterion(outputs, labels) + engine.backward(loss) + engine.step() + + if micro_step % engine.gradient_accumulation_steps() == 0: + step += 1 + if rank == 0 and (step % 10 == 0): + print(f'step: {step:3d} / {args.steps:3d} loss: {loss}') + + + +def join_layers(vision_model): + layers = [ + *vision_model.features, + vision_model.avgpool, + lambda x: torch.flatten(x, 1), + *vision_model.classifier, + ] + return layers + + +def train_pipe(args, part='parameters'): + torch.manual_seed(args.seed) + deepspeed.runtime.utils.set_random_seed(args.seed) + + # + # Build the model + # + + # VGG also works :-) + #net = vgg19(num_classes=10) + net = AlexNet(num_classes=10) + net = PipelineModule(layers=join_layers(net), + loss_fn=torch.nn.CrossEntropyLoss(), + num_stages=args.pipeline_parallel_size, + partition_method=part, + activation_checkpoint_interval=0) + + trainset = cifar_trainset(args.local_rank) + + engine, _, _, _ = deepspeed.initialize( + args=args, + model=net, + model_parameters=[p for p in net.parameters() if p.requires_grad], + training_data=trainset) + + for step in range(args.steps): + loss = engine.train_batch() + + +if __name__ == '__main__': + args = get_args() + + deepspeed.init_distributed(dist_backend=args.backend) + args.local_rank = int(os.environ['LOCAL_RANK']) + torch.cuda.set_device(args.local_rank) + + if args.pipeline_parallel_size == 0: + train_base(args) + else: + train_pipe(args) diff --git a/intel_extension_for_deepspeed/__init__.py b/intel_extension_for_deepspeed/__init__.py new file mode 100644 index 0000000..cf96805 --- /dev/null +++ b/intel_extension_for_deepspeed/__init__.py @@ -0,0 +1 @@ +from .xpu_accelerator import XPU_Accelerator diff --git a/intel_extension_for_deepspeed/op_builder/__init__.py b/intel_extension_for_deepspeed/op_builder/__init__.py new file mode 100755 index 0000000..22afbfa --- /dev/null +++ b/intel_extension_for_deepspeed/op_builder/__init__.py @@ -0,0 +1,7 @@ +from .builder import OpBuilder +from .cpu_adam import CPUAdamBuilder +from .cpu_adagrad import CPUAdagradBuilder +from .fused_adam import FusedAdamBuilder +from .transformer import TransformerBuilder +from .quantizer import QuantizerBuilder +from .utils import UtilsBuilder diff --git a/intel_extension_for_deepspeed/op_builder/builder.py b/intel_extension_for_deepspeed/op_builder/builder.py new file mode 100644 index 0000000..7d39a0b --- /dev/null +++ b/intel_extension_for_deepspeed/op_builder/builder.py @@ -0,0 +1,61 @@ +""" +Copyright 2020 The Microsoft DeepSpeed Team +""" +import os +import shutil +from pathlib import Path +from deepspeed.ops.op_builder.builder import OpBuilder, TORCH_MAJOR, TORCH_MINOR + + +class SYCLOpBuilder(OpBuilder): + def builder(self): + try: + from intel_extension_for_pytorch.xpu.cpp_extension import DPCPPExtension + except ImportError: + from intel_extension_for_pytorch.xpu.utils import DPCPPExtension + + print("dpcpp sources = {}".format(self.sources())) + dpcpp_ext = DPCPPExtension( + name=self.absolute_name(), + sources=self.strip_empty_entries(self.sources()), + include_dirs=self.strip_empty_entries(self.include_paths()), + extra_compile_args={ + 'cxx': self.strip_empty_entries(self.cxx_args()), + }, + extra_link_args=self.strip_empty_entries(self.extra_ldflags())) + return dpcpp_ext + + def version_dependent_macros(self): + # Fix from apex that might be relevant for us as well, related to https://github.com/NVIDIA/apex/issues/456 + version_ge_1_1 = [] + if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 0): + version_ge_1_1 = ['-DVERSION_GE_1_1'] + version_ge_1_3 = [] + if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2): + version_ge_1_3 = ['-DVERSION_GE_1_3'] + version_ge_1_5 = [] + if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 4): + version_ge_1_5 = ['-DVERSION_GE_1_5'] + return version_ge_1_1 + version_ge_1_3 + version_ge_1_5 + + def cxx_args(self): + return ['-O3', '-g', '-std=c++20', '-w', '-fPIC', '-DMKL_ILP64'] + + def extra_ldflags(self): + return ['-fPIC', '-Wl,-export-dynamic'] + + +def sycl_kernel_path(code_path): + import intel_extension_for_pytorch + abs_path = os.path.join(Path(__file__).parent.absolute(), code_path) + rel_path = os.path.join("third-party", code_path) + print("Copying SYCL kernel file from {} to {}".format(abs_path, rel_path)) + os.makedirs(os.path.dirname(rel_path), exist_ok=True) + shutil.copyfile(abs_path, rel_path) + return rel_path + + +def sycl_kernel_include(code_path): + import intel_extension_for_pytorch + abs_path = os.path.join(Path(__file__).parent.absolute(), code_path) + return abs_path diff --git a/intel_extension_for_deepspeed/op_builder/cpu_adagrad.py b/intel_extension_for_deepspeed/op_builder/cpu_adagrad.py new file mode 100644 index 0000000..ea379cd --- /dev/null +++ b/intel_extension_for_deepspeed/op_builder/cpu_adagrad.py @@ -0,0 +1,21 @@ +""" +Copyright 2020 The Microsoft DeepSpeed Team +""" +from .builder import SYCLOpBuilder + + +class CPUAdagradBuilder(SYCLOpBuilder): + BUILD_VAR = "DS_BUILD_CPU_ADAGRAD" + NAME = "cpu_adagrad" + + def __init__(self): + super().__init__(name=self.NAME) + + def absolute_name(self): + return f'deepspeed.ops.adagrad.{self.NAME}_op' + + def sources(self): + return [] + + def include_paths(self): + return [] diff --git a/intel_extension_for_deepspeed/op_builder/cpu_adam.py b/intel_extension_for_deepspeed/op_builder/cpu_adam.py new file mode 100644 index 0000000..a67fcba --- /dev/null +++ b/intel_extension_for_deepspeed/op_builder/cpu_adam.py @@ -0,0 +1,31 @@ +""" +Copyright 2020 The Microsoft DeepSpeed Team +""" +from .builder import SYCLOpBuilder, sycl_kernel_path, sycl_kernel_include + + +class CPUAdamBuilder(SYCLOpBuilder): + BUILD_VAR = "DS_BUILD_CPU_ADAM" + NAME = "cpu_adam" + + def __init__(self): + super().__init__(name=self.NAME) + + def absolute_name(self): + return f'deepspeed.ops.adam.{self.NAME}_op' + + def sources(self): + return [ + sycl_kernel_path('csrc/adam/sycl/cpu_adam.dp.cpp'), + sycl_kernel_path('csrc/adam/sycl/custom_sycl_kernel.dp.cpp'), + ] + + def libraries_args(self): + args = super().libraries_args() + return args + + def include_paths(self): + return [ + sycl_kernel_include('csrc/includes'), + sycl_kernel_include('csrc/adam'), 'csrc/includes' + ] diff --git a/intel_extension_for_deepspeed/op_builder/csrc/adam/sycl/cpu_adam.dp.cpp b/intel_extension_for_deepspeed/op_builder/csrc/adam/sycl/cpu_adam.dp.cpp new file mode 100644 index 0000000..940572d --- /dev/null +++ b/intel_extension_for_deepspeed/op_builder/csrc/adam/sycl/cpu_adam.dp.cpp @@ -0,0 +1,707 @@ +#include "sycl/cpu_adam.hpp" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "sycl/custom_sycl_layers.hpp" + +static std::unordered_map> s_optimizers; + +#define ROUND_DOWN(size, step) ((size) & ~((step)-1)) + +// C++ interface + +void Adam_Optimizer::Step(float* _params, + float* grads, + float* _exp_avg, + float* _exp_avg_sq, + size_t _param_size, + sycl::half* dev_params, + bool half_precision) +{ + sycl::half* grads_cast_h; + sycl::half* params_cast_h; + if (half_precision) { + grads_cast_h = reinterpret_cast(grads); + params_cast_h = reinterpret_cast(_params); + } + + float betta1_minus1 = 1 - _betta1; + float betta2_minus1 = 1 - _betta2; + + float step_size = -1 * _alpha / _bias_correction1; + float w_decay = -1 * _alpha * _weight_decay; + size_t rounded_size = 0; + +#if defined(__AVX512__) or defined(__AVX256__) + + AVX_Data betta1_4; + betta1_4.data = SIMD_SET(_betta1); + AVX_Data betta2_4; + betta2_4.data = SIMD_SET(_betta2); + + AVX_Data betta1_minus1_4; + betta1_minus1_4.data = SIMD_SET(betta1_minus1); + AVX_Data betta2_minus1_4; + betta2_minus1_4.data = SIMD_SET(betta2_minus1); + + AVX_Data bias2_sqrt; + bias2_sqrt.data = SIMD_SET(_bias_correction2); + + AVX_Data eps_4; + eps_4.data = SIMD_SET(_eps); + + AVX_Data step_size_4; + step_size_4.data = SIMD_SET(step_size); + + AVX_Data weight_decay4; + if (_weight_decay > 0) + weight_decay4.data = (_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay)); + rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH); + + for (size_t t = 0; t < rounded_size; t += TILE) { + size_t copy_size = TILE; + if ((t + TILE) > rounded_size) copy_size = rounded_size - t; + size_t offset = copy_size + t; + if ((t / TILE) >= 2) { _streams[_buf_index]->wait(); } + +#pragma omp parallel for + for (size_t i = t; i < offset; i += SIMD_WIDTH) { + AVX_Data grad_4; + grad_4.data = SIMD_LOAD(grads + i); + + AVX_Data momentum_4; + momentum_4.data = SIMD_LOAD(_exp_avg + i); + AVX_Data variance_4; + variance_4.data = SIMD_LOAD(_exp_avg_sq + i); + + AVX_Data param_4; + param_4.data = SIMD_LOAD(_params + i); + + if (_weight_decay > 0 && !_adamw_mode) { + grad_4.data = SIMD_FMA(param_4.data, weight_decay4.data, grad_4.data); + } + momentum_4.data = SIMD_MUL(momentum_4.data, betta1_4.data); + momentum_4.data = SIMD_FMA(grad_4.data, betta1_minus1_4.data, momentum_4.data); + + variance_4.data = SIMD_MUL(variance_4.data, betta2_4.data); + grad_4.data = SIMD_MUL(grad_4.data, grad_4.data); + variance_4.data = SIMD_FMA(grad_4.data, betta2_minus1_4.data, variance_4.data); + + grad_4.data = SIMD_SQRT(variance_4.data); + grad_4.data = SIMD_FMA(grad_4.data, bias2_sqrt.data, eps_4.data); + grad_4.data = SIMD_DIV(momentum_4.data, grad_4.data); + if (_weight_decay > 0 && _adamw_mode) { + param_4.data = SIMD_FMA(param_4.data, weight_decay4.data, param_4.data); + } + param_4.data = SIMD_FMA(grad_4.data, step_size_4.data, param_4.data); + + SIMD_STORE(_params + i, param_4.data); + + if (dev_params) SIMD_STORE(_doubled_buffer[_buf_index] + (i - t), param_4.data); + + SIMD_STORE(_exp_avg + i, momentum_4.data); + SIMD_STORE(_exp_avg_sq + i, variance_4.data); + } + if (dev_params) { + launch_param_update( + _doubled_buffer[_buf_index], dev_params + t, copy_size, _streams[_buf_index]); + _buf_index = !_buf_index; + } + } + +#endif + + if (_param_size > rounded_size) { + for (size_t t = rounded_size; t < _param_size; t += TILE) { + size_t copy_size = TILE; + if ((t + TILE) > _param_size) copy_size = _param_size - t; + size_t offset = copy_size + t; + if ((t / TILE) >= 2) { _streams[_buf_index]->wait(); } +#pragma omp parallel for + for (size_t k = t; k < offset; k++) { + float grad = half_precision ? (float)grads_cast_h[k] : grads[k]; + float param = half_precision ? (float)params_cast_h[k] : _params[k]; + float momentum = _exp_avg[k]; + float variance = _exp_avg_sq[k]; + if (_weight_decay > 0 && !_adamw_mode) { grad = param * _weight_decay + grad; } + momentum = momentum * _betta1; + momentum = grad * betta1_minus1 + momentum; + + variance = variance * _betta2; + grad = grad * grad; + variance = grad * betta2_minus1 + variance; + + grad = sqrt(variance); + grad = grad * _bias_correction2 + _eps; + grad = momentum / grad; + if (_weight_decay > 0 && _adamw_mode) { param += w_decay * param; } + param = grad * step_size + param; + if (dev_params) _doubled_buffer[_buf_index][k - t] = param; + + if (half_precision) + params_cast_h[k] = (sycl::half)param; + else + _params[k] = param; + + // _params[k] = param; + _exp_avg[k] = momentum; + _exp_avg_sq[k] = variance; + } + if (dev_params) { + launch_param_update( + _doubled_buffer[_buf_index], dev_params + t, (copy_size), _streams[_buf_index]); + _buf_index = !_buf_index; + } + } + } +} + +void Adam_Optimizer::Step_4(float* _params, + float* grads, + float* _exp_avg, + float* _exp_avg_sq, + size_t _param_size, + sycl::half* dev_params, + bool half_precision) +{ + size_t rounded_size = 0; + +#if defined(__AVX512__) or defined(__AVX256__) + + AVX_Data betta1_4; + betta1_4.data = SIMD_SET(_betta1); + AVX_Data betta2_4; + betta2_4.data = SIMD_SET(_betta2); + + float betta1_minus1 = 1 - _betta1; + float betta2_minus1 = 1 - _betta2; + AVX_Data betta1_minus1_4; + betta1_minus1_4.data = SIMD_SET(betta1_minus1); + AVX_Data betta2_minus1_4; + betta2_minus1_4.data = SIMD_SET(betta2_minus1); + + AVX_Data bias2_sqrt; + bias2_sqrt.data = SIMD_SET(_bias_correction2); + + AVX_Data eps_4; + eps_4.data = SIMD_SET(_eps); + + float step_size = -1 * _alpha / _bias_correction1; + AVX_Data step_size_4; + step_size_4.data = SIMD_SET(step_size); + + float w_decay = -1 * _alpha * _weight_decay; + AVX_Data weight_decay4; + if (_weight_decay > 0) + weight_decay4.data = (_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay)); + rounded_size = ROUND_DOWN(_param_size, (SIMD_WIDTH << 2)); + + for (size_t t = 0; t < rounded_size; t += TILE) { + size_t copy_size = TILE; + if ((t + TILE) > rounded_size) copy_size = rounded_size - t; + size_t offset = copy_size + t; + if ((t / TILE) >= 2) { _streams[_buf_index]->wait(); } +#pragma omp parallel for + for (size_t i = t; i < offset; i += (SIMD_WIDTH << 2)) { + AVX_Data grad_4[4]; + grad_4[0].data = SIMD_LOAD(grads + i); + grad_4[1].data = SIMD_LOAD(grads + i + SIMD_WIDTH); + grad_4[2].data = SIMD_LOAD(grads + i + (SIMD_WIDTH << 1)); + grad_4[3].data = SIMD_LOAD(grads + i + SIMD_WIDTH * 3); + + AVX_Data momentum_4[4]; + momentum_4[0].data = SIMD_LOAD(_exp_avg + i); + momentum_4[1].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH); + momentum_4[2].data = SIMD_LOAD(_exp_avg + i + (SIMD_WIDTH << 1)); + momentum_4[3].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH * 3); + + AVX_Data variance_4[4]; + variance_4[0].data = SIMD_LOAD(_exp_avg_sq + i); + variance_4[1].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH); + variance_4[2].data = SIMD_LOAD(_exp_avg_sq + i + (SIMD_WIDTH << 1)); + variance_4[3].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH * 3); + + AVX_Data param_4[4]; + param_4[0].data = SIMD_LOAD(_params + i); + param_4[1].data = SIMD_LOAD(_params + i + SIMD_WIDTH); + param_4[2].data = SIMD_LOAD(_params + i + (SIMD_WIDTH << 1)); + param_4[3].data = SIMD_LOAD(_params + i + SIMD_WIDTH * 3); + + if (_weight_decay > 0 && !_adamw_mode) { + grad_4[0].data = SIMD_FMA(param_4[0].data, weight_decay4.data, grad_4[0].data); + grad_4[1].data = SIMD_FMA(param_4[1].data, weight_decay4.data, grad_4[1].data); + grad_4[2].data = SIMD_FMA(param_4[2].data, weight_decay4.data, grad_4[2].data); + grad_4[3].data = SIMD_FMA(param_4[3].data, weight_decay4.data, grad_4[3].data); + } + + momentum_4[0].data = SIMD_MUL(momentum_4[0].data, betta1_4.data); + momentum_4[0].data = SIMD_FMA(grad_4[0].data, betta1_minus1_4.data, momentum_4[0].data); + momentum_4[1].data = SIMD_MUL(momentum_4[1].data, betta1_4.data); + momentum_4[1].data = SIMD_FMA(grad_4[1].data, betta1_minus1_4.data, momentum_4[1].data); + momentum_4[2].data = SIMD_MUL(momentum_4[2].data, betta1_4.data); + momentum_4[2].data = SIMD_FMA(grad_4[2].data, betta1_minus1_4.data, momentum_4[2].data); + momentum_4[3].data = SIMD_MUL(momentum_4[3].data, betta1_4.data); + momentum_4[3].data = SIMD_FMA(grad_4[3].data, betta1_minus1_4.data, momentum_4[3].data); + + variance_4[0].data = SIMD_MUL(variance_4[0].data, betta2_4.data); + variance_4[1].data = SIMD_MUL(variance_4[1].data, betta2_4.data); + variance_4[2].data = SIMD_MUL(variance_4[2].data, betta2_4.data); + variance_4[3].data = SIMD_MUL(variance_4[3].data, betta2_4.data); + grad_4[0].data = SIMD_MUL(grad_4[0].data, grad_4[0].data); + grad_4[1].data = SIMD_MUL(grad_4[1].data, grad_4[1].data); + grad_4[2].data = SIMD_MUL(grad_4[2].data, grad_4[2].data); + grad_4[3].data = SIMD_MUL(grad_4[3].data, grad_4[3].data); + variance_4[0].data = SIMD_FMA(grad_4[0].data, betta2_minus1_4.data, variance_4[0].data); + variance_4[1].data = SIMD_FMA(grad_4[1].data, betta2_minus1_4.data, variance_4[1].data); + variance_4[2].data = SIMD_FMA(grad_4[2].data, betta2_minus1_4.data, variance_4[2].data); + variance_4[3].data = SIMD_FMA(grad_4[3].data, betta2_minus1_4.data, variance_4[3].data); + + grad_4[0].data = SIMD_SQRT(variance_4[0].data); + grad_4[1].data = SIMD_SQRT(variance_4[1].data); + grad_4[2].data = SIMD_SQRT(variance_4[2].data); + grad_4[3].data = SIMD_SQRT(variance_4[3].data); + + grad_4[0].data = SIMD_FMA(grad_4[0].data, bias2_sqrt.data, eps_4.data); + grad_4[1].data = SIMD_FMA(grad_4[1].data, bias2_sqrt.data, eps_4.data); + grad_4[2].data = SIMD_FMA(grad_4[2].data, bias2_sqrt.data, eps_4.data); + grad_4[3].data = SIMD_FMA(grad_4[3].data, bias2_sqrt.data, eps_4.data); + grad_4[0].data = SIMD_DIV(momentum_4[0].data, grad_4[0].data); + grad_4[1].data = SIMD_DIV(momentum_4[1].data, grad_4[1].data); + grad_4[2].data = SIMD_DIV(momentum_4[2].data, grad_4[2].data); + grad_4[3].data = SIMD_DIV(momentum_4[3].data, grad_4[3].data); + + if (_weight_decay > 0 && _adamw_mode) { + param_4[0].data = SIMD_FMA(param_4[0].data, weight_decay4.data, param_4[0].data); + param_4[1].data = SIMD_FMA(param_4[1].data, weight_decay4.data, param_4[1].data); + param_4[2].data = SIMD_FMA(param_4[2].data, weight_decay4.data, param_4[2].data); + param_4[3].data = SIMD_FMA(param_4[3].data, weight_decay4.data, param_4[3].data); + } + + param_4[0].data = SIMD_FMA(grad_4[0].data, step_size_4.data, param_4[0].data); + param_4[1].data = SIMD_FMA(grad_4[1].data, step_size_4.data, param_4[1].data); + param_4[2].data = SIMD_FMA(grad_4[2].data, step_size_4.data, param_4[2].data); + param_4[3].data = SIMD_FMA(grad_4[3].data, step_size_4.data, param_4[3].data); + + SIMD_STORE(_params + i, param_4[0].data); + SIMD_STORE(_params + i + SIMD_WIDTH, param_4[1].data); + SIMD_STORE(_params + i + (SIMD_WIDTH << 1), param_4[2].data); + SIMD_STORE(_params + i + SIMD_WIDTH * 3, param_4[3].data); + + if (dev_params) { + SIMD_STORE(_doubled_buffer[_buf_index] + (i - t), param_4[0].data); + SIMD_STORE(_doubled_buffer[_buf_index] + (i - t) + SIMD_WIDTH, param_4[1].data); + SIMD_STORE(_doubled_buffer[_buf_index] + (i - t) + (SIMD_WIDTH << 1), + param_4[2].data); + SIMD_STORE(_doubled_buffer[_buf_index] + (i - t) + SIMD_WIDTH * 3, param_4[3].data); + } + + SIMD_STORE(_exp_avg + i, momentum_4[0].data); + SIMD_STORE(_exp_avg + i + SIMD_WIDTH, momentum_4[1].data); + SIMD_STORE(_exp_avg + i + (SIMD_WIDTH << 1), momentum_4[2].data); + SIMD_STORE(_exp_avg + i + SIMD_WIDTH * 3, momentum_4[3].data); + + SIMD_STORE(_exp_avg_sq + i, variance_4[0].data); + SIMD_STORE(_exp_avg_sq + i + SIMD_WIDTH, variance_4[1].data); + SIMD_STORE(_exp_avg_sq + i + (SIMD_WIDTH << 1), variance_4[2].data); + SIMD_STORE(_exp_avg_sq + i + SIMD_WIDTH * 3, variance_4[3].data); + } + + if (dev_params) { + launch_param_update( + _doubled_buffer[_buf_index], dev_params + t, copy_size, _streams[_buf_index]); + _buf_index = !_buf_index; + } + } +#endif + if (_param_size > rounded_size) + Step((_params + rounded_size), + (grads + rounded_size), + (_exp_avg + rounded_size), + (_exp_avg_sq + rounded_size), + (_param_size - rounded_size), + (dev_params != nullptr ? (dev_params + rounded_size) : dev_params), + half_precision); +} + +int create_adam_optimizer(int optimizer_id, + float alpha = 1e-3, + float betta1 = 0.9, + float betta2 = 0.999, + float eps = 1e-8, + float weight_decay = 0, + bool adamw_mode = true, + bool should_log = false) +{ + auto opt = + std::make_shared(alpha, betta1, betta2, eps, weight_decay, adamw_mode); + + s_optimizers[optimizer_id] = opt; + + if (should_log) { + std::string avx_type = ""; +#if defined(__AVX512__) + avx_type = "AVX512"; +#else +#if defined(__AVX256__) + avx_type = "AVX2"; +#else + avx_type = "scalar"; +#endif +#endif + + printf("Adam Optimizer #%d is created with %s arithmetic capability.\n", + optimizer_id, + avx_type.c_str()); + printf("Config: alpha=%f, betas=(%f, %f), weight_decay=%f, adam_w=%d\n", + alpha, + betta1, + betta2, + weight_decay, + (int)adamw_mode); + } + + return 0; +} + +void Adam_Optimizer::Step_8(float* _params, + float* grads, + float* _exp_avg, + float* _exp_avg_sq, + size_t _param_size, + sycl::half* dev_params, + bool half_precision) +{ + size_t rounded_size = 0; + +#if defined(__AVX512__) or defined(__AVX256__) + + AVX_Data betta1_4; + betta1_4.data = SIMD_SET(_betta1); + AVX_Data betta2_4; + betta2_4.data = SIMD_SET(_betta2); + + float betta1_minus1 = 1 - _betta1; + float betta2_minus1 = 1 - _betta2; + AVX_Data betta1_minus1_4; + betta1_minus1_4.data = SIMD_SET(betta1_minus1); + AVX_Data betta2_minus1_4; + betta2_minus1_4.data = SIMD_SET(betta2_minus1); + + AVX_Data bias2_sqrt; + bias2_sqrt.data = SIMD_SET(_bias_correction2); + + AVX_Data eps_4; + eps_4.data = SIMD_SET(_eps); + + float step_size = -1 * _alpha / _bias_correction1; + AVX_Data step_size_4; + step_size_4.data = SIMD_SET(step_size); + + float w_decay = -1 * _alpha * _weight_decay; + AVX_Data weight_decay4; + if (_weight_decay > 0) + weight_decay4.data = (_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay)); + rounded_size = ROUND_DOWN(_param_size, (SIMD_WIDTH << 3)); + + for (size_t t = 0; t < rounded_size; t += TILE) { + size_t copy_size = TILE; + if ((t + TILE) > rounded_size) copy_size = rounded_size - t; + size_t offset = copy_size + t; + if ((t / TILE) >= 2) { _streams[_buf_index]->wait(); } +#pragma omp parallel for + for (size_t i = t; i < offset; i += (SIMD_WIDTH << 3)) { + AVX_Data grad_4[8]; + grad_4[0].data = SIMD_LOAD(grads + i); + grad_4[1].data = SIMD_LOAD(grads + i + SIMD_WIDTH); + grad_4[2].data = SIMD_LOAD(grads + i + (SIMD_WIDTH << 1)); + grad_4[3].data = SIMD_LOAD(grads + i + SIMD_WIDTH * 3); + grad_4[4].data = SIMD_LOAD(grads + i + (SIMD_WIDTH << 2)); + grad_4[5].data = SIMD_LOAD(grads + i + SIMD_WIDTH * 5); + grad_4[6].data = SIMD_LOAD(grads + i + SIMD_WIDTH * 6); + grad_4[7].data = SIMD_LOAD(grads + i + SIMD_WIDTH * 7); + + AVX_Data momentum_4[8]; + momentum_4[0].data = SIMD_LOAD(_exp_avg + i); + momentum_4[1].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH); + momentum_4[2].data = SIMD_LOAD(_exp_avg + i + (SIMD_WIDTH << 1)); + momentum_4[3].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH * 3); + momentum_4[4].data = SIMD_LOAD(_exp_avg + i + (SIMD_WIDTH << 2)); + momentum_4[5].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH * 5); + momentum_4[6].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH * 6); + momentum_4[7].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH * 7); + + AVX_Data variance_4[8]; + variance_4[0].data = SIMD_LOAD(_exp_avg_sq + i); + variance_4[1].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH); + variance_4[2].data = SIMD_LOAD(_exp_avg_sq + i + (SIMD_WIDTH << 1)); + variance_4[3].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH * 3); + variance_4[4].data = SIMD_LOAD(_exp_avg_sq + i + (SIMD_WIDTH << 2)); + variance_4[5].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH * 5); + variance_4[6].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH * 6); + variance_4[7].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH * 7); + + AVX_Data param_4[8]; + param_4[0].data = SIMD_LOAD(_params + i); + param_4[1].data = SIMD_LOAD(_params + i + SIMD_WIDTH); + param_4[2].data = SIMD_LOAD(_params + i + (SIMD_WIDTH << 1)); + param_4[3].data = SIMD_LOAD(_params + i + SIMD_WIDTH * 3); + param_4[4].data = SIMD_LOAD(_params + i + (SIMD_WIDTH << 2)); + param_4[5].data = SIMD_LOAD(_params + i + SIMD_WIDTH * 5); + param_4[6].data = SIMD_LOAD(_params + i + SIMD_WIDTH * 6); + param_4[7].data = SIMD_LOAD(_params + i + SIMD_WIDTH * 7); + + if (_weight_decay > 0 && !_adamw_mode) { + grad_4[0].data = SIMD_FMA(param_4[0].data, weight_decay4.data, grad_4[0].data); + grad_4[1].data = SIMD_FMA(param_4[1].data, weight_decay4.data, grad_4[1].data); + grad_4[2].data = SIMD_FMA(param_4[2].data, weight_decay4.data, grad_4[2].data); + grad_4[3].data = SIMD_FMA(param_4[3].data, weight_decay4.data, grad_4[3].data); + grad_4[4].data = SIMD_FMA(param_4[4].data, weight_decay4.data, grad_4[4].data); + grad_4[5].data = SIMD_FMA(param_4[5].data, weight_decay4.data, grad_4[5].data); + grad_4[6].data = SIMD_FMA(param_4[6].data, weight_decay4.data, grad_4[6].data); + grad_4[7].data = SIMD_FMA(param_4[7].data, weight_decay4.data, grad_4[7].data); + } + + momentum_4[0].data = SIMD_MUL(momentum_4[0].data, betta1_4.data); + momentum_4[0].data = SIMD_FMA(grad_4[0].data, betta1_minus1_4.data, momentum_4[0].data); + momentum_4[1].data = SIMD_MUL(momentum_4[1].data, betta1_4.data); + momentum_4[1].data = SIMD_FMA(grad_4[1].data, betta1_minus1_4.data, momentum_4[1].data); + momentum_4[2].data = SIMD_MUL(momentum_4[2].data, betta1_4.data); + momentum_4[2].data = SIMD_FMA(grad_4[2].data, betta1_minus1_4.data, momentum_4[2].data); + momentum_4[3].data = SIMD_MUL(momentum_4[3].data, betta1_4.data); + momentum_4[3].data = SIMD_FMA(grad_4[3].data, betta1_minus1_4.data, momentum_4[3].data); + momentum_4[4].data = SIMD_MUL(momentum_4[4].data, betta1_4.data); + momentum_4[4].data = SIMD_FMA(grad_4[4].data, betta1_minus1_4.data, momentum_4[4].data); + momentum_4[5].data = SIMD_MUL(momentum_4[5].data, betta1_4.data); + momentum_4[5].data = SIMD_FMA(grad_4[5].data, betta1_minus1_4.data, momentum_4[5].data); + momentum_4[6].data = SIMD_MUL(momentum_4[6].data, betta1_4.data); + momentum_4[6].data = SIMD_FMA(grad_4[6].data, betta1_minus1_4.data, momentum_4[6].data); + momentum_4[7].data = SIMD_MUL(momentum_4[7].data, betta1_4.data); + momentum_4[7].data = SIMD_FMA(grad_4[7].data, betta1_minus1_4.data, momentum_4[7].data); + + variance_4[0].data = SIMD_MUL(variance_4[0].data, betta2_4.data); + variance_4[1].data = SIMD_MUL(variance_4[1].data, betta2_4.data); + variance_4[2].data = SIMD_MUL(variance_4[2].data, betta2_4.data); + variance_4[3].data = SIMD_MUL(variance_4[3].data, betta2_4.data); + variance_4[4].data = SIMD_MUL(variance_4[4].data, betta2_4.data); + variance_4[5].data = SIMD_MUL(variance_4[5].data, betta2_4.data); + variance_4[6].data = SIMD_MUL(variance_4[6].data, betta2_4.data); + variance_4[7].data = SIMD_MUL(variance_4[7].data, betta2_4.data); + grad_4[0].data = SIMD_MUL(grad_4[0].data, grad_4[0].data); + grad_4[1].data = SIMD_MUL(grad_4[1].data, grad_4[1].data); + grad_4[2].data = SIMD_MUL(grad_4[2].data, grad_4[2].data); + grad_4[3].data = SIMD_MUL(grad_4[3].data, grad_4[3].data); + grad_4[4].data = SIMD_MUL(grad_4[4].data, grad_4[4].data); + grad_4[5].data = SIMD_MUL(grad_4[5].data, grad_4[5].data); + grad_4[6].data = SIMD_MUL(grad_4[6].data, grad_4[6].data); + grad_4[7].data = SIMD_MUL(grad_4[7].data, grad_4[7].data); + variance_4[0].data = SIMD_FMA(grad_4[0].data, betta2_minus1_4.data, variance_4[0].data); + variance_4[1].data = SIMD_FMA(grad_4[1].data, betta2_minus1_4.data, variance_4[1].data); + variance_4[2].data = SIMD_FMA(grad_4[2].data, betta2_minus1_4.data, variance_4[2].data); + variance_4[3].data = SIMD_FMA(grad_4[3].data, betta2_minus1_4.data, variance_4[3].data); + variance_4[4].data = SIMD_FMA(grad_4[4].data, betta2_minus1_4.data, variance_4[4].data); + variance_4[5].data = SIMD_FMA(grad_4[5].data, betta2_minus1_4.data, variance_4[5].data); + variance_4[6].data = SIMD_FMA(grad_4[6].data, betta2_minus1_4.data, variance_4[6].data); + variance_4[7].data = SIMD_FMA(grad_4[7].data, betta2_minus1_4.data, variance_4[7].data); + + grad_4[0].data = SIMD_SQRT(variance_4[0].data); + grad_4[1].data = SIMD_SQRT(variance_4[1].data); + grad_4[2].data = SIMD_SQRT(variance_4[2].data); + grad_4[3].data = SIMD_SQRT(variance_4[3].data); + grad_4[4].data = SIMD_SQRT(variance_4[4].data); + grad_4[5].data = SIMD_SQRT(variance_4[5].data); + grad_4[6].data = SIMD_SQRT(variance_4[6].data); + grad_4[7].data = SIMD_SQRT(variance_4[7].data); + + grad_4[0].data = SIMD_FMA(grad_4[0].data, bias2_sqrt.data, eps_4.data); + grad_4[1].data = SIMD_FMA(grad_4[1].data, bias2_sqrt.data, eps_4.data); + grad_4[2].data = SIMD_FMA(grad_4[2].data, bias2_sqrt.data, eps_4.data); + grad_4[3].data = SIMD_FMA(grad_4[3].data, bias2_sqrt.data, eps_4.data); + grad_4[4].data = SIMD_FMA(grad_4[4].data, bias2_sqrt.data, eps_4.data); + grad_4[5].data = SIMD_FMA(grad_4[5].data, bias2_sqrt.data, eps_4.data); + grad_4[6].data = SIMD_FMA(grad_4[6].data, bias2_sqrt.data, eps_4.data); + grad_4[7].data = SIMD_FMA(grad_4[7].data, bias2_sqrt.data, eps_4.data); + grad_4[0].data = SIMD_DIV(momentum_4[0].data, grad_4[0].data); + grad_4[1].data = SIMD_DIV(momentum_4[1].data, grad_4[1].data); + grad_4[2].data = SIMD_DIV(momentum_4[2].data, grad_4[2].data); + grad_4[3].data = SIMD_DIV(momentum_4[3].data, grad_4[3].data); + grad_4[4].data = SIMD_DIV(momentum_4[4].data, grad_4[4].data); + grad_4[5].data = SIMD_DIV(momentum_4[5].data, grad_4[5].data); + grad_4[6].data = SIMD_DIV(momentum_4[6].data, grad_4[6].data); + grad_4[7].data = SIMD_DIV(momentum_4[7].data, grad_4[7].data); + + if (_weight_decay > 0 && _adamw_mode) { + param_4[0].data = SIMD_FMA(param_4[0].data, weight_decay4.data, param_4[0].data); + param_4[1].data = SIMD_FMA(param_4[1].data, weight_decay4.data, param_4[1].data); + param_4[2].data = SIMD_FMA(param_4[2].data, weight_decay4.data, param_4[2].data); + param_4[3].data = SIMD_FMA(param_4[3].data, weight_decay4.data, param_4[3].data); + param_4[4].data = SIMD_FMA(param_4[4].data, weight_decay4.data, param_4[4].data); + param_4[5].data = SIMD_FMA(param_4[5].data, weight_decay4.data, param_4[5].data); + param_4[6].data = SIMD_FMA(param_4[6].data, weight_decay4.data, param_4[6].data); + param_4[7].data = SIMD_FMA(param_4[7].data, weight_decay4.data, param_4[7].data); + } + + param_4[0].data = SIMD_FMA(grad_4[0].data, step_size_4.data, param_4[0].data); + param_4[1].data = SIMD_FMA(grad_4[1].data, step_size_4.data, param_4[1].data); + param_4[2].data = SIMD_FMA(grad_4[2].data, step_size_4.data, param_4[2].data); + param_4[3].data = SIMD_FMA(grad_4[3].data, step_size_4.data, param_4[3].data); + param_4[4].data = SIMD_FMA(grad_4[4].data, step_size_4.data, param_4[4].data); + param_4[5].data = SIMD_FMA(grad_4[5].data, step_size_4.data, param_4[5].data); + param_4[6].data = SIMD_FMA(grad_4[6].data, step_size_4.data, param_4[6].data); + param_4[7].data = SIMD_FMA(grad_4[7].data, step_size_4.data, param_4[7].data); + + SIMD_STORE(_params + i, param_4[0].data); + SIMD_STORE(_params + i + SIMD_WIDTH, param_4[1].data); + SIMD_STORE(_params + i + (SIMD_WIDTH << 1), param_4[2].data); + SIMD_STORE(_params + i + SIMD_WIDTH * 3, param_4[3].data); + SIMD_STORE(_params + i + (SIMD_WIDTH << 2), param_4[4].data); + SIMD_STORE(_params + i + SIMD_WIDTH * 5, param_4[5].data); + SIMD_STORE(_params + i + SIMD_WIDTH * 6, param_4[6].data); + SIMD_STORE(_params + i + SIMD_WIDTH * 7, param_4[7].data); + + if (dev_params) { + SIMD_STORE(_doubled_buffer[_buf_index] + (i - t), param_4[0].data); + SIMD_STORE(_doubled_buffer[_buf_index] + (i - t) + SIMD_WIDTH, param_4[1].data); + SIMD_STORE(_doubled_buffer[_buf_index] + (i - t) + (SIMD_WIDTH << 1), + param_4[2].data); + SIMD_STORE(_doubled_buffer[_buf_index] + (i - t) + SIMD_WIDTH * 3, param_4[3].data); + SIMD_STORE(_doubled_buffer[_buf_index] + (i - t) + (SIMD_WIDTH << 2), + param_4[4].data); + SIMD_STORE(_doubled_buffer[_buf_index] + (i - t) + SIMD_WIDTH * 5, param_4[5].data); + SIMD_STORE(_doubled_buffer[_buf_index] + (i - t) + SIMD_WIDTH * 6, param_4[6].data); + SIMD_STORE(_doubled_buffer[_buf_index] + (i - t) + SIMD_WIDTH * 7, param_4[7].data); + } + + SIMD_STORE(_exp_avg + i, momentum_4[0].data); + SIMD_STORE(_exp_avg + i + SIMD_WIDTH, momentum_4[1].data); + SIMD_STORE(_exp_avg + i + (SIMD_WIDTH << 1), momentum_4[2].data); + SIMD_STORE(_exp_avg + i + SIMD_WIDTH * 3, momentum_4[3].data); + SIMD_STORE(_exp_avg + i + (SIMD_WIDTH << 2), momentum_4[4].data); + SIMD_STORE(_exp_avg + i + SIMD_WIDTH * 5, momentum_4[5].data); + SIMD_STORE(_exp_avg + i + SIMD_WIDTH * 6, momentum_4[6].data); + SIMD_STORE(_exp_avg + i + SIMD_WIDTH * 7, momentum_4[7].data); + + SIMD_STORE(_exp_avg_sq + i, variance_4[0].data); + SIMD_STORE(_exp_avg_sq + i + SIMD_WIDTH, variance_4[1].data); + SIMD_STORE(_exp_avg_sq + i + (SIMD_WIDTH << 1), variance_4[2].data); + SIMD_STORE(_exp_avg_sq + i + SIMD_WIDTH * 3, variance_4[3].data); + SIMD_STORE(_exp_avg_sq + i + (SIMD_WIDTH << 2), variance_4[4].data); + SIMD_STORE(_exp_avg_sq + i + SIMD_WIDTH * 5, variance_4[5].data); + SIMD_STORE(_exp_avg_sq + i + SIMD_WIDTH * 6, variance_4[6].data); + SIMD_STORE(_exp_avg_sq + i + SIMD_WIDTH * 7, variance_4[7].data); + } + if (dev_params) { + launch_param_update( + _doubled_buffer[_buf_index], dev_params + t, copy_size, _streams[_buf_index]); + _buf_index = !_buf_index; + } + } +#endif + if (_param_size > rounded_size) + Step_4((_params + rounded_size), + (grads + rounded_size), + (_exp_avg + rounded_size), + (_exp_avg_sq + rounded_size), + (_param_size - rounded_size), + (dev_params != nullptr ? (dev_params + rounded_size) : dev_params), + half_precision); +} + +int ds_adam_step(int optimizer_id, + size_t step, + float lr, + float beta1, + float beta2, + float epsilon, + float weight_decay, + bool bias_correction, + torch::Tensor& params, + torch::Tensor& grads, + torch::Tensor& exp_avg, + torch::Tensor& exp_avg_sq) +{ + auto params_c = params.contiguous(); + auto grads_c = grads.contiguous(); + auto exp_avg_c = exp_avg.contiguous(); + auto exp_avg_sq_c = exp_avg_sq.contiguous(); + + float* params_ptr = (float*)params_c.data_ptr(); + float* grads_ptr = (float*)grads_c.data_ptr(); + float* exp_avg_ptr = (float*)exp_avg_c.data_ptr(); + float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr(); + + std::shared_ptr opt = + std::static_pointer_cast(s_optimizers[optimizer_id]); + opt->IncrementStep(step, beta1, beta2); + opt->update_state(lr, epsilon, weight_decay, bias_correction); + opt->Step_8(params_ptr, + grads_ptr, + exp_avg_ptr, + exp_avg_sq_ptr, + params_c.size(0), + nullptr, + (params.options().dtype() == at::kHalf)); + + opt->SynchronizeStreams(); + return 0; +} + +int ds_adam_step_plus_copy(int optimizer_id, + size_t step, + float lr, + float beta1, + float beta2, + float epsilon, + float weight_decay, + bool bias_correction, + torch::Tensor& params, + torch::Tensor& grads, + torch::Tensor& exp_avg, + torch::Tensor& exp_avg_sq, + torch::Tensor& gpu_params) +{ + auto params_c = params.contiguous(); + auto gpu_params_c = gpu_params.contiguous(); + auto exp_avg_c = exp_avg.contiguous(); + auto exp_avg_sq_c = exp_avg_sq.contiguous(); + auto grads_c = grads.contiguous(); + + float* params_ptr = (float*)params_c.data_ptr(); + float* grads_ptr = (float*)grads_c.data_ptr(); + sycl::half* gpu_params_ptr = (sycl::half*)gpu_params_c.data_ptr(); + float* exp_avg_ptr = (float*)exp_avg_c.data_ptr(); + float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr(); + + std::shared_ptr opt = + std::static_pointer_cast(s_optimizers[optimizer_id]); + opt->IncrementStep(step, beta1, beta2); + opt->update_state(lr, epsilon, weight_decay, bias_correction); + opt->Step_8( + params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr, params_c.size(0), gpu_params_ptr); + + opt->SynchronizeStreams(); + return 0; +} + +int destroy_adam_optimizer(int optimizer_id) +{ + s_optimizers.erase(optimizer_id); + + return 0; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("adam_update", &ds_adam_step, "DeepSpeed CPU Adam update (C++)"); + m.def("adam_update_copy", + &ds_adam_step_plus_copy, + "DeepSpeed CPU Adam update and param copy (C++)"); + m.def("create_adam", &create_adam_optimizer, "DeepSpeed CPU Adam (C++)"); + m.def("destroy_adam", &destroy_adam_optimizer, "DeepSpeed CPU Adam destroy (C++)"); +} diff --git a/intel_extension_for_deepspeed/op_builder/csrc/adam/sycl/custom_sycl_kernel.dp.cpp b/intel_extension_for_deepspeed/op_builder/csrc/adam/sycl/custom_sycl_kernel.dp.cpp new file mode 100644 index 0000000..7049c31 --- /dev/null +++ b/intel_extension_for_deepspeed/op_builder/csrc/adam/sycl/custom_sycl_kernel.dp.cpp @@ -0,0 +1,26 @@ +#include +#include "sycl/custom_sycl_layers.hpp" + +void param_update_kernel(const float* input, + sycl::half* output, + int size, + sycl::nd_item<3> item_ct1) +{ + int id = item_ct1.get_group(2) * item_ct1.get_local_range(2) + item_ct1.get_local_id(2); + + if (id < size) { output[id] = (sycl::half)input[id]; } +} + +void launch_param_update(const float* input, sycl::half* output, int size, sycl::queue* stream) +{ + int threads = 1024; + + sycl::range<3> grid_dim(1, 1, (size - 1) / threads + 1); + sycl::range<3> block_dim(1, 1, threads); + + stream->submit([&](sycl::handler& cgh) { + cgh.parallel_for( + sycl::nd_range<3>(grid_dim * block_dim, block_dim), + [=](sycl::nd_item<3> item_ct1) { param_update_kernel(input, output, size, item_ct1); }); + }); +} diff --git a/intel_extension_for_deepspeed/op_builder/csrc/adam/sycl/fused_adam_frontend.cpp b/intel_extension_for_deepspeed/op_builder/csrc/adam/sycl/fused_adam_frontend.cpp new file mode 100644 index 0000000..37bab4a --- /dev/null +++ b/intel_extension_for_deepspeed/op_builder/csrc/adam/sycl/fused_adam_frontend.cpp @@ -0,0 +1,20 @@ +#include + +void multi_tensor_adam_sycl(int chunk_size, + at::Tensor noop_flag, + std::vector> tensor_lists, + const float lr, + const float beta1, + const float beta2, + const float epsilon, + const int step, + const int mode, + const int bias_correction, + const float weight_decay); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("multi_tensor_adam", + &multi_tensor_adam_sycl, + "Compute and apply gradient update to parameters for Adam optimizer"); +} diff --git a/intel_extension_for_deepspeed/op_builder/csrc/adam/sycl/multi_tensor_adam.dp.cpp b/intel_extension_for_deepspeed/op_builder/csrc/adam/sycl/multi_tensor_adam.dp.cpp new file mode 100644 index 0000000..f4aafe1 --- /dev/null +++ b/intel_extension_for_deepspeed/op_builder/csrc/adam/sycl/multi_tensor_adam.dp.cpp @@ -0,0 +1,215 @@ +/* Copyright 2020 The Microsoft DeepSpeed Team + Copyright NVIDIA/apex +*/ + +#include +#include + +#include "multi_tensor_apply.dp.hpp" +#include "sycl/type_shim.hpp" + +#define BLOCK_SIZE 512 +#define ILP 4 + +typedef enum { + ADAM_MODE_0 = 0, // L2 regularization mode + ADAM_MODE_1 = 1 // Decoupled weight decay mode(AdamW) +} adamMode_t; + +using MATH_T = float; + +template +void AdamFunctor(sycl::nd_item<1> item_ct1, + int chunk_size, + int* noop_gmem, + const int tensor_loc, + const int chunk_idx, + int n, + T* g, + T* p, + T* m, + T* v, + const float beta1, + const float beta2, + const float beta1_correction, + const float beta2_correction, + const float epsilon, + const float lr, + const int mode, + const float decay) +{ + g += chunk_idx * chunk_size; + + p += chunk_idx * chunk_size; + + m += chunk_idx * chunk_size; + + v += chunk_idx * chunk_size; + + n -= chunk_idx * chunk_size; + + // see note in multi_tensor_scale_kernel.cu + for (int i_start = 0; i_start < n && i_start < chunk_size; + i_start += item_ct1.get_local_range(0) * ILP) { + MATH_T r_g[ILP]; + MATH_T r_p[ILP]; + MATH_T r_m[ILP]; + MATH_T r_v[ILP]; +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + item_ct1.get_local_id(0) + ii * item_ct1.get_local_range(0); + if (i < n && i < chunk_size) { + r_g[ii] = g[i]; + r_p[ii] = p[i]; + r_m[ii] = m[i]; + r_v[ii] = v[i]; + } else { + r_g[ii] = MATH_T(0); + r_p[ii] = MATH_T(0); + r_m[ii] = MATH_T(0); + r_v[ii] = MATH_T(0); + } + } + +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + if (mode == ADAM_MODE_0) { // L2 + r_g[ii] = r_g[ii] + (decay * r_p[ii]); + r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii]; + r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii]; + MATH_T next_m_unbiased = r_m[ii] / beta1_correction; + MATH_T next_v_unbiased = r_v[ii] / beta2_correction; + MATH_T denom = sycl::sqrt((float)next_v_unbiased) + epsilon; + MATH_T update = next_m_unbiased / denom; + r_p[ii] = r_p[ii] - (lr * update); + } else { // weight decay + r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii]; + r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii]; + MATH_T next_m_unbiased = r_m[ii] / beta1_correction; + MATH_T next_v_unbiased = r_v[ii] / beta2_correction; + MATH_T denom = sycl::sqrt((float)next_v_unbiased) + epsilon; + MATH_T update = (next_m_unbiased / denom) + (decay * r_p[ii]); + r_p[ii] = r_p[ii] - (lr * update); + } + } +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + item_ct1.get_local_id(0) + ii * item_ct1.get_local_range(0); + if (i < n && i < chunk_size) { + p[i] = r_p[ii]; + m[i] = r_m[ii]; + v[i] = r_v[ii]; + } + } + } +} + +void test_queue_with_accessor(void) +{ + printf("Test queue with accessor\n"); + auto type_ = c10::DeviceType::XPU; + c10::impl::VirtualGuardImpl impl(type_); + auto device_ = c10::Device(type_); + c10::Stream dpcpp_stream = impl.getStream(device_); + sycl::queue* stream = &(xpu::get_queue_from_stream(dpcpp_stream)); + sycl::default_selector d_selector; + static auto exception_handler = [](sycl::exception_list e_list) { + for (std::exception_ptr const& e : e_list) { + try { + std::rethrow_exception(e); + } catch (std::exception const& e) { + std::cout << "Failure" << std::endl; + std::terminate(); + } + } + }; + sycl::queue dq(d_selector, + exception_handler, + {sycl::property::queue::in_order(), sycl::property::queue::enable_profiling()}); + struct { + unsigned char block_to_tensor[320]; + int block_to_chunk[320]; + void* addresses[4][36]; + int sizes[36]; + } tll; + sycl::buffer block_to_tensor_buf(&(tll.block_to_tensor[0]), {320}); + sycl::buffer block_to_chunk_buf(&(tll.block_to_chunk[0]), {320}); + sycl::buffer addresses_buf(&(tll.addresses[0][0]), {4, 36}); + sycl::buffer sizes_buf(&(tll.sizes[0]), {36}); + printf("submit dq without accessor "); + dq.submit([&](sycl::handler& cgh) { + cgh.parallel_for(sycl::nd_range<1>(320 * 512, 512), [=](sycl::nd_item<1> item_ct1) {}); + }); + dq.wait(); + printf("done\n"); + printf("submit dq with accessor "); + dq.submit([&](sycl::handler& cgh) { + sycl::accessor tl_block_to_tensor(block_to_tensor_buf, cgh, sycl::read_only); + sycl::accessor tl_block_to_chunk(block_to_chunk_buf, cgh, sycl::read_only); + sycl::accessor tl_addresses(addresses_buf, cgh, sycl::read_only); + sycl::accessor tl_sizes(sizes_buf, cgh, sycl::read_only); + cgh.parallel_for(sycl::nd_range<1>(320 * 512, 512), [=](sycl::nd_item<1> item_ct1) {}); + }); + dq.wait(); + printf("done\n"); + printf("submit xpu::stream without accessor "); + stream->submit([&](sycl::handler& cgh) { + cgh.parallel_for(sycl::nd_range<1>(320 * 512, 512), [=](sycl::nd_item<1> item_ct1) {}); + }); + stream->wait(); + printf("done\n"); + printf("submit xpu::stream with accessor "); + stream->submit([&](sycl::handler& cgh) { + sycl::accessor tl_block_to_tensor(block_to_tensor_buf, cgh, sycl::read_only); + sycl::accessor tl_block_to_chunk(block_to_chunk_buf, cgh, sycl::read_only); + sycl::accessor tl_addresses(addresses_buf, cgh, sycl::read_only); + sycl::accessor tl_sizes(sizes_buf, cgh, sycl::read_only); + cgh.parallel_for(sycl::nd_range<1>(320 * 512, 512), [=](sycl::nd_item<1> item_ct1) {}); + }); + stream->wait(); + printf("done\n"); +} + +void multi_tensor_test(void) +{ + printf("inside multi_tensor_test\n"); + test_queue_with_accessor(); +} + +void multi_tensor_adam_sycl(int chunk_size, + at::Tensor noop_flag, + std::vector> tensor_lists, + const float lr, + const float beta1, + const float beta2, + const float epsilon, + const int step, + const int mode, + const int bias_correction, + const float weight_decay) +{ + using namespace at; + + // Handle bias correction mode + float bias_correction1 = 1.0f, bias_correction2 = 1.0f; + if (bias_correction == 1) { + bias_correction1 = 1 - std::pow(beta1, step); + bias_correction2 = 1 - std::pow(beta2, step); + } + // Assume single type across p,g,m1,m2 + DISPATCH_DOUBLE_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), + 0, + "adam", + multi_tensor_apply<4, scalar_t_0>(BLOCK_SIZE, + chunk_size, + noop_flag, + tensor_lists, + beta1, + beta2, + bias_correction1, + bias_correction2, + epsilon, + lr, + mode, + weight_decay)) +} diff --git a/intel_extension_for_deepspeed/op_builder/csrc/includes/multi_tensor_apply.dp.hpp b/intel_extension_for_deepspeed/op_builder/csrc/includes/multi_tensor_apply.dp.hpp new file mode 100644 index 0000000..30de518 --- /dev/null +++ b/intel_extension_for_deepspeed/op_builder/csrc/includes/multi_tensor_apply.dp.hpp @@ -0,0 +1,174 @@ +/* Copyright 2020 The Microsoft DeepSpeed Team + Copyright NVIDIA/apex + This file is adapted from fused adam in NVIDIA/apex, commit a109f85 +*/ +#pragma once + +#include +#include +#include +#include "compat.h" +#include "sycl/context.hpp" + +#include +#include + +// #include + +// This header is the one-stop shop for all your multi-tensor apply needs. + +constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30}; +constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320}; + +template +struct TensorListMetadata { + void* addresses[n][depth_to_max_tensors[n - 1]]; + int sizes[depth_to_max_tensors[n - 1]]; + unsigned char block_to_tensor[depth_to_max_blocks[n - 1]]; + int block_to_chunk[depth_to_max_blocks[n - 1]]; // I fear this needs to be a + // full int. + int start_tensor_this_launch; +}; + +template +SYCL_EXTERNAL void AdamFunctor(sycl::nd_item<1> item_ct1, + int chunk_size, + int* noop_gmem, + const int tensor_loc, + const int chunk_idx, + int n, + T* g, + T* p, + T* m, + T* v, + const float beta1, + const float beta2, + const float beta1_correction, + const float beta2_correction, + const float epsilon, + const float lr, + const int mode, + const float decay); + +template +void multi_tensor_apply(int block_size, + int chunk_size, + const at::Tensor& noop_flag, + const std::vector>& tensor_lists, + const float beta1, + const float beta2, + const float beta_correction1, + const float beta_correction2, + const float epsilon, + const float lr, + const int mode, + const float decay) +{ + TORCH_CHECK(tensor_lists.size() == depth, "tensor_lists.size() != depth"); + int len0 = tensor_lists[0].size(); + TORCH_CHECK(len0 > 0, "tensor_lists[0].size() is not > 0"); + auto ref_device = tensor_lists[0][0].device(); + TORCH_CHECK(ref_device.type() == at::kXPU, "expected input to be on XPU"); + for (int l = 0; l < tensor_lists.size(); l++) // No range-based for because I need indices + { + TORCH_CHECK(tensor_lists[l].size() == len0, "Size mismatch among tensor lists"); + for (int t = 0; t < tensor_lists[l].size(); t++) { + bool contiguous_memory = tensor_lists[l][t].is_contiguous(); +#ifdef VERSION_GE_1_5 + contiguous_memory = (contiguous_memory || + tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast)); +#endif + TORCH_CHECK(contiguous_memory, "A tensor was not contiguous."); + TORCH_CHECK(tensor_lists[l][t].device() == ref_device, + "A tensor was not on the same device as the first tensor"); + TORCH_CHECK(tensor_lists[l][t].numel() == tensor_lists[0][t].numel(), "Size mismatch"); + } + } + + int ntensors = tensor_lists[0].size(); + + TensorListMetadata tl; + + sycl::queue* stream = SyclContext::Instance().GetCurrentStream(); + + tl.start_tensor_this_launch = 0; + int loc_block_info = 0; + int loc_tensor_info = 0; + for (int t = 0; t < ntensors; t++) { + tl.sizes[loc_tensor_info] = tensor_lists[0][t].numel(); + for (int d = 0; d < depth; d++) + tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr(); + loc_tensor_info++; + + int chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size; + + for (int chunk = 0; chunk < chunks_this_tensor; chunk++) { + // std::cout << chunks_this_tensor << std::endl; + tl.block_to_tensor[loc_block_info] = loc_tensor_info - 1; + tl.block_to_chunk[loc_block_info] = chunk; + loc_block_info++; + + bool tensors_full = (loc_tensor_info == depth_to_max_tensors[depth - 1] && + chunk == chunks_this_tensor - 1); + bool blocks_full = (loc_block_info == depth_to_max_blocks[depth - 1]); + bool last_chunk = (t == ntensors - 1 && chunk == chunks_this_tensor - 1); + if (tensors_full || blocks_full || last_chunk) { + int* data_ptr = noop_flag.DATA_PTR(); + sycl::buffer block_to_tensor_buf(&(tl.block_to_tensor[0]), {320}); + sycl::buffer block_to_chunk_buf(&(tl.block_to_chunk[0]), {320}); + sycl::buffer addresses_buf(&(tl.addresses[0][0]), {4, 36}); + sycl::buffer sizes_buf(&(tl.sizes[0]), {36}); + sycl::buffer data_buf(data_ptr, noop_flag.numel()); + stream->submit([&](sycl::handler& cgh) { + sycl::accessor tl_block_to_tensor(block_to_tensor_buf, cgh, sycl::read_only); + sycl::accessor tl_block_to_chunk(block_to_chunk_buf, cgh, sycl::read_only); + sycl::accessor tl_addresses(addresses_buf, cgh, sycl::read_only); + sycl::accessor tl_sizes(sizes_buf, cgh, sycl::read_only); + sycl::accessor data_acc(data_buf, cgh, sycl::read_only); + cgh.parallel_for(sycl::nd_range<1>(loc_block_info * block_size, block_size), + [=](sycl::nd_item<1> item_ct1) { + int tensor_loc = tl_block_to_tensor[item_ct1.get_group(0)]; + int chunk_idx = tl_block_to_chunk[item_ct1.get_group(0)]; + int n = tl_sizes[tensor_loc]; + T* g = (T*)tl_addresses[0][tensor_loc]; + T* p = (T*)tl_addresses[1][tensor_loc]; + T* m = (T*)tl_addresses[2][tensor_loc]; + T* v = (T*)tl_addresses[3][tensor_loc]; + + AdamFunctor(item_ct1, + chunk_size, + data_acc.get_pointer(), + tensor_loc, + chunk_idx, + n, + g, + p, + m, + v, + beta1, + beta2, + beta_correction1, + beta_correction2, + epsilon, + lr, + mode, + decay); + }); + }); + + // Reset. The control flow possibilities here make my brain hurt. + loc_block_info = 0; + if (chunk == chunks_this_tensor - 1) { + loc_tensor_info = 0; + tl.start_tensor_this_launch = t + 1; + } else { + tl.sizes[0] = tl.sizes[loc_tensor_info - 1]; + for (int d = 0; d < depth; d++) + tl.addresses[d][0] = tl.addresses[d][loc_tensor_info - 1]; + loc_tensor_info = 1; + tl.start_tensor_this_launch = t; + } + } + } + } +} diff --git a/intel_extension_for_deepspeed/op_builder/csrc/includes/sycl/Timer.hpp b/intel_extension_for_deepspeed/op_builder/csrc/includes/sycl/Timer.hpp new file mode 100644 index 0000000..d5a5f84 --- /dev/null +++ b/intel_extension_for_deepspeed/op_builder/csrc/includes/sycl/Timer.hpp @@ -0,0 +1,41 @@ + +#ifndef __TIMER_H__ +#define __TIMER_H__ + +#include +#include + +class GPUTimer { + sycl::event start, stop; + std::chrono::time_point start_ct1; + std::chrono::time_point stop_ct1; + +public: + GPUTimer() {} + ~GPUTimer() {} + + inline void Record() { start_ct1 = std::chrono::steady_clock::now(); } + inline void Elapsed(float& time_elapsed) + { + stop_ct1 = std::chrono::steady_clock::now(); + stop.wait_and_throw(); + time_elapsed = std::chrono::duration(stop_ct1 - start_ct1).count(); + } +}; + +class CPUTimer { + std::chrono::high_resolution_clock::time_point start; + +public: + CPUTimer() : start(std::chrono::high_resolution_clock::now()) {} + inline void Reset() { start = std::chrono::high_resolution_clock::now(); } + inline float Elapsed() + { + auto temp = start; + start = std::chrono::high_resolution_clock::now(); + return (float)(std::chrono::duration_cast(start - temp).count() / + 1e3); + } +}; + +#endif diff --git a/intel_extension_for_deepspeed/op_builder/csrc/includes/sycl/common.hpp b/intel_extension_for_deepspeed/op_builder/csrc/includes/sycl/common.hpp new file mode 100644 index 0000000..953bd3c --- /dev/null +++ b/intel_extension_for_deepspeed/op_builder/csrc/includes/sycl/common.hpp @@ -0,0 +1,30 @@ +#pragma once +#include + +#define CHECK_XPU(x) AT_ASSERTM(x.is_xpu(), #x " must be a XPU tensor") +#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) \ + CHECK_XPU(x); \ + CHECK_CONTIGUOUS(x) + +template +inline void print_nan(sycl::queue* stream, int bsz, const T* buf, char* name) +{ + T temp_tensor[10000]; + bool has_nan = false; + stream->wait(); + stream->memcpy(temp_tensor, buf, bsz * sizeof(T)); + stream->wait(); + for (int i = 0; i < bsz; i++) { + if (isnan(float(temp_tensor[i]))) { has_nan = true; } + } + printf("%s[%d](%p)%s --> ", name, bsz, buf, has_nan ? "has_nan" : ""); + for (int i = 0; i < bsz; i++) { + if (isnan(float(temp_tensor[i]))) { + printf("%d:nan ", i); + } else { + printf("%d:%f, ", i, float(temp_tensor[i])); + } + } + printf("\n"); +} diff --git a/intel_extension_for_deepspeed/op_builder/csrc/includes/sycl/context.hpp b/intel_extension_for_deepspeed/op_builder/csrc/includes/sycl/context.hpp new file mode 100644 index 0000000..2934610 --- /dev/null +++ b/intel_extension_for_deepspeed/op_builder/csrc/includes/sycl/context.hpp @@ -0,0 +1,146 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +using bf16 = sycl::ext::oneapi::experimental::bfloat16; + +#define WARP_SIZE 32 +#define ONEMKL_OP_T oneapi::mkl::transpose::trans +#define ONEMKL_OP_N oneapi::mkl::transpose::nontrans + +#define DPCPP_1D_KERNEL_LOOP(i, n) \ + for (size_t(i) = item_ct1.get_global_id(2); (i) < (n); (i) += item_ct1.get_global_range(2)) + +#define DPCPP_2D_KERNEL_LOOP(i, n, j, m) \ + for (size_t i = item_ct1.get_global_id(2); (i) < (n); (i) += item_ct1.get_global_range(2)) \ + for (size_t j = item_ct1.get_global_id(1); (j) < (m); (j) += item_ct1.get_global_range(1)) + +#define DS_CUDA_NUM_THREADS 512 +#define DS_MAXIMUM_NUM_BLOCKS 262144 + +inline int DS_GET_BLOCKS(const int N) +{ + return (std::max)( + (std::min)((N + DS_CUDA_NUM_THREADS - 1) / DS_CUDA_NUM_THREADS, DS_MAXIMUM_NUM_BLOCKS), + // Use at least 1 block, since CUDA does not allow empty block + 1); +} + +class SyclContext { +public: + SyclContext() + try : _workspace(nullptr), _seed(42), _curr_offset(0) { + auto type_ = c10::DeviceType::XPU; + c10::impl::VirtualGuardImpl impl(type_); + auto device_ = c10::Device(type_); + c10::Stream dpcpp_stream = impl.getStream(device_); + _gen = new oneapi::mkl::rng::philox4x32x10(xpu::get_queue_from_stream(dpcpp_stream), 123); + if ((_onemklQ = &xpu::get_queue_from_stream(dpcpp_stream), 0) != 0) { + auto message = std::string("Fail to create onemkl queue."); + std::cerr << message << std::endl; + throw std::runtime_error(message); + } + } catch (sycl::exception const& exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ + << std::endl; + std::exit(1); + } + + virtual ~SyclContext() + { + _onemklQ = nullptr; + free(_gen); + + auto type_ = c10::DeviceType::XPU; + c10::impl::VirtualGuardImpl impl(type_); + auto device_ = c10::Device(type_); + c10::Stream dpcpp_stream = impl.getStream(device_); + sycl::free(_workspace, xpu::get_queue_from_stream(dpcpp_stream)); + } + + static SyclContext& Instance() + { + static SyclContext _ctx; + return _ctx; + } + + void SetWorkSpace(void* workspace) + { + if (!workspace) { throw std::runtime_error("Workspace is null."); } + _workspace = workspace; + } + + void* GetWorkSpace() { return _workspace; } + + sycl::queue* GetCurrentStream() + { + // get current pytorch stream. + // return &xpu::dpcpp::getCurrentDPCPPStream().dpcpp_queue(); + + auto type_ = c10::DeviceType::XPU; + c10::impl::VirtualGuardImpl impl(type_); + auto device_ = c10::Device(type_); + c10::Stream dpcpp_stream = impl.getStream(device_); + return &xpu::get_queue_from_stream(dpcpp_stream); + } + + sycl::queue* GetNewStream() + { + auto type_ = c10::DeviceType::XPU; + c10::impl::VirtualGuardImpl impl(type_); + auto device_ = c10::Device(type_); + c10::Stream dpcpp_stream = impl.getStream(device_); + c10::Stream stream = impl.getStreamFromGlobalPool(device_, /*isHighPriority=*/false); + + return &xpu::get_queue_from_stream(dpcpp_stream); + } + + sycl::queue* GetOneMKLQ() { return _onemklQ; } + + std::pair IncrementOffset(uint64_t offset_inc) + { + uint64_t offset = _curr_offset; + _curr_offset += offset_inc; + // set _GPT_DEBUG_ and fix seed to avoid randomness +#ifdef _GPT_DEBUG_ + return std::pair(_seed, 0); +#else + return std::pair(_seed, offset); +#endif + } + + void SetSeed(uint64_t new_seed) { _seed = new_seed; } + + void TestGemmFP16(bool test_gemm, int batch_size, int seq_len, int head_num, int size_per_head) + { + // avoid rerun. + if (_gemm_algos.size() > 0) return; + + // Use default algo. + _gemm_algos.push_back(std::array({99, 99, 99})); + _gemm_algos.push_back(std::array({99, 99, 99})); + _gemm_algos.push_back(std::array({99, 99, 99})); + _gemm_algos.push_back(std::array({99, 99, 99})); + _gemm_algos.push_back(std::array({99, 99, 99})); + } + + const std::vector>& GetGemmAlgos() const { return _gemm_algos; } + +private: + oneapi::mkl::rng::philox4x32x10* _gen; + sycl::queue* _onemklQ; + void* _workspace; + uint64_t _seed; + uint64_t _curr_offset; + std::vector> _gemm_algos; +}; diff --git a/intel_extension_for_deepspeed/op_builder/csrc/includes/sycl/cpu_adam.hpp b/intel_extension_for_deepspeed/op_builder/csrc/includes/sycl/cpu_adam.hpp new file mode 100644 index 0000000..3af3109 --- /dev/null +++ b/intel_extension_for_deepspeed/op_builder/csrc/includes/sycl/cpu_adam.hpp @@ -0,0 +1,164 @@ +#pragma once + +#if (__x86_64__ || __i386__) +#include +#include +#include +#endif + +#include +#include +#include +#include "context.hpp" + +#include +#include + +#include + +#define TILE (128 * 1024 * 1024) + +#if defined(__AVX512__) +#define SIMD_STORE(a, d) _mm512_storeu_ps(a, d) +#define SIMD_LOAD(x) _mm512_loadu_ps(x) +#define SIMD_SET(x) _mm512_set1_ps(x) +#define SIMD_MUL(x, y) _mm512_mul_ps(x, y) +#define SIMD_FMA(x, y, c) _mm512_fmadd_ps(x, y, c) +#define SIMD_SQRT(x) _mm512_sqrt_ps(x) +#define SIMD_DIV(x, y) _mm512_div_ps(x, y) +#define SIMD_WIDTH 16 +#else +#if defined(__AVX256__) +#define SIMD_STORE(a, d) _mm256_storeu_ps(a, d) +#define SIMD_LOAD(x) _mm256_loadu_ps(x) +#define SIMD_SET(x) _mm256_set1_ps(x) +#define SIMD_MUL(x, y) _mm256_mul_ps(x, y) +#define SIMD_FMA(x, y, c) _mm256_fmadd_ps(x, y, c) +#define SIMD_SQRT(x) _mm256_sqrt_ps(x) +#define SIMD_DIV(x, y) _mm256_div_ps(x, y) +#define SIMD_WIDTH 8 +#endif +#endif + +class Adam_Optimizer { +public: + Adam_Optimizer(float alpha = 1e-3, + float betta1 = 0.9, + float betta2 = 0.999, + float eps = 1e-8, + float weight_decay = 0, + bool adamw_mode = true) + : _alpha(alpha), + _betta1(betta1), + _betta2(betta2), + _eps(eps), + _weight_decay(weight_decay), + _betta1_t(1.0), + _betta2_t(1.0), + _step(0), + _buf_index(false), + _adamw_mode(adamw_mode) + { + _streams[0] = ::SyclContext::Instance().GetCurrentStream(); + _streams[1] = ::SyclContext::Instance().GetNewStream(); + sycl::queue& q_ct1 = *_streams[0]; + + *_doubled_buffer = sycl::malloc_host(TILE, q_ct1); + *(_doubled_buffer + 1) = sycl::malloc_host(TILE, q_ct1); + } + ~Adam_Optimizer() + { + sycl::queue& q_ct1 = *_streams[0]; + sycl::free(_doubled_buffer[0], q_ct1); + sycl::free(_doubled_buffer[1], q_ct1); + } + void Step(float* _params, + float* grads, + float* _exp_avg, + float* _exp_avg_sq, + size_t param_size, + sycl::half* dev_param = nullptr, + bool half_precision = false); + void Step_4(float* _params, + float* grads, + float* _exp_avg, + float* _exp_avg_sa, + size_t param_size, + sycl::half* dev_param = nullptr, + bool half_precision = false); + void Step_8(float* _params, + float* grads, + float* _exp_avg, + float* _exp_avg_sq, + size_t _param_size, + sycl::half* dev_params = nullptr, + bool half_precision = false); + inline void SynchronizeStreams() + { + for (int i = 0; i < 2; i++) _streams[i]->wait(); + } + inline void IncrementStep(size_t step, float beta1, float beta2) + { + if (beta1 != _betta1 || beta2 != _betta2) { + _step = step; + _betta1 = beta1; + _betta2 = beta2; + _betta1_t = std::pow(_betta1, step); + _betta2_t = std::pow(_betta2, step); + } else { + _step++; + if (_step != step) { + _betta1_t = std::pow(_betta1, step); + _betta2_t = std::pow(_betta2, step); + _step = step; + } else { + _betta1_t *= _betta1; + _betta2_t *= _betta2; + } + } + } + inline void update_state(float lr, float epsilon, float weight_decay, bool bias_correction) + { + _alpha = lr; + _eps = epsilon; + _weight_decay = weight_decay; + + _bias_correction1 = 1.0f; + _bias_correction2 = 1.0f; + if (bias_correction == 1) { + _bias_correction1 = 1 - _betta1_t; + _bias_correction2 = 1 / sqrt(1 - _betta2_t); + } + } + +private: +#if defined(__AVX512__) or defined(__AVX256__) + union AVX_Data { +#if defined(__AVX512__) + __m512 data; +#else + __m256 data; +#endif + // float data_f[16]; + }; +#endif + + float _alpha; + float _betta1; + float _betta2; + float _eps; + float _weight_decay; + + float _betta1_t; + float _betta2_t; + size_t _step; + + float _bias_correction1; + float _bias_correction2; + + float* _doubled_buffer[2]; + bool _buf_index; + bool _adamw_mode; + + sycl::queue* _streams[2]; +}; diff --git a/intel_extension_for_deepspeed/op_builder/csrc/includes/sycl/custom_sycl_layers.hpp b/intel_extension_for_deepspeed/op_builder/csrc/includes/sycl/custom_sycl_layers.hpp new file mode 100644 index 0000000..be6dfbb --- /dev/null +++ b/intel_extension_for_deepspeed/op_builder/csrc/includes/sycl/custom_sycl_layers.hpp @@ -0,0 +1,283 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include "context.hpp" +#include "onednn_wrappers.hpp" +#include "onemkl_wrappers.hpp" + +#define MAX_THREADS 1024 +#define THREADS 256 + +#define MAX_THREAD_STRIDE 32 +#define TILE_DIM 16 + +// Maximum sequence-length support based on the number of threads (2048) allowed +// in each block and this MAX is 8K For higher sequence length we need to use +// higher Max, like for 64K : 32 +#define MAX_THREAD_ITERATIONS 8 // Maximum 8K +#define MAX_WARP_NUM 32 + +#define MAX_REGISTERS 256 + +#define MAX_REG 256 + +template +void launch_qunatize_kernel(T* vals, + int total_count, + int group_num, + int num_bits, + sycl::queue* stream); +template +void launch_sr_qunatize_kernel(T* vals, + int total_count, + int group_num, + int num_bits, + sycl::queue* stream); +template +void launch_qunatize_kernel_asym(T* vals, + int total_count, + int group_num, + int num_bits, + sycl::queue* stream); +template +void launch_sr_qunatize_kernel_asym(T* vals, + int total_count, + int group_num, + int num_bits, + sycl::queue* stream); +// Fused bias add with gelu activation +template +void launch_bias_gelu(const T* input, + const T* bias, + T* output, + int intermediate_size, + int batch_size, + sycl::queue* stream); + +template +void launch_gelu(const T* input, + T* output, + int intermediate_size, + int batch_size, + sycl::queue* stream); + +template +void launch_d_gelu(T* d_output, + const T* input, + const T* bias, + int intermediate_size, + int batch_size, + sycl::queue* stream); + +// Custom fused bias add with layer normalization +template +void launch_bias_residual_layer_norm(T* vals, + const T* residual, + const T* gamma, + const T* beta, + float epsilon, + int batch_size, + int hidden_dim, + sycl::queue* stream, + bool preLayerNorm, + bool training, + T* vars, + T* means); + +template +void launch_bias_residual_layer_norm(T* vals, + const T* residual, + const T* gamma, + const T* beta, + float epsilon, + int batch_size, + int hidden_dim, + sycl::queue* stream, + bool preLayerNorm, + bool training, + T* vars); + +template +void launch_layerNorm_backward_fused_add(const T* out_grad1, + const T* out_grad2, + const T* X_data, + const T* vars, + const T* means, + const T* gamma, + T* gamma_grad, + T* betta_grad, + T* inp_grad, + int batch_size, + int hidden_dim, + sycl::queue* stream[2]); +template +void launch_layerNorm_backward_fused_add(const T* out_grad1, + const T* out_grad2, + const T* vals_hat, + const T* vars, + const T* gamma, + T* gamma_grad, + T* betta_grad, + T* inp_grad, + int batch_size, + int hidden_dim, + sycl::queue* stream[2], + bool invertible = false, + const T* betta = nullptr); + +template +void launch_layerNorm_backward(const T* out_grad, + const T* X_data, + const T* vars, + const T* means, + const T* gamma, + T* gamma_grad, + T* betta_grad, + T* inp_grad, + int batch_size, + int hidden_dim, + sycl::queue* stream[2]); + +template +void launch_layerNorm_backward(const T* out_grad, + const T* vals_hat, + const T* vars, + const T* gamma, + T* gamma_grad, + T* betta_grad, + T* inp_grad, + int batch_size, + int hidden_dim, + sycl::queue* stream[2], + bool invertible = false, + const T* betta = nullptr); + +template +void launch_layerNorm_backward_nreversible(const T* out_grad, + const T* vals, + const T* out_grad_trans, + const T* vals_trans, + const T* means, + const T* vars, + const T* gamma, + T* gamma_grad, + T* betta_grad, + T* inp_grad, + int batch_size, + int hidden_dim, + sycl::queue* stream[2]); + +template +void Transpose(const T* inp_mat, T* out_mat, int rows, int cols, sycl::queue* stream); + +template +void launch_attn_softmax_backward(T* out_grad, + const T* soft_inp, + int batch_size, + int heads, + int seq_length, + sycl::queue* stream); + +template +void launch_attn_softmax_backward_v2(T* out_grad, + const T* soft_inp, + int batch_size, + int heads, + int seq_length, + sycl::queue* stream); + +// Custom softmax with scaling and attention mask addition +template +void launch_attn_softmax(T* vals, + const T* attn_mask, + int batch_size, + int heads, + int sequence_length, + sycl::queue* stream); + +template +void launch_transform_0213(T* output, + const T* vals, + int batch_size, + int seq_length, + int hidden_dim, + int heads, + sycl::queue* stream); + +// Custom bias add +template +void launch_bias_add_transform_0213(T* outputs, + const T* vals, + const T* bias, + int batch_size, + int seq_length, + int hidden_dim, + int heads, + sycl::queue* stream, + int trans_count); + +// 4D transform [0, 1, 2, 3] -> [0, 2, 1, 3] +template +void launch_transform4d_0213(T* out, + const T* in, + int batch_size, + int heads, + int seq_length, + int hidden_dim, + sycl::queue* stream, + int trans_count); + +template +void launch_dropout(T* vals, + const T* bias, + uint8_t* mask, + int batch, + int dim, + float ratio, + sycl::queue* stream); + +template +void launch_dropout(T* vals_out, + const T* vals, + uint8_t* mask, + int total_count, + int dim, + float ratio, + sycl::queue* stream, + bool bwd = false); + +template +void launch_dropout(T* out, + const T* vals, + const T* residual, + const T* bias, + uint8_t* mask, + int batch, + int dim, + float ratio, + sycl::queue* stream); + +template +void launch_dropout_grad(T* vals, uint8_t* mask, int total_count, float ratio, sycl::queue* stream); + +template +void launch_dropout_grad(T* vals_out, + const T* vals, + uint8_t* mask, + int total_count, + float ratio, + sycl::queue* stream); + +template +void launch_fuse_transpose_bias_kernel(const T* inp, + T* out, + int rows, + int cols, + sycl::queue* stream); + +void launch_param_update(const float* input, sycl::half* output, int size, sycl::queue* stream); diff --git a/intel_extension_for_deepspeed/op_builder/csrc/includes/sycl/dropout.hpp b/intel_extension_for_deepspeed/op_builder/csrc/includes/sycl/dropout.hpp new file mode 100644 index 0000000..d534938 --- /dev/null +++ b/intel_extension_for_deepspeed/op_builder/csrc/includes/sycl/dropout.hpp @@ -0,0 +1,76 @@ +#pragma once + +#include +#include +#include "custom_sycl_layers.hpp" + +template +class Dropout { +public: + struct Config { + float ratio; + uint32_t dim; + bool training; + + Config(float r, uint32_t d) : ratio(r), dim(d), training(true) {} + + float RATIO() const { return training ? ratio : 0.0; } + inline void SetDim(uint32_t d) { dim = d; } + }; + + Dropout(const Config& config) : _config(config), _mask(nullptr) {} + + virtual ~Dropout() {} + + void Forward(int bsz, T* out, const T* vals, sycl::queue* stream, bool bwd = false) + { + launch_dropout( + out, vals, _mask, bsz * _config.dim, _config.dim, _config.RATIO(), stream, bwd); + } + + void ForwardWithBias(int bsz, T* vals, const T* bias, sycl::queue* stream) + { + launch_dropout(vals, bias, _mask, bsz, _config.dim, _config.RATIO(), stream); + } + + void ForwardWithBias(int bsz, + T* out, + const T* vals, + const T* residual, + const T* bias, + sycl::queue* stream) + { + launch_dropout( + out, vals, residual, bias, _mask, bsz, _config.dim, _config.RATIO(), stream); + } + + void Backward(int bsz, T* d_vals, sycl::queue* stream) + { + launch_dropout_grad(d_vals, _mask, bsz * _config.dim, _config.RATIO(), stream); + } + + void Backward(int bsz, T* d_vals_out, const T* d_vals, sycl::queue* stream) + { + launch_dropout_grad( + d_vals_out, d_vals, _mask, bsz * _config.dim, _config.RATIO(), stream); + } + + bool HasDropout() const { return _config.RATIO() > 0.0; } + + void SetTrainingMode(bool training) { _config.training = training; } + + void SetMask(uint8_t* mask) + { + if (!mask) { throw std::runtime_error("Dropout mask is null."); } + + _mask = mask; + } + + Config GetConfig() const { return _config; } + + inline void SetDimension(uint32_t dim) { _config.SetDim(dim); } + +private: + uint8_t* _mask; + Config _config; +}; diff --git a/intel_extension_for_deepspeed/op_builder/csrc/includes/sycl/ds_transformer_sycl.hpp b/intel_extension_for_deepspeed/op_builder/csrc/includes/sycl/ds_transformer_sycl.hpp new file mode 100644 index 0000000..4b6bec1 --- /dev/null +++ b/intel_extension_for_deepspeed/op_builder/csrc/includes/sycl/ds_transformer_sycl.hpp @@ -0,0 +1,184 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include "dropout.hpp" +#include "feed_forward.hpp" +#include "gelu.hpp" +#include "general_kernels.hpp" +#include "normalize_layer.hpp" +#include "softmax.hpp" +#include "strided_batch_gemm.hpp" + +struct BertGemmAlgos { + int m_gemm_qkv_algo; + int m_gemm_inter_algo; + int m_gemm_output_algo; + int m_gemm_batch1_algo; + int m_gemm_batch2_algo; + + BertGemmAlgos() + : m_gemm_qkv_algo(-1), + m_gemm_inter_algo(-1), + m_gemm_output_algo(-1), + m_gemm_batch1_algo(-1), + m_gemm_batch2_algo(-1) + { + } +}; + +template +class BertTransformerLayer { +public: + BertTransformerLayer(int layer_id, + int batch_size, + int hidden_size, + int num_heads, + int intermediate_size, + int seq_length, + float attn_dropout_ratio, + float hidden_output_dropout_ratio, + float layer_norm_eps, + bool pre_or_postLayerNorm, + const std::vector>& gemm_algos, + bool attn_dropout_checkpoint, + bool normalize_invertible, + bool gelu_checkpoint, + bool stochastic_mode); + + virtual ~BertTransformerLayer(); + + void Forward(int bsz, + const T* input_ptr, + const T* input_mask_ptr, + const T* attn_qkvw_ptr, + const T* attn_qkvb_ptr, + const T* attn_ow_ptr, + const T* attn_ob_ptr, + const T* attn_nw_ptr, + const T* attn_nb_ptr, + const T* inter_w_ptr, + const T* inter_b_ptr, + const T* output_w_ptr, + const T* output_b_ptr, + const T* norm_w_ptr, + const T* norm_b_ptr, + T* out_ptr, + T* inp_norm_ptr, + T* q_tf_ptr, + T* k_tf_ptr, + T* v_tf_ptr, + T* softmax_output_ptr, + T* ctx_bufB_ptr, + T* attn_o_inp_ptr, + T* add_res_ptr, + T* ff1_inp_ptr, + T* gelu_inp_ptr, + T* ff2_inp_ptr); + + void Backward(int bsz, + const T* grad_output_ptr, + const T* input_ptr, + const T* output_ptr, + const T* inp_norm_ptr, + const T* q_tf_ptr, + const T* k_tf_ptr, + const T* v_tf_ptr, + const T* softmax_output_ptr, + const T* ctx_bufB_ptr, + const T* attn_o_inp_ptr, + const T* add_res_ptr, + const T* ff1_inp_ptr, + const T* gelu_inp_ptr, + const T* ff2_inp_ptr, + const T* input_mask_ptr, + const T* attn_qkvw_ptr, + const T* attn_ow_ptr, + const T* attn_nw_ptr, + const T* attn_nb_ptr, + const T* inter_w_ptr, + const T* inter_b_ptr, + const T* output_w_ptr, + const T* norm_w_ptr, + const T* norm_b_ptr, + + T* grad_input_ptr, + T* grad_attn_qkvw_ptr, + T* grad_attn_qkvb_ptr, + T* grad_attn_ow_ptr, + T* grad_attn_ob_ptr, + T* grad_attn_nw_ptr, + T* grad_attn_nb_ptr, + T* grad_inter_w_ptr, + T* grad_inter_b_ptr, + T* grad_output_w_ptr, + T* grad_output_b_ptr, + T* grad_norm_w_ptr, + T* grad_norm_b_ptr); + + void SetIntermediateBuffers(uint8_t* attn_prob_dropout_mask_ptr, + uint8_t* attn_output_dropout_mask_ptr, + uint8_t* layer_output_dropout_mask_ptr, + T* layer_norm_var, + T* layer_norm_mean, + T* attn_layer_norm_var, + T* attn_layer_norm_mean); + + inline int GetBatchSize() const { return _batch_size; } + inline int GetNumHeads() const { return _heads; } + inline int GetSeqLength() const { return _seq_length; } + inline int GetIntermediateSize() const { return _intermediate_size; } + + void SetSeqLength(int seq_len); + inline int GetHiddenSize() const { return _hidden_size; } + void SetTrainingMode(bool training); + inline bool IsTrainingMode() const { return _training; } + inline bool GeluCheckpoint() const { return _gelu_checkpoint; } + +private: + void Initialize(); + size_t getWorkspaceSize(int maxBatchSize) const; + + // Params + int _layer_id; + int _batch_size; + int _hidden_size; + int _heads; + int _size_per_head; + int _intermediate_size; + int _seq_length; + + bool _pre_or_postLayerNorm; + + sycl::queue* _onemklQ; + sycl::queue* _stream; + + // layers + FeedForward _qkv_linear; + FeedForward _attn_out_linear; + Normalize_Layer _attn_layer_norm; + Normalize_Layer _layer_norm; + Normalize_Layer* _last_normalize; + FeedForward _ff1, _ff2; + Softmax _softmax; + Gelu _gelu; + Dropout _attn_prob_dropout; + Dropout _attn_output_dropout; + Dropout _layer_output_dropout; + StridedBatchGemm _attn_scores; + StridedBatchGemm _attn_context; + + bool _training; + + // Memory saving flags + bool _attn_dropout_checkpoint; + bool _normalize_invertible; + bool _gelu_checkpoint; + + // High Performace flags + bool _stochastic_mode; +}; diff --git a/intel_extension_for_deepspeed/op_builder/csrc/includes/sycl/feed_forward.hpp b/intel_extension_for_deepspeed/op_builder/csrc/includes/sycl/feed_forward.hpp new file mode 100644 index 0000000..4145fdd --- /dev/null +++ b/intel_extension_for_deepspeed/op_builder/csrc/includes/sycl/feed_forward.hpp @@ -0,0 +1,125 @@ +#pragma once + +#include +#include +#include "custom_sycl_layers.hpp" + +template +class FeedForward { +public: + struct Config { + int batchSize, outputSize; + int inputSize; + Config(int batch, int outputs, int inputs) + : batchSize(batch), outputSize(outputs), inputSize(inputs) + { + } + }; + + FeedForward(Config config) : config_(config) {} + + ~FeedForward() {} + + void Forward(int bsz, const T* input_ptr, const T* weights, T* out, sycl::queue* _Q) + { + if constexpr (std::is_same_v) { + float alpha = 1.0f; + float beta = 0.0f; + onednn_matmul_ex(_Q, + false, + true, + bsz, + config_.outputSize, + config_.inputSize, + alpha, + beta, + input_ptr, + weights, + out); + } else { + T alpha = T(1.); + T beta = T(0.); + onemkl_gemm_ex(_Q, + oneapi::mkl::transpose::trans, + oneapi::mkl::transpose::nontrans, + config_.outputSize, + bsz, + config_.inputSize, + alpha, + beta, + weights, + input_ptr, + out); + } + } + void Backward(int bsz, + const T* out_grad, + const T* input_ptr, + const T* weights, + T* weights_grad, + T* bias_grad, + sycl::queue* _Q, + sycl::queue* stream, + T* inp_grad_out = nullptr, + T* out_grad_trans_out = nullptr) + { + if constexpr (std::is_same_v) { + float alpha = 1.0f; + float beta = 0.0f; + onednn_matmul_ex(stream, + true, + false, + config_.outputSize, + config_.inputSize, + bsz, + alpha, + beta, + out_grad, + input_ptr, + weights_grad); + onednn_matmul_ex(stream, + false, + false, + bsz, + config_.inputSize, + config_.outputSize, + alpha, + beta, + out_grad, + weights, + inp_grad_out); + launch_fuse_transpose_bias_kernel( + out_grad, bias_grad, bsz, config_.outputSize, stream); + } else { + T alpha = (T)1.0; + T beta = (T)0.0; + onemkl_gemm_ex(_Q, + oneapi::mkl::transpose::nontrans, + oneapi::mkl::transpose::trans, + config_.inputSize, + config_.outputSize, + bsz, + alpha, + beta, + input_ptr, + out_grad, + weights_grad); + onemkl_gemm_ex(_Q, + oneapi::mkl::transpose::nontrans, + oneapi::mkl::transpose::nontrans, + config_.inputSize, + bsz, + config_.outputSize, + alpha, + beta, + weights, + out_grad, + inp_grad_out); + launch_fuse_transpose_bias_kernel( + out_grad, bias_grad, bsz, config_.outputSize, stream); + } + } + +private: + Config config_; +}; diff --git a/intel_extension_for_deepspeed/op_builder/csrc/includes/sycl/gelu.hpp b/intel_extension_for_deepspeed/op_builder/csrc/includes/sycl/gelu.hpp new file mode 100644 index 0000000..84aa8d9 --- /dev/null +++ b/intel_extension_for_deepspeed/op_builder/csrc/includes/sycl/gelu.hpp @@ -0,0 +1,35 @@ +#pragma once + +#include +#include +#include "custom_sycl_layers.hpp" + +template +class Gelu { +public: + struct Config { + uint32_t intermediate_size; + Config(uint32_t inter_size) : intermediate_size(inter_size) {} + }; + + Gelu(const Config& config) : _config(config) {} + + virtual ~Gelu() {} + + void ForwardWithBiasAdd(int bsz, + const T* input_buf, + const T* bias, + T* output, + sycl::queue* stream) + { + launch_bias_gelu(input_buf, bias, output, _config.intermediate_size, bsz, stream); + } + + void Backward(int bsz, T* d_output, const T* input_buf, const T* bias, sycl::queue* stream) + { + launch_d_gelu(d_output, input_buf, bias, _config.intermediate_size, bsz, stream); + } + +private: + Config _config; +}; diff --git a/intel_extension_for_deepspeed/op_builder/csrc/includes/sycl/gemm_test.hpp b/intel_extension_for_deepspeed/op_builder/csrc/includes/sycl/gemm_test.hpp new file mode 100644 index 0000000..0b4cac6 --- /dev/null +++ b/intel_extension_for_deepspeed/op_builder/csrc/includes/sycl/gemm_test.hpp @@ -0,0 +1,297 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include "StopWatch.h" +#include "onemkl_wrappers.hpp" + +template +class GemmTest { +public: + GemmTest(int m, + int n, + int k, + oneapi::mkl::transpose ta, + oneapi::mkl::transpose tb, + sycl::queue* h) + : M(m), N(n), K(k), transa(ta), transb(tb), handle(h) + { + dpct::device_ext& dev_ct1 = dpct::get_current_device(); + sycl::queue& q_ct1 = dev_ct1.default_queue(); + A = (T*)sycl::malloc_device(sizeof(T) * M * K, q_ct1); + B = (T*)sycl::malloc_device(sizeof(T) * K * N, q_ct1); + C = (T*)sycl::malloc_device(sizeof(T) * M * N, q_ct1); + } + + ~GemmTest() + { + dpct::device_ext& dev_ct1 = dpct::get_current_device(); + sycl::queue& q_ct1 = dev_ct1.default_queue(); + sycl::free(A, q_ct1); + sycl::free(B, q_ct1); + sycl::free(C, q_ct1); + } + + std::array TestAlgo(int loops) + { + float alpha = (T)1.0f; + float beta = (T)0.0f; + + int algo_fw = Run(loops, [=](int algo) { + onemkl_gemm_ex(handle, + oneapi::mkl::transpose::trans, + oneapi::mkl::transpose::nontrans, + N, + M, + K, + &alpha, + &beta, + B, + A, + C, + static_cast(algo)); + }); + + int algo_bw1 = Run(loops, [=](int algo) { + onemkl_gemm_ex(handle, + oneapi::mkl::transpose::nontrans, + oneapi::mkl::transpose::trans, + K, + N, + M, + &alpha, + &beta, + A, + C, + B, + static_cast(algo)); + }); + + int algo_bw2 = Run(loops, [=](int algo) { + onemkl_gemm_ex(handle, + oneapi::mkl::transpose::nontrans, + oneapi::mkl::transpose::nontrans, + K, + M, + N, + &alpha, + &beta, + B, + C, + A, + static_cast(algo)); + }); + + return std::array({algo_fw, algo_bw1, algo_bw2}); + } + + template + int Run(int loops, Func f) + { + float fast_latency = (std::numeric_limits::max)(); + int fast_algo = 0; + + for (int algo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; + algo <= (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; + algo++) { + int warm_up = 5; + for (int i = 0; i < warm_up; ++i) f(algo); + + cudaDeviceSynchronize(); + Stopwatch timer; + timer.Restart(); + + for (int i = 0; i < loops; ++i) f(algo); + + cudaDeviceSynchronize(); + timer.Stop(); + + float avg_latency = (float)timer.GetTimeInSeconds() * 1000 / loops; + + printf("algo-%d: %.3fms\n", algo, avg_latency); + + if (avg_latency < fast_latency) { + fast_latency = avg_latency; + fast_algo = algo; + } + } + + printf("fast_algo %d: %.3f ms\n", fast_algo, fast_latency); + + return fast_algo; + } + +private: + int M, N, K; + sycl::queue* handle; + oneapi::mkl::transpose transa, transb; + T *A, *B, *C; +}; + +template +class StridedGemmTest { +public: + StridedGemmTest(int b, + int m, + int n, + int k, + oneapi::mkl::transpose ta, + oneapi::mkl::transpose tb, + sycl::queue* h) + : bsz(b), M(m), N(n), K(k), transa(ta), transb(tb), handle(h) + { + dpct::device_ext& dev_ct1 = dpct::get_current_device(); + sycl::queue& q_ct1 = dev_ct1.default_queue(); + A = (T*)sycl::malloc_device(sizeof(T) * M * K * bsz, q_ct1); + B = (T*)sycl::malloc_device(sizeof(T) * K * N * bsz, q_ct1); + C = (T*)sycl::malloc_device(sizeof(T) * M * N * bsz, q_ct1); + } + + ~StridedGemmTest() + { + dpct::device_ext& dev_ct1 = dpct::get_current_device(); + sycl::queue& q_ct1 = dev_ct1.default_queue(); + sycl::free(A, q_ct1); + sycl::free(B, q_ct1); + sycl::free(C, q_ct1); + } + + std::array TestAlgo(int loops) + { + float alpha = (T)1.0f; + float beta = (T)0.0f; + + int algo_fw = Run(loops, [=](int algo) { + int stride_a = M * K; + int stride_b = N * K; + int stride_c = M * N; + + cublas_strided_batched_gemm(handle, + M, + N, + K, + &alpha, + &beta, + A, + B, + C, + transa, + transb, + stride_a, + stride_b, + stride_c, + bsz, + static_cast(algo)); + }); + + int algo_bw1 = Run(loops, [=](int algo) { + int mb = (transa == oneapi::mkl::transpose::trans ? K : M); + int kb = (transa == oneapi::mkl::transpose::trans ? M : K); + + int stride_a = mb * N; + int stride_b = N * kb; + int stride_c = M * K; + + // B need to transpose. + cublasOperation_t op_b = + (transb == oneapi::mkl::transpose::trans ? oneapi::mkl::transpose::nontrans + : oneapi::mkl::transpose::trans); + + // Calculate d_A. + cublas_strided_batched_gemm(handle, + mb, + kb, + N, + &alpha, + &beta, + (transa == oneapi::mkl::transpose::trans ? B : C), + (transa == oneapi::mkl::transpose::trans ? C : B), + A, + oneapi::mkl::transpose::nontrans, + op_b, + stride_a, + stride_b, + stride_c, + bsz, + static_cast(algo)); + }); + + int algo_bw2 = Run(loops, [=](int algo) { + // A need to transpose. + cublasOperation_t op_a = + (transa == oneapi::mkl::transpose::trans ? oneapi::mkl::transpose::nontrans + : oneapi::mkl::transpose::trans); + + int stride_a = M * K; + int stride_b = M * N; + int stride_c = N * K; + + // Calculate d_B. + cublas_strided_batched_gemm(handle, + K, + N, + M, + &alpha, + &beta, + A, + C, + B, + op_a, + oneapi::mkl::transpose::nontrans, + stride_a, + stride_b, + stride_c, + bsz, + static_cast(algo)); + }); + + return std::array({algo_fw, algo_bw1, algo_bw2}); + } + + template + int Run(int loops, Func f) + { + float fast_latency = (std::numeric_limits::max)(); + int fast_algo = 0; + + for (int algo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; + algo <= (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; + algo++) { + int warm_up = 5; + for (int i = 0; i < warm_up; ++i) f(algo); + + cudaDeviceSynchronize(); + Stopwatch timer; + timer.Restart(); + + for (int i = 0; i < loops; ++i) f(algo); + + cudaDeviceSynchronize(); + timer.Stop(); + + float avg_latency = (float)timer.GetTimeInSeconds() * 1000 / loops; + + printf("algo-%d: %.3fms\n", algo, avg_latency); + + if (avg_latency < fast_latency) { + fast_latency = avg_latency; + fast_algo = algo; + } + } + + printf("fast_algo %d: %.3f ms\n", fast_algo, fast_latency); + + return fast_algo; + } + +private: + int bsz, M, N, K; + sycl::queue* handle; + oneapi::mkl::transpose transa, transb; + T *A, *B, *C; +}; diff --git a/intel_extension_for_deepspeed/op_builder/csrc/includes/sycl/general_kernels.hpp b/intel_extension_for_deepspeed/op_builder/csrc/includes/sycl/general_kernels.hpp new file mode 100644 index 0000000..ea2135f --- /dev/null +++ b/intel_extension_for_deepspeed/op_builder/csrc/includes/sycl/general_kernels.hpp @@ -0,0 +1,43 @@ +#include +#include + +// #include + +#include "context.hpp" +#include "onemkl_wrappers.hpp" + +#define THREADS 256 + +#define minus_infinity -1 * std::numeric_limits::infinity() + +#define FINAL_MASK 0xffffffff + +template +void launch_fused_add2(T* out, + const T* inp1, + const T* inp2, + int batch_size, + int seq_length, + int hidden_size, + sycl::queue* stream); + +template +void launch_fused_add4(T* out, + const T* inp1, + const T* inp2, + const T* inp3, + const T* inp4, + int batch_size, + int seq_length, + int hidden_size, + sycl::queue* stream); + +template +void launch_fused_add3(T* out, + const T* inp1, + const T* inp2, + const T* inp3, + int batch_size, + int seq_length, + int hidden_size, + sycl::queue* stream); diff --git a/intel_extension_for_deepspeed/op_builder/csrc/includes/sycl/normalize_layer.hpp b/intel_extension_for_deepspeed/op_builder/csrc/includes/sycl/normalize_layer.hpp new file mode 100644 index 0000000..ed3a20a --- /dev/null +++ b/intel_extension_for_deepspeed/op_builder/csrc/includes/sycl/normalize_layer.hpp @@ -0,0 +1,197 @@ +#pragma once + +#include +#include "custom_sycl_layers.hpp" + +template +class Normalize_Layer { +public: + struct Config { + uint32_t batchSize; + uint32_t seqLength; + uint32_t hiddenDim; + float epsilon; + bool training; + bool useMean; + Config(uint32_t batch, + uint32_t seq, + uint32_t h, + float epsilon = 1e-12, + bool training = true, + bool useMean = true) + : batchSize(batch), + seqLength(seq), + hiddenDim(h), + epsilon(epsilon), + training(training), + useMean(useMean) + { + } + }; + + Normalize_Layer(Config config) + : config_(config), vars(nullptr), means(nullptr), vals_hat(nullptr) + { + } + + ~Normalize_Layer() {} + + void ForwardCheckpoint(int bsz, // batch * seq + T* vals, + const T* residual, + const T* gamma, + const T* betta, + sycl::queue* stream, + bool preLayerNorm = false) + { + launch_bias_residual_layer_norm(vals, + residual, + gamma, + betta, + config_.epsilon, + bsz, + config_.hiddenDim, + stream, + preLayerNorm, + config_.training, + vars, + means); + } + + void Forward(int bsz, + T* vals, + const T* residual, + const T* gamma, + const T* betta, + sycl::queue* stream, + bool preLayerNorm = false) + { + launch_bias_residual_layer_norm(vals, + residual, + gamma, + betta, + config_.epsilon, + bsz, + config_.hiddenDim, + stream, + preLayerNorm, + config_.training, + vars); + } + + void Backward(int bsz, + const T* out_grad, + const T* gamma, + T* gamma_grad, + T* betta_grad, + sycl::queue* stream[2], + T* inp_grad_out, + const T* norm_in = nullptr) + { + launch_layerNorm_backward(out_grad, + norm_in, + vars, + means, + gamma, + gamma_grad, + betta_grad, + inp_grad_out, + bsz, + config_.hiddenDim, + stream); + } + + void Backward(int bsz, + const T* out_grad, + const T* gamma, + const T* betta, + T* gamma_grad, + T* betta_grad, + sycl::queue* stream[2], + T* inp_grad_out, + const T* norm_out) + { + launch_layerNorm_backward(out_grad, + norm_out, + vars, + gamma, + gamma_grad, + betta_grad, + inp_grad_out, + bsz, + config_.hiddenDim, + stream, + !config_.useMean, + betta); + } + + void BackwardFusedAdd(int bsz, + const T* out_grad1, + const T* out_grad2, + const T* gamma, + T* gamma_grad, + T* betta_grad, + sycl::queue* stream[2], + T* inp_grad_out, + const T* norm_in = nullptr) + { + launch_layerNorm_backward_fused_add(out_grad1, + out_grad2, + norm_in, + vars, + means, + gamma, + gamma_grad, + betta_grad, + inp_grad_out, + bsz, + config_.hiddenDim, + stream); + } + + void BackwardFusedAdd(int bsz, + const T* out_grad1, + const T* out_grad2, + const T* gamma, + const T* betta, + T* gamma_grad, + T* betta_grad, + sycl::queue* stream[2], + T* inp_grad_out, + const T* norm_out) + { + launch_layerNorm_backward_fused_add(out_grad1, + out_grad2, + norm_out, + vars, + gamma, + gamma_grad, + betta_grad, + inp_grad_out, + bsz, + config_.hiddenDim, + stream, + !config_.useMean, + betta); + } + + inline bool UseMean() const { return config_.useMean; } + + inline void SetVar(T* variance) + { + if (!variance) { throw std::runtime_error("Normalize variance is null."); } + vars = variance; + } + + inline void SetMean(T* mean) + { + if (!mean) { throw std::runtime_error("Normalize mean is null."); } + means = mean; + } + +private: + Config config_; + T* vars; + T* means; + T* vals_hat; +}; diff --git a/intel_extension_for_deepspeed/op_builder/csrc/includes/sycl/onednn_wrappers.hpp b/intel_extension_for_deepspeed/op_builder/csrc/includes/sycl/onednn_wrappers.hpp new file mode 100644 index 0000000..cdef994 --- /dev/null +++ b/intel_extension_for_deepspeed/op_builder/csrc/includes/sycl/onednn_wrappers.hpp @@ -0,0 +1,31 @@ +#pragma once +#include +#include + +using namespace cl::sycl; +using bf16 = sycl::ext::oneapi::experimental::bfloat16; + +int onednn_matmul_ex(sycl::queue* handle, + bool trans_src, + bool trans_wgt, + int m, + int n, + int k, + const float alpha, + const float beta, + const bf16* src_ptr, + const bf16* wgt_ptr, + bf16* dst_ptr); + +int onednn_batchgemm(sycl::queue* handle, + int m, + int n, + int k, + const float alpha, + const float beta, + const bf16* src_ptr, + const bf16* wgt_ptr, + bf16* dst_ptr, + bool trans_src, + bool trans_wgt, + int batch); diff --git a/intel_extension_for_deepspeed/op_builder/csrc/includes/sycl/onemkl_wrappers.hpp b/intel_extension_for_deepspeed/op_builder/csrc/includes/sycl/onemkl_wrappers.hpp new file mode 100644 index 0000000..ab36803 --- /dev/null +++ b/intel_extension_for_deepspeed/op_builder/csrc/includes/sycl/onemkl_wrappers.hpp @@ -0,0 +1,65 @@ +#pragma once + +#include +#include +#include + +#include + +int onemkl_gemm_ex(sycl::queue* handle, + oneapi::mkl::transpose transa, + oneapi::mkl::transpose transb, + int m, + int n, + int k, + const float alpha, + const float beta, + const float* A, + const float* B, + float* C); + +int onemkl_gemm_ex(sycl::queue* handle, + oneapi::mkl::transpose transa, + oneapi::mkl::transpose transb, + int m, + int n, + int k, + const sycl::half alpha, + const sycl::half beta, + const sycl::half* A, + const sycl::half* B, + sycl::half* C); + +int onemkl_strided_batched_gemm(sycl::queue* handle, + int m, + int n, + int k, + const float alpha, + const float beta, + const float* A, + const float* B, + float* C, + oneapi::mkl::transpose op_A, + oneapi::mkl::transpose op_B, + int stride_A, + int stride_B, + int stride_C, + int batch, + int algo = -1); + +int onemkl_strided_batched_gemm(sycl::queue* handle, + int m, + int n, + int k, + const sycl::half alpha, + const sycl::half beta, + const sycl::half* A, + const sycl::half* B, + sycl::half* C, + oneapi::mkl::transpose op_A, + oneapi::mkl::transpose op_B, + int stride_A, + int stride_B, + int stride_C, + int batch, + int algo = 99); diff --git a/intel_extension_for_deepspeed/op_builder/csrc/includes/sycl/softmax.hpp b/intel_extension_for_deepspeed/op_builder/csrc/includes/sycl/softmax.hpp new file mode 100644 index 0000000..6fdb3aa --- /dev/null +++ b/intel_extension_for_deepspeed/op_builder/csrc/includes/sycl/softmax.hpp @@ -0,0 +1,57 @@ +#pragma once + +#include +#include +#include "custom_sycl_layers.hpp" + +#include + +template +class Softmax { +public: + struct Config { + size_t batchSize; + size_t heads; + size_t seq_length; + size_t prob_depth; + float temprature; + bool mem_alloc; + Config(size_t batch, size_t h, size_t seq, int prob_size = 0, bool mem_alloc = false) + : batchSize(batch), + heads(h), + seq_length(seq), + prob_depth(prob_size), + temprature(1.0), + mem_alloc(mem_alloc) + { + } + }; + + Softmax(Config config) : config_(config) {} + + ~Softmax() {} + + void Forward(int bsz, T* vals, const T* attn_mask, sycl::queue* stream) + { + launch_attn_softmax(vals, attn_mask, bsz, config_.heads, config_.seq_length, stream); + } + + void Backward(int bsz, T* out_grad, const T* soft_out, sycl::queue* stream) + { + launch_attn_softmax_backward_v2( + out_grad, soft_out, bsz, config_.heads, config_.seq_length, stream); + } + + inline size_t GetProbDepth() const { return config_.prob_depth; } + + inline size_t GetBatchSize() const { return config_.batchSize; } + + inline size_t GetNumHeads() const { return config_.heads; } + + inline size_t GetSeqLength() const { return config_.seq_length; } + + inline void SetSeqLength(size_t seq_len) { config_.seq_length = seq_len; } + +private: + Config config_; +}; diff --git a/intel_extension_for_deepspeed/op_builder/csrc/includes/sycl/strided_batch_gemm.hpp b/intel_extension_for_deepspeed/op_builder/csrc/includes/sycl/strided_batch_gemm.hpp new file mode 100644 index 0000000..47ea7eb --- /dev/null +++ b/intel_extension_for_deepspeed/op_builder/csrc/includes/sycl/strided_batch_gemm.hpp @@ -0,0 +1,268 @@ +#pragma once + +#include +#include +#include "context.hpp" +#include "onednn_wrappers.hpp" +#include "onemkl_wrappers.hpp" + +template +class StridedBatchGemm { +public: + struct Config { + int batch_size; + int m; + int n; + int k; + float alpha; + float beta; + oneapi::mkl::transpose op_A; + oneapi::mkl::transpose op_B; + std::array gemm_algos; + + Config(int batch, + int mm, + int nn, + int kk, + float param_alpha, + float param_beta, + oneapi::mkl::transpose opA, + oneapi::mkl::transpose opB, + const std::array& algos) + : batch_size(batch), + m(mm), + n(nn), + k(kk), + alpha(param_alpha), + beta(param_beta), + op_A(opA), + op_B(opB), + gemm_algos(algos) + { + } + void SetConfig(int mm, int nn, int kk) + { + m = mm; + n = nn; + k = kk; + } + }; + + StridedBatchGemm(const Config& config) : _config(config) + { + k_buf = NULL; + q_buf = NULL; + } + + virtual ~StridedBatchGemm() {} + + void Forward(int bsz, T* output, const T* _buffer_a, const T* _buffer_b, sycl::queue* handle) + { + int stride_a = _config.m * _config.k; + int stride_b = _config.n * _config.k; + int stride_c = _config.m * _config.n; + + if constexpr (std::is_same_v) { + onednn_batchgemm(handle, + _config.n, + _config.m, + _config.k, + _config.alpha, + _config.beta, + _buffer_b, + _buffer_a, + output, + _config.op_B == oneapi::mkl::transpose::trans, + _config.op_A == oneapi::mkl::transpose::trans, + bsz); + } else { + onemkl_strided_batched_gemm(handle, + _config.m, + _config.n, + _config.k, + (T)_config.alpha, + (T)_config.beta, + _buffer_a, + _buffer_b, + output, + _config.op_A, + _config.op_B, + stride_a, + stride_b, + stride_c, + bsz, + int(_config.gemm_algos[0])); + } + } + + void ForwardPlusSave(T* output, const T* _buffer_a, const T* _buffer_b, sycl::queue* handle) + { + int stride_a = _config.m * _config.k; + int stride_b = _config.n * _config.k; + int stride_c = _config.m * _config.n; + + if constexpr (std::is_same_v) { + throw std::runtime_error("Unsupport bf16 strided batch gemm"); + } else { + onemkl_strided_batched_gemm(handle, + _config.m, + _config.n, + _config.k, + (T)_config.alpha, + (T)_config.beta, + _buffer_a, + _buffer_b, + output, + _config.op_A, + _config.op_B, + stride_a, + stride_b, + stride_c, + _config.batch_size, + int(_config.gemm_algos[0])); + } + + k_buf = _buffer_a; + q_buf = _buffer_b; + } + + void Backward(int bsz, + const T* d_output, + const T* _buffer_a, + const T* _buffer_b, + sycl::queue* handle, + T* inpGradA = nullptr, + T* inpGradB = nullptr) + { + if constexpr (std::is_same_v) { + // calculate d_A + if (_config.op_A == oneapi::mkl::transpose::trans) { + onednn_batchgemm(handle, + _config.m, + _config.k, + _config.n, + _config.alpha, + _config.beta, + d_output, + _buffer_b, + inpGradA, + true, + false, + bsz); + + // Calculate d_B. + onednn_batchgemm(handle, + _config.n, + _config.k, + _config.m, + _config.alpha, + _config.beta, + d_output, + _buffer_a, + inpGradB, + false, + false, + bsz); + } else { + onednn_batchgemm(handle, + _config.n, + _config.m, + _config.k, + _config.alpha, + _config.beta, + _buffer_b, + d_output, + inpGradA, + true, + false, + bsz); + + // Calculate d_B. + onednn_batchgemm(handle, + _config.n, + _config.k, + _config.m, + _config.alpha, + _config.beta, + d_output, + _buffer_a, + inpGradB, + false, + true, + bsz); + } + + } else { + int mb = (_config.op_A == oneapi::mkl::transpose::trans ? _config.k : _config.m); + int kb = (_config.op_A == oneapi::mkl::transpose::trans ? _config.m : _config.k); + + int stride_a = mb * _config.n; + int stride_b = _config.n * kb; + int stride_c = _config.m * _config.k; + + // B need to transpose. + oneapi::mkl::transpose op_b = + (_config.op_B == oneapi::mkl::transpose::trans ? oneapi::mkl::transpose::nontrans + : oneapi::mkl::transpose::trans); + + // calculate d_A + onemkl_strided_batched_gemm( + handle, + mb, + kb, + _config.n, + (T)_config.alpha, + (T)_config.beta, + (_config.op_A == oneapi::mkl::transpose::trans ? _buffer_b : d_output), + (_config.op_A == oneapi::mkl::transpose::trans ? d_output : _buffer_b), + inpGradA, + oneapi::mkl::transpose::nontrans, + op_b, + stride_a, + stride_b, + stride_c, + bsz, + int(_config.gemm_algos[1])); + + // A need to transpose. + oneapi::mkl::transpose op_a = + (_config.op_A == oneapi::mkl::transpose::trans ? oneapi::mkl::transpose::nontrans + : oneapi::mkl::transpose::trans); + + stride_a = _config.m * _config.k; + stride_b = _config.m * _config.n; + stride_c = _config.n * _config.k; + + // Calculate d_B. + onemkl_strided_batched_gemm(handle, + _config.k, + _config.n, + _config.m, + (T)_config.alpha, + (T)_config.beta, + _buffer_a, + d_output, + inpGradB, + op_a, + oneapi::mkl::transpose::nontrans, + stride_a, + stride_b, + stride_c, + bsz, + int(_config.gemm_algos[2])); + } + } + + inline int GetN() const { return _config.k; } + + inline const T* GetBufferA() const { return k_buf; } + + inline const T* GetBufferB() const { return q_buf; } + + inline void SetConfig(int m, int n, int k) { _config.SetConfig(m, n, k); } + +private: + Config _config; + const T* q_buf; + const T* k_buf; +}; diff --git a/intel_extension_for_deepspeed/op_builder/csrc/includes/sycl/type_shim.hpp b/intel_extension_for_deepspeed/op_builder/csrc/includes/sycl/type_shim.hpp new file mode 100644 index 0000000..8f70feb --- /dev/null +++ b/intel_extension_for_deepspeed/op_builder/csrc/includes/sycl/type_shim.hpp @@ -0,0 +1,112 @@ +/* Taken from NVIDIA/apex commit 855808f3fc268e9715d613f3c2e56469d8c986d8 */ +#include +#include + +// Forward/backward compatiblity hack around +// https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288 +// pending more future-proof guidance from upstream. +// struct TypeShim +// { +// const at::Type& payload; +// TypeShim(const at::Type& type) : payload(type) {} +// // Enable trivial conversion to a const at::Type& for pre-3aeb78 +// operator const at::Type&(){ return payload; }; +// // Enable dispatch switch statements to take *this directly for post-3aeb78 +// //operator at::ScalarType(){ return payload.; }; +// }; + +#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \ + switch (TYPE) { \ + case at::ScalarType::Float: { \ + using scalar_t_##LEVEL = float; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Half: { \ + using scalar_t_##LEVEL = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + default: AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } + +#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \ + switch (TYPE) { \ + case at::ScalarType::Double: { \ + using scalar_t_##LEVEL = double; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Float: { \ + using scalar_t_##LEVEL = float; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Half: { \ + using scalar_t_##LEVEL = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + default: AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } + +#define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \ + switch (TYPE) { \ + case at::ScalarType::Double: { \ + using scalar_t_##LEVEL = double; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Float: { \ + using scalar_t_##LEVEL = float; \ + __VA_ARGS__; \ + break; \ + } \ + default: AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } + +template +T reduce_block_into_lanes(T* x, + T val, + sycl::nd_item<3> item_ct1, + int lanes = 1, + bool share_result = false) // lanes is intended to be <= 32. +{ + int tid = + item_ct1.get_local_id(2) + item_ct1.get_local_id(1) * item_ct1.get_local_range().get(2); + int blockSize = item_ct1.get_local_range(2) * + item_ct1.get_local_range(1); // blockSize is intended to be a multiple of 32. + + if (blockSize >= 64) { + x[tid] = val; + item_ct1.barrier(); + } + +#pragma unroll + for (int i = (blockSize >> 1); i >= 64; i >>= 1) { + if (tid < i) x[tid] = x[tid] + x[tid + i]; + item_ct1.barrier(); + } + + T final; + + if (tid < 32) { + if (blockSize >= 64) + final = x[tid] + x[tid + 32]; + else + final = val; + // __SYNCWARP(); + +#pragma unroll + for (int i = 16; i >= lanes; i >>= 1) + final = final + __shfl_down_sync(0xffffffff, final, i); + } + + if (share_result) { + if (tid < lanes) x[tid] = final; // EpilogueOp + // Make sure the smem result is visible to all warps. + item_ct1.barrier(); + } + + return final; +} diff --git a/intel_extension_for_deepspeed/op_builder/csrc/transformer/sycl/dropout_kernels.dp.cpp b/intel_extension_for_deepspeed/op_builder/csrc/transformer/sycl/dropout_kernels.dp.cpp new file mode 100644 index 0000000..1588c0b --- /dev/null +++ b/intel_extension_for_deepspeed/op_builder/csrc/transformer/sycl/dropout_kernels.dp.cpp @@ -0,0 +1,1194 @@ +#include +#include "sycl/custom_sycl_layers.hpp" +using namespace cl::sycl; + +const int unroll_factor = 4; + +void dropout_kernel(const int N, + const float ratio, + float* out, + const float* Xdata, + uint8_t* mask, + const std::pair& seed, + nd_item<3> item_ct1) +{ + const float scale = 1. / (1. - ratio); + size_t idx = item_ct1.get_group(2) * item_ct1.get_local_range(2) + item_ct1.get_local_id(2); + + oneapi::mkl::rng::device::philox4x32x10<4> engine(seed.first, {idx * 4, seed.second}); + oneapi::mkl::rng::device::uniform<> distr; + + DPCPP_1D_KERNEL_LOOP(j, N / unroll_factor) + { + float4 rand = oneapi::mkl::rng::device::generate(distr, engine); + uint8_t m[unroll_factor]; + + m[0] = (uint8_t)(rand.x() > ratio); + m[1] = (uint8_t)(rand.y() > ratio); + m[2] = (uint8_t)(rand.z() > ratio); + m[3] = (uint8_t)(rand.w() > ratio); + + int i = j * unroll_factor; + + mask[i] = (uint8_t)m[0]; + mask[i + 1] = (uint8_t)m[1]; + mask[i + 2] = (uint8_t)m[2]; + mask[i + 3] = (uint8_t)m[3]; + + out[i] = Xdata[i] * scale * m[0]; + out[i + 1] = Xdata[i + 1] * scale * m[1]; + out[i + 2] = Xdata[i + 2] * scale * m[2]; + out[i + 3] = Xdata[i + 3] * scale * m[3]; + } + int high_index = ((((N / unroll_factor) - 1) / item_ct1.get_local_range().get(2) + 1) * + (unroll_factor * item_ct1.get_local_range().get(2))) + + item_ct1.get_local_id(2); + if (N > high_index) { + float4 rand = oneapi::mkl::rng::device::generate(distr, engine); + float* rand_data = &(rand.x()); + int k = 0; + for (int i = high_index; i < N; i++) { + uint8_t m = (uint8_t)(rand_data[k++] > ratio); + out[i] = Xdata[i] * scale * m; + mask[i] = m; + } + } +} + +void dropout_kernel(const int N, + const float ratio, + bf16* out, + const bf16* Xdata, + uint8_t* mask, + const std::pair& seed, + nd_item<3> item_ct1) +{ + const float scale = 1. / (1. - ratio); + size_t idx = item_ct1.get_group(2) * item_ct1.get_local_range(2) + item_ct1.get_local_id(2); + + oneapi::mkl::rng::device::philox4x32x10<4> engine(seed.first, {idx * 4, seed.second}); + oneapi::mkl::rng::device::uniform<> distr; + + ushort* out_cast = reinterpret_cast(out); + const ushort* Xdata_cast = reinterpret_cast(Xdata); + + DPCPP_1D_KERNEL_LOOP(j, N / unroll_factor) + { + float4 rand = oneapi::mkl::rng::device::generate(distr, engine); + uint8_t m[unroll_factor]; + + m[0] = (uint8_t)(rand.x() > ratio); + m[1] = (uint8_t)(rand.y() > ratio); + m[2] = (uint8_t)(rand.z() > ratio); + m[3] = (uint8_t)(rand.w() > ratio); + + int i = j * unroll_factor; + + mask[i] = (uint8_t)m[0]; + mask[i + 1] = (uint8_t)m[1]; + mask[i + 2] = (uint8_t)m[2]; + mask[i + 3] = (uint8_t)m[3]; + + out_cast[i] = bf16::from_float(bf16::to_float(Xdata_cast[i]) * scale * m[0]); + out_cast[i + 1] = bf16::from_float(bf16::to_float(Xdata_cast[i + 1]) * scale * m[1]); + out_cast[i + 2] = bf16::from_float(bf16::to_float(Xdata_cast[i + 2]) * scale * m[2]); + out_cast[i + 3] = bf16::from_float(bf16::to_float(Xdata_cast[i + 3]) * scale * m[3]); + } + int high_index = ((((N / unroll_factor) - 1) / item_ct1.get_local_range().get(2) + 1) * + (unroll_factor * item_ct1.get_local_range().get(2))) + + item_ct1.get_local_id(2); + if (N > high_index) { + float4 rand = oneapi::mkl::rng::device::generate(distr, engine); + float* rand_data = &(rand.x()); + int k = 0; + for (int i = high_index; i < N; i++) { + uint8_t m = (uint8_t)(rand_data[k++] > ratio); + out_cast[i] = bf16::from_float(bf16::to_float(Xdata_cast[i]) * scale * m); + mask[i] = m; + } + } +} + +void dropout_kernel(const int N, + const float ratio, + half* out, + const half* Xdata, + uint8_t* mask, + const std::pair& seed, + nd_item<3> item_ct1) +{ + const float scale = 1. / (1. - ratio); + + size_t idx = item_ct1.get_group(2) * item_ct1.get_local_range(2) + item_ct1.get_local_id(2); + + oneapi::mkl::rng::device::philox4x32x10<4> engine(seed.first, {idx * 4, seed.second}); + oneapi::mkl::rng::device::uniform<> distr; + +#ifdef __STOCHASTIC_MODE__ + + const half2 h_scale = vec{scale}.convert(); + const float2* x_cast = reinterpret_cast(Xdata); + float2* out_cast = reinterpret_cast(out); + uint32_t* mask_cast = reinterpret_cast(mask); + + uint32_t m_32; + uint8_t* m = reinterpret_cast(&m_32); + + float2 result_f; + half2* result_h = reinterpret_cast(&result_f); + half2 mask_h[2]; + float2 mask_f[2]; + + DPCPP_1D_KERNEL_LOOP(j, N / unroll_factor) + { + float2 x_f = x_cast[j]; + half2* x_h = reinterpret_cast(&x_f); + + float4 rand = oneapi::mkl::rng::device::generate(distr, engine); + + m[0] = (uint8_t)(rand.x > ratio); + m[1] = (uint8_t)(rand.y > ratio); + m[2] = (uint8_t)(rand.z > ratio); + m[3] = (uint8_t)(rand.w > ratio); + + float* mask_f_data = &mask_f[0].x; +#pragma unroll + for (int i = 0; i < unroll_factor; i++) mask_f_data[i] = (float)(m[i]); + + mask_h[0] = mask_f[0].convert(); + mask_h[1] = mask_f[1].convert 9; + + result_h[0] = x_h[0] * h_scale * mask_h[0]; + result_h[1] = x_h[1] * h_scale * mask_h[1]; + + out_cast[j] = result_f; + + mask_cast[j] = m_32; + } + +#else + + DPCPP_1D_KERNEL_LOOP(j, N / unroll_factor) + { + int i = j * unroll_factor; + + const half2* vals_half = reinterpret_cast(Xdata + i); + float2 vals_half_f[2]; + vals_half_f[0] = vals_half[0].convert(); + vals_half_f[1] = vals_half[1].convert(); + + uint8_t m[unroll_factor]; + float4 rand = oneapi::mkl::rng::device::generate(distr, engine); + m[0] = (uint8_t)(rand.x() > ratio); + m[1] = (uint8_t)(rand.y() > ratio); + m[2] = (uint8_t)(rand.z() > ratio); + m[3] = (uint8_t)(rand.w() > ratio); + + out[i] = vec{vals_half_f[0].x() * scale * m[0]} + .convert()[0]; + out[i + 1] = vec{vals_half_f[0].y() * scale * m[1]} + .convert()[0]; + out[i + 2] = vec{vals_half_f[1].x() * scale * m[2]} + .convert()[0]; + out[i + 3] = vec{vals_half_f[1].y() * scale * m[3]} + .convert()[0]; + + mask[i] = m[0]; + mask[i + 1] = m[1]; + mask[i + 2] = m[2]; + mask[i + 3] = m[3]; + } + +#endif + int high_index = ((((N / unroll_factor) - 1) / item_ct1.get_local_range().get(2) + 1) * + (unroll_factor * item_ct1.get_local_range().get(2))) + + item_ct1.get_local_id(2); + if (N > high_index) { + float4 rand = oneapi::mkl::rng::device::generate(distr, engine); + float* rand_data = &(rand.x()); + int k = 0; + for (int i = high_index; i < N; i++) { + uint8_t m = (uint8_t)(rand_data[k++] > ratio); + out[i] = vec{(float)Xdata[i] * scale * m} + .convert()[0]; + mask[i] = m; + } + } +} + +void dropout_kernel_bwd(const int N, + const float ratio, + const float* Xdata, + float* out, + uint8_t* mask, + const std::pair& seed, + nd_item<3> item_ct1) +{ + const float scale = 1. / (1. - ratio); + DPCPP_1D_KERNEL_LOOP(j, N / unroll_factor) + { + int i = j * unroll_factor; + + out[i] = mask[i] * Xdata[i] * scale; + out[i + 1] = mask[i + 1] * Xdata[i + 1] * scale; + out[i + 2] = mask[i + 2] * Xdata[i + 2] * scale; + out[i + 3] = mask[i + 3] * Xdata[i + 3] * scale; + } + int high_index = ((((N / unroll_factor) - 1) / item_ct1.get_local_range().get(2) + 1) * + (unroll_factor * item_ct1.get_local_range().get(2))) + + item_ct1.get_local_id(2); + if (N > high_index) { + for (int i = high_index; i < N; i++) { out[i] = mask[i] * Xdata[i] * scale; } + } +} + +void dropout_kernel_bwd(const int N, + const float ratio, + const bf16* Xdata, + bf16* out, + uint8_t* mask, + const std::pair& seed, + nd_item<3> item_ct1) +{ + const float scale = 1. / (1. - ratio); + + const ushort* Xdata_cast = reinterpret_cast(Xdata); + ushort* out_cast = reinterpret_cast(out); + + DPCPP_1D_KERNEL_LOOP(j, N / unroll_factor) + { + int i = j * unroll_factor; + + out_cast[i] = bf16::from_float(mask[i] * bf16::to_float(Xdata_cast[i]) * scale); + out_cast[i + 1] = bf16::from_float(mask[i + 1] * bf16::to_float(Xdata_cast[i + 1]) * scale); + out_cast[i + 2] = bf16::from_float(mask[i + 2] * bf16::to_float(Xdata_cast[i + 2]) * scale); + out_cast[i + 3] = bf16::from_float(mask[i + 3] * bf16::to_float(Xdata_cast[i + 3]) * scale); + } + int high_index = ((((N / unroll_factor) - 1) / item_ct1.get_local_range().get(2) + 1) * + (unroll_factor * item_ct1.get_local_range().get(2))) + + item_ct1.get_local_id(2); + if (N > high_index) { + for (int i = high_index; i < N; i++) { + out_cast[i] = bf16::from_float(mask[i] * bf16::to_float(Xdata_cast[i]) * scale); + } + } +} + +void dropout_kernel_bwd(const int N, + const float ratio, + const half* Xdata, + half* out, + uint8_t* mask, + const std::pair& seed, + nd_item<3> item_ct1) +{ + const float scale = 1. / (1. - ratio); + +#ifdef __STOCHASTIC_MODE__ + + const half2 h_scale = vec{scale}.convert(); + + const float2* x_cast = reinterpret_cast(Xdata); + float2* out_cast = reinterpret_cast(out); + uint32_t* mask_cast = reinterpret_cast(mask); + + DPCPP_1D_KERNEL_LOOP(j, N / unroll_factor) + { + float2 x_f = x_cast[j]; + half2* x_h = reinterpret_cast(&x_f); + + uint32_t m_32 = mask_cast[j]; + uint8_t* m = (uint8_t*)&m_32; + + half2 mask_h[2]; + float2 mask_f[2]; + + float* mask_f_data = &mask_f[0].x; +#pragma unroll + for (int i = 0; i < unroll_factor; i++) mask_f_data[i] = (float)(m[i]); + +#pragma unroll + for (int i = 0; i < 2; i++) mask_h[i] = __float22half2_rn(mask_f[i]); + + float2 result_f; + half2* result_h = reinterpret_cast(&result_f); + + result_h[0] = x_h[0] * h_scale * mask_h[0]; + result_h[1] = x_h[1] * h_scale * mask_h[1]; + + out_cast[j] = result_f; + } + +#else + + const half h_scale = vec{scale}.convert()[0]; + const half h_zero = vec{0.0}.convert()[0]; + + DPCPP_1D_KERNEL_LOOP(j, N / unroll_factor) + { + int i = j * unroll_factor; + + const half2* vals_half = reinterpret_cast(Xdata + i); + + uint8_t* m = mask + i; + + float2 vals_half_f[2]; + + vals_half_f[0] = vals_half[0].convert(); + vals_half_f[1] = vals_half[1].convert(); + + out[i] = vec{vals_half_f[0].x() * scale * m[0]} + .convert()[0]; + out[i + 1] = vec{vals_half_f[0].y() * scale * m[1]} + .convert()[0]; + out[i + 2] = vec{vals_half_f[1].x() * scale * m[2]} + .convert()[0]; + out[i + 3] = vec{vals_half_f[1].y() * scale * m[3]} + .convert()[0]; + } + +#endif + int high_index = ((((N / unroll_factor) - 1) / item_ct1.get_local_range().get(2) + 1) * + (unroll_factor * item_ct1.get_local_range().get(2))) + + item_ct1.get_local_id(2); + if (N > high_index) { + for (int i = high_index; i < N; i++) { + out[i] = vec{(float)Xdata[i] * scale * mask[i]} + .convert()[0]; + } + } +} + +template +void launch_dropout(T* out, + const T* vals, + uint8_t* mask, + int total_count, + int dim, + float ratio, + queue* stream, + bool bwd) +{ + /* + * dropout.Forward + */ + assert(unroll_factor == 4); + + range<3> grid_dim = range<3>(1, 1, DS_GET_BLOCKS(total_count / unroll_factor)); + range<3> block_dim = range<3>(1, 1, DS_CUDA_NUM_THREADS); + + if (dim > 512) { + block_dim[2] >>= 1; + grid_dim[2] <<= 1; + } + uint64_t inc = total_count / grid_dim[2] / block_dim[2]; + std::pair seed = SyclContext::Instance().IncrementOffset(inc); + if (bwd) + stream->submit([&](handler& cgh) { + cgh.parallel_for( + nd_range<3>(grid_dim * block_dim, block_dim), [=](nd_item<3> item_ct1) { + dropout_kernel_bwd(total_count, ratio, vals, out, mask, seed, item_ct1); + }); + }); + else + stream->submit([&](handler& cgh) { + cgh.parallel_for( + nd_range<3>(grid_dim * block_dim, block_dim), [=](nd_item<3> item_ct1) { + dropout_kernel(total_count, ratio, out, vals, mask, seed, item_ct1); + }); + }); +} + +template void launch_dropout(float* out, + const float* vals, + uint8_t* mask, + int total_count, + int dim, + float ratio, + queue* stream, + bool); +template void launch_dropout(bf16* out, + const bf16* vals, + uint8_t* mask, + int total_count, + int dim, + float ratio, + queue* stream, + bool); +template void launch_dropout(half* out, + const half* vals, + uint8_t* mask, + int total_count, + int dim, + float ratio, + queue* stream, + bool); + +void dropout_grad_kernel(const int N, + const float scale, + float* Xdata, + uint8_t* mask, + nd_item<3> item_ct1) +{ + DPCPP_1D_KERNEL_LOOP(i, N) { Xdata[i] *= scale * mask[i]; } +} + +void dropout_grad_kernel(const int N, + const float scale, + bf16* Xdata, + uint8_t* mask, + nd_item<3> item_ct1) +{ + ushort* Xdata_cast = reinterpret_cast(Xdata); + DPCPP_1D_KERNEL_LOOP(i, N) + { + Xdata_cast[i] = bf16::from_float(bf16::to_float(Xdata_cast[i]) * scale * mask[i]); + } +} + +void dropout_grad_kernel(const int N, + const float scale, + half* Xdata, + uint8_t* mask, + nd_item<3> item_ct1) +{ + const half2 h_scale = float2{scale, scale}.convert(); + float2* x_cast = reinterpret_cast(Xdata); + uint32_t* mask_cast = reinterpret_cast(mask); + + DPCPP_1D_KERNEL_LOOP(j, N / unroll_factor) + { + float2 x_data = x_cast[j]; + uint32_t m_32 = mask_cast[j]; + uint8_t* m = (uint8_t*)&m_32; + + float2 result_f; + half2* result_h = reinterpret_cast(&result_f); + +#ifdef __STOCHASTIC_MODE__ + + half2* x_data_h = reinterpret_cast(&x_data); + half2 mask_h[2]; + float2 mask_f[2]; + + float* mask_f_data = &mask_f[0].x; +#pragma unroll + for (int i = 0; i < unroll_factor; i++) *(mask_f_data++) = (float)(m[i]); + + mask_h[0] = __float22half2_rn(mask_f[0]); + mask_h[1] = __float22half2_rn(mask_f[1]); + + result_h[0] = x_data_h[0] * h_scale * mask_h[0]; + result_h[1] = x_data_h[1] * h_scale * mask_h[1]; + +#else + + half* x_data_h = reinterpret_cast(&x_data); + float2 result[2]; + + result[0].x() = (float)x_data_h[0] * scale * m[0]; + result[0].y() = (float)x_data_h[1] * scale * m[1]; + result[1].x() = (float)x_data_h[2] * scale * m[2]; + result[1].y() = (float)x_data_h[3] * scale * m[3]; + + result_h[0] = result[0].convert(); + result_h[1] = result[1].convert(); + +#endif + x_cast[j] = result_f; + } + int high_index = ((((N / unroll_factor) - 1) / item_ct1.get_local_range().get(2) + 1) * + (unroll_factor * item_ct1.get_local_range().get(2))) + + item_ct1.get_local_id(2); + if (N > high_index) { + for (int i = high_index; i < N; i++) { + Xdata[i] = vec{(float)Xdata[i] * scale * mask[i]} + .convert()[0]; + } + } +} + +template +void launch_dropout_grad(T* vals, uint8_t* mask, int total_count, float ratio, queue* stream) +{ + /* + * Dropout.Backward0 + */ + assert(unroll_factor == 4); + + const float scale = 1. / (1. - ratio); + range<3> grid_dim = range<3>(1, 1, DS_GET_BLOCKS(total_count / unroll_factor)); + range<3> block_dim = range<3>(1, 1, DS_CUDA_NUM_THREADS); + stream->submit([&](handler& cgh) { + cgh.parallel_for(nd_range<3>(grid_dim * block_dim, block_dim), [=](nd_item<3> item_ct1) { + dropout_grad_kernel(total_count, scale, vals, mask, item_ct1); + }); + }); +} + +template void launch_dropout_grad(float* vals, + uint8_t* mask, + int total_count, + float ratio, + queue* stream); +template void launch_dropout_grad(bf16* vals, + uint8_t* mask, + int total_count, + float ratio, + queue* stream); +template void launch_dropout_grad(half* vals, + uint8_t* mask, + int total_count, + float ratio, + queue* stream); + +void dropout_grad_kernel(const int N, + const float scale, + const float* Xdata, + float* out, + uint8_t* mask, + nd_item<3> item_ct1) +{ + DPCPP_1D_KERNEL_LOOP(i, N) { out[i] = Xdata[i] * scale * mask[i]; } +} + +void dropout_grad_kernel(const int N, + const float scale, + const bf16* Xdata, + bf16* out, + uint8_t* mask, + nd_item<3> item_ct1) +{ + const ushort* Xdata_cast = reinterpret_cast(Xdata); + ushort* out_cast = reinterpret_cast(out); + DPCPP_1D_KERNEL_LOOP(i, N) + { + out_cast[i] = bf16::from_float(bf16::to_float(Xdata_cast[i]) * scale * mask[i]); + } +} + +void dropout_grad_kernel(const int N, + const float scale, + const half* Xdata, + half* out, + uint8_t* mask, + nd_item<3> item_ct1) +{ + const float2* x_cast = reinterpret_cast(Xdata); + float2* out_cast = reinterpret_cast(out); + const uint32_t* mask_cast = reinterpret_cast(mask); + + float2 result_f; + half2* result_h = reinterpret_cast(&result_f); + + DPCPP_1D_KERNEL_LOOP(j, N / unroll_factor) + { + float2 x_data = x_cast[j]; + uint32_t m_32 = mask_cast[j]; + uint8_t* m = (uint8_t*)&m_32; + + half* x_data_h = reinterpret_cast(&x_data); + float2 result[2]; + + result[0].x() = (float)x_data_h[0] * scale * m[0]; + result[0].y() = (float)x_data_h[1] * scale * m[1]; + result[1].x() = (float)x_data_h[2] * scale * m[2]; + result[1].y() = (float)x_data_h[3] * scale * m[3]; + + result_h[0] = result[0].convert(); + result_h[1] = result[1].convert(); + + out_cast[j] = result_f; + } + int high_index = ((((N / unroll_factor) - 1) / item_ct1.get_local_range().get(2) + 1) * + (unroll_factor * item_ct1.get_local_range().get(2))) + + item_ct1.get_local_id(2); + if (N > high_index) { + for (int i = high_index; i < N; i++) { + out[i] = vec{(float)Xdata[i] * scale * mask[i]} + .convert()[0]; + } + } +} + +template +void launch_dropout_grad(T* vals_out, + const T* vals, + uint8_t* mask, + int total_count, + float ratio, + queue* stream) +{ + /* + * Dropout.Backward1 + */ + assert(unroll_factor == 4); + + const float scale = 1. / (1. - ratio); + range<3> grid_dim = range<3>(1, 1, DS_GET_BLOCKS(total_count / unroll_factor)); + range<3> block_dim = range<3>(1, 1, DS_CUDA_NUM_THREADS); + stream->submit([&](handler& cgh) { + cgh.parallel_for(nd_range<3>(grid_dim * block_dim, block_dim), [=](nd_item<3> item_ct1) { + dropout_grad_kernel(total_count, scale, vals, vals_out, mask, item_ct1); + }); + }); +} +template void launch_dropout_grad(float* vals_out, + const float* vals, + uint8_t* mask, + int total_count, + float ratio, + queue* stream); +template void launch_dropout_grad(bf16* vals_out, + const bf16* vals, + uint8_t* mask, + int total_count, + float ratio, + queue* stream); +template void launch_dropout_grad(half* vals_out, + const half* vals, + uint8_t* mask, + int total_count, + float ratio, + queue* stream); + +/* + * not called in transformer kernel Shi Yuankun 2021/10/21 + */ +void dropout_kernel(const int N, + const int dim, + const float ratio, + const float* bias, + float* Xdata, + uint8_t* mask, + const std::pair& seed, + nd_item<3> item_ct1) +{ + const float scale = 1. / (1. - ratio); + size_t idx = + item_ct1.get_group(2) * item_ct1.get_local_range().get(2) + item_ct1.get_local_id(2); + int tid = item_ct1.get_local_id(2) % (dim / unroll_factor); + + oneapi::mkl::rng::device::philox4x32x10<4> engine(seed.first, {idx * 4, seed.second}); + oneapi::mkl::rng::device::uniform<> distr; + + float4* Xdata_cast = reinterpret_cast(Xdata); + uint32_t* mask_32 = reinterpret_cast(mask); + const float4* bias_cast = reinterpret_cast(bias); + + DPCPP_1D_KERNEL_LOOP(j, N) + { + float4 rand = oneapi::mkl::rng::device::generate(distr, engine); + uint32_t m_32; + uint8_t* m = (uint8_t*)&m_32; + + m[0] = (uint8_t)(rand.x() > ratio); + m[1] = (uint8_t)(rand.y() > ratio); + m[2] = (uint8_t)(rand.z() > ratio); + m[3] = (uint8_t)(rand.w() > ratio); + + float4 x_data = Xdata_cast[j]; + float4 b_data = bias_cast[j % (dim / unroll_factor)]; + + x_data.x() += b_data.x(); + x_data.y() += b_data.y(); + x_data.z() += b_data.z(); + x_data.w() += b_data.w(); + + x_data.x() = x_data.x() * scale * m[0]; + x_data.y() = x_data.y() * scale * m[1]; + x_data.z() = x_data.z() * scale * m[2]; + x_data.w() = x_data.w() * scale * m[3]; + + mask_32[j] = m_32; + Xdata_cast[j] = x_data; + } + int high_index = ((((N / unroll_factor) - 1) / item_ct1.get_local_range().get(2) + 1) * + (unroll_factor * item_ct1.get_local_range().get(2))) + + item_ct1.get_local_id(2); + if (N > high_index) { + float4 rand = oneapi::mkl::rng::device::generate(distr, engine); + float* rand_data = &(rand.x()); + int k = 0; + for (int i = high_index; i < N; i++) { + float x_data = Xdata[i] + bias[i % dim]; + uint8_t m = (uint8_t)(rand_data[k++] > ratio); + Xdata[i] = x_data * scale * m; + mask[i] = m; + } + } +} + +void dropout_kernel(const int N, + const int dim, + const float ratio, + const half* bias, + half* Xdata, + uint8_t* mask, + const std::pair& seed, + nd_item<3> item_ct1) +{ + const float scale = 1. / (1. - ratio); + size_t idx = + item_ct1.get_group(2) * item_ct1.get_local_range().get(2) + item_ct1.get_local_id(2); + int tid = item_ct1.get_local_id(2) % (dim / unroll_factor); + + oneapi::mkl::rng::device::philox4x32x10<4> engine(seed.first, {idx * 4, seed.second}); + oneapi::mkl::rng::device::uniform<> distr; + + float2* Xdata_cast = reinterpret_cast(Xdata); + uint32_t* mask_32 = reinterpret_cast(mask); + const float2* bias_cast = reinterpret_cast(bias); + + DPCPP_1D_KERNEL_LOOP(j, N) + { + float4 rand = oneapi::mkl::rng::device::generate(distr, engine); + + float2 data_f; + half2* data_h = reinterpret_cast(&data_f); + + float2 bias_f; + half2* bias_h = reinterpret_cast(&bias_f); + + data_f = Xdata_cast[j]; + bias_f = bias_cast[j % (dim / unroll_factor)]; + + float2 data_h_0 = data_h[0].convert(); + float2 data_h_1 = data_h[1].convert(); + + float2 bias_h_0 = bias_h[0].convert(); + float2 bias_h_1 = bias_h[1].convert(); + + data_h_0.x() += bias_h_0.x(); + data_h_0.y() += bias_h_0.y(); + data_h_1.x() += bias_h_1.x(); + data_h_1.y() += bias_h_1.y(); + + uint32_t m_32; + uint8_t* m = (uint8_t*)&m_32; + + m[0] = (uint8_t)(rand.x() > ratio); + m[1] = (uint8_t)(rand.y() > ratio); + m[2] = (uint8_t)(rand.z() > ratio); + m[3] = (uint8_t)(rand.w() > ratio); + + data_h_0.x() = + vec{data_h_0.x() * scale * m[0]}.convert()[0]; + data_h_0.y() = + vec{data_h_0.y() * scale * m[1]}.convert()[0]; + data_h_1.x() = + vec{data_h_1.x() * scale * m[2]}.convert()[0]; + data_h_1.y() = + vec{data_h_1.y() * scale * m[3]}.convert()[0]; + + float2 result_f; + half2* result_h = reinterpret_cast(&result_f); + + result_h[0] = data_h_0.convert(); + result_h[1] = data_h_1.convert(); + + Xdata_cast[j] = result_f; + mask_32[j] = m_32; + } + int high_index = ((((N / unroll_factor) - 1) / item_ct1.get_local_range().get(2) + 1) * + (unroll_factor * item_ct1.get_local_range().get(2))) + + item_ct1.get_local_id(2); + if (N > high_index) { + float4 rand = oneapi::mkl::rng::device::generate(distr, engine); + float* rand_data = &(rand.x()); + int k = 0; + for (int i = high_index; i < N; i++) { + float x_data = (float)Xdata[i] + (float)bias[i % dim]; + uint8_t m = (uint8_t)(rand_data[k++] > ratio); + Xdata[i] = + vec{x_data * scale * m}.convert()[0]; + mask[i] = m; + } + } +} + +template +void launch_dropout(T* out, + const T* bias, + uint8_t* mask, + int batch, + int dim, + float ratio, + queue* stream) +{ + assert(unroll_factor == 4); + + int total_count = batch * dim / unroll_factor; + + range<3> grid_dim = range<3>(1, 1, DS_GET_BLOCKS(total_count)); + range<3> block_dim = range<3>(1, 1, DS_CUDA_NUM_THREADS); + + uint64_t inc = (batch * dim) / grid_dim[2] / block_dim[2]; + std::pair seed = SyclContext::Instance().IncrementOffset(inc); + stream->submit([&](handler& cgh) { + cgh.parallel_for(nd_range<3>(grid_dim * block_dim, block_dim), [=](nd_item<3> item_ct1) { + dropout_kernel(total_count, dim, ratio, bias, out, mask, seed, item_ct1); + }); + }); +} + +template void launch_dropout(float*, + const float* bias, + uint8_t* mask, + int batch, + int dim, + float ratio, + queue* stream); +template void launch_dropout(half*, + const half* bias, + uint8_t* mask, + int batch, + int dim, + float ratio, + queue* stream); + +void dropout_kernel(const int N, + const int dim, + const float ratio, + const float* input, + const float* residual, + const float* bias, + float* out, + uint8_t* mask, + const std::pair& seed, + nd_item<3> item_ct1) +{ + const float scale = 1. / (1. - ratio); + size_t idx = + item_ct1.get_group(2) * item_ct1.get_local_range().get(2) + item_ct1.get_local_id(2); + int tid = item_ct1.get_local_id(2) % (dim / unroll_factor); + + oneapi::mkl::rng::device::philox4x32x10<4> engine(seed.first, {idx * 4, seed.second}); + oneapi::mkl::rng::device::uniform<> distr; + + float4* out_cast = reinterpret_cast(out); + uint32_t* mask_32 = reinterpret_cast(mask); + + const float4* bias_cast = reinterpret_cast(bias); + const float4* residual_cast = reinterpret_cast(residual); + const float4* input_cast = reinterpret_cast(input); + + DPCPP_1D_KERNEL_LOOP(j, N) + { + float4 rand = oneapi::mkl::rng::device::generate(distr, engine); + + uint32_t m_32; + uint8_t* m = (uint8_t*)&m_32; + + m[0] = (uint8_t)(rand.x() > ratio); + m[1] = (uint8_t)(rand.y() > ratio); + m[2] = (uint8_t)(rand.z() > ratio); + m[3] = (uint8_t)(rand.w() > ratio); + + float4 out_data; + float4 b_data = bias_cast[j % (dim / unroll_factor)]; + float4 res_data = residual_cast[j]; + float4 inp_data = input_cast[j]; + + out_data.x() = (b_data.x() + inp_data.x()); + out_data.y() = (b_data.y() + inp_data.y()); + out_data.z() = (b_data.z() + inp_data.z()); + out_data.w() = (b_data.w() + inp_data.w()); + + out_data.x() = out_data.x() * scale * m[0]; + out_data.y() = out_data.y() * scale * m[1]; + out_data.z() = out_data.z() * scale * m[2]; + out_data.w() = out_data.w() * scale * m[3]; + + out_data.x() += res_data.x(); + out_data.y() += res_data.y(); + out_data.z() += res_data.z(); + out_data.w() += res_data.w(); + + mask_32[j] = m_32; + out_cast[j] = out_data; + } + int high_index = ((((N / unroll_factor) - 1) / item_ct1.get_local_range().get(2) + 1) * + (unroll_factor * item_ct1.get_local_range().get(2))) + + item_ct1.get_local_id(2); + if (N > high_index) { + float4 rand = oneapi::mkl::rng::device::generate(distr, engine); + float* rand_data = &(rand.x()); + int k = 0; + for (int i = high_index; i < N; i++) { + float x_data = input[i] + bias[i % dim]; + uint8_t m = (uint8_t)(rand_data[k++] > ratio); + x_data = x_data * scale * m; + x_data += residual[i]; + + out[i] = x_data; + mask[i] = m; + } + } +} + +void dropout_kernel(const int N, + const int dim, + const float ratio, + const bf16* input, + const bf16* residual, + const bf16* bias, + bf16* out, + uint8_t* mask, + const std::pair& seed, + nd_item<3> item_ct1) +{ + const float scale = 1. / (1. - ratio); + size_t idx = + item_ct1.get_group(2) * item_ct1.get_local_range().get(2) + item_ct1.get_local_id(2); + int tid = item_ct1.get_local_id(2) % (dim / unroll_factor); + + oneapi::mkl::rng::device::philox4x32x10<4> engine(seed.first, {idx * 4, seed.second}); + oneapi::mkl::rng::device::uniform<> distr; + + ushort4* out_cast = reinterpret_cast(out); + uint32_t* mask_32 = reinterpret_cast(mask); + + const ushort4* bias_cast = reinterpret_cast(bias); + const ushort4* residual_cast = reinterpret_cast(residual); + const ushort4* input_cast = reinterpret_cast(input); + + DPCPP_1D_KERNEL_LOOP(j, N) + { + float4 rand = oneapi::mkl::rng::device::generate(distr, engine); + + uint32_t m_32; + uint8_t* m = (uint8_t*)&m_32; + + m[0] = (uint8_t)(rand.x() > ratio); + m[1] = (uint8_t)(rand.y() > ratio); + m[2] = (uint8_t)(rand.z() > ratio); + m[3] = (uint8_t)(rand.w() > ratio); + + float4 out_data; + float4 b_data = { + bf16::to_float(bias_cast[j % (dim / unroll_factor)].x()), + bf16::to_float(bias_cast[j % (dim / unroll_factor)].y()), + bf16::to_float(bias_cast[j % (dim / unroll_factor)].z()), + bf16::to_float(bias_cast[j % (dim / unroll_factor)].w()), + }; + float4 res_data = {bf16::to_float(residual_cast[j].x()), + bf16::to_float(residual_cast[j].y()), + bf16::to_float(residual_cast[j].z()), + bf16::to_float(residual_cast[j].w())}; + float4 inp_data = {bf16::to_float(input_cast[j].x()), + bf16::to_float(input_cast[j].y()), + bf16::to_float(input_cast[j].z()), + bf16::to_float(input_cast[j].w())}; + + out_data.x() = (b_data.x() + inp_data.x()); + out_data.y() = (b_data.y() + inp_data.y()); + out_data.z() = (b_data.z() + inp_data.z()); + out_data.w() = (b_data.w() + inp_data.w()); + + out_data.x() = out_data.x() * scale * m[0]; + out_data.y() = out_data.y() * scale * m[1]; + out_data.z() = out_data.z() * scale * m[2]; + out_data.w() = out_data.w() * scale * m[3]; + + out_data.x() += res_data.x(); + out_data.y() += res_data.y(); + out_data.z() += res_data.z(); + out_data.w() += res_data.w(); + + mask_32[j] = m_32; + out_cast[j] = {bf16::from_float(out_data.x()), + bf16::from_float(out_data.y()), + bf16::from_float(out_data.z()), + bf16::from_float(out_data.w())}; + } + int high_index = ((((N / unroll_factor) - 1) / item_ct1.get_local_range().get(2) + 1) * + (unroll_factor * item_ct1.get_local_range().get(2))) + + item_ct1.get_local_id(2); + if (N > high_index) { + ushort* out_cast = reinterpret_cast(out); + const ushort* bias_cast = reinterpret_cast(bias); + const ushort* residual_cast = reinterpret_cast(residual); + const ushort* input_cast = reinterpret_cast(input); + float4 rand = oneapi::mkl::rng::device::generate(distr, engine); + float* rand_data = &(rand.x()); + int k = 0; + for (int i = high_index; i < N; i++) { + float x_data = bf16::to_float(input_cast[i]) + bf16::to_float(bias_cast[i % dim]); + uint8_t m = (uint8_t)(rand_data[k++] > ratio); + x_data = x_data * scale * m; + x_data += bf16::to_float(residual_cast[i]); + + out_cast[i] = bf16::from_float(x_data); + mask[i] = m; + } + } +} + +void dropout_kernel(const int N, + const int dim, + const float ratio, + const half* input, + const half* residual, + const half* bias, + half* out, + uint8_t* mask, + const std::pair& seed, + nd_item<3> item_ct1) +{ + const float scale = 1. / (1. - ratio); + size_t idx = item_ct1.get_group(2) * item_ct1.get_local_range(2) + item_ct1.get_local_id(2); + int tid = item_ct1.get_local_id(2) % (dim / unroll_factor); + + oneapi::mkl::rng::device::philox4x32x10<4> engine(seed.first, {idx * 4, seed.second}); + oneapi::mkl::rng::device::uniform<> distr; + + float2* out_cast = reinterpret_cast(out); + uint32_t* mask_32 = reinterpret_cast(mask); + + const float2* bias_cast = reinterpret_cast(bias); + const float2* residual_cast = reinterpret_cast(residual); + const float2* input_cast = reinterpret_cast(input); + + DPCPP_1D_KERNEL_LOOP(j, N) + { + float4 rand = oneapi::mkl::rng::device::generate(distr, engine); + + float2 data_f; + half2* data_h = reinterpret_cast(&data_f); + + float2 bias_f; + half2* bias_h = reinterpret_cast(&bias_f); + + float2 residual_f; + half2* residual_h = reinterpret_cast(&residual_f); + + float2 input_f; + half2* input_h = reinterpret_cast(&input_f); + + bias_f = bias_cast[j % (dim / unroll_factor)]; + residual_f = residual_cast[j]; + input_f = input_cast[j]; + + float2 data_h_0 = data_h[0].convert(); + float2 data_h_1 = data_h[1].convert(); + + float2 bias_h_0 = bias_h[0].convert(); + float2 bias_h_1 = bias_h[1].convert(); + + float2 residual_h_0 = residual_h[0].convert(); + float2 residual_h_1 = residual_h[1].convert(); + + float2 input_h_0 = input_h[0].convert(); + float2 input_h_1 = input_h[1].convert(); + + data_h_0.x() = (bias_h_0.x() + input_h_0.x()); + data_h_0.y() = (bias_h_0.y() + input_h_0.y()); + data_h_1.x() = (bias_h_1.x() + input_h_1.x()); + data_h_1.y() = (bias_h_1.y() + input_h_1.y()); + + uint32_t m_32; + uint8_t* m = (uint8_t*)&m_32; + + m[0] = (uint8_t)(rand.x() > ratio); + m[1] = (uint8_t)(rand.y() > ratio); + m[2] = (uint8_t)(rand.z() > ratio); + m[3] = (uint8_t)(rand.w() > ratio); + + data_h_0.x() = + vec{data_h_0.x() * scale * m[0]}.convert()[0]; + data_h_0.y() = + vec{data_h_0.y() * scale * m[1]}.convert()[0]; + data_h_1.x() = + vec{data_h_1.x() * scale * m[2]}.convert()[0]; + data_h_1.y() = + vec{data_h_1.y() * scale * m[3]}.convert()[0]; + + data_h_0.x() += residual_h_0.x(); + data_h_0.y() += residual_h_0.y(); + data_h_1.x() += residual_h_1.x(); + data_h_1.y() += residual_h_1.y(); + + float2 result_f; + half2* result_h = reinterpret_cast(&result_f); + + result_h[0] = data_h_0.convert(); + result_h[1] = data_h_1.convert(); + + out_cast[j] = result_f; + mask_32[j] = m_32; + } + int high_index = ((((N / unroll_factor) - 1) / item_ct1.get_local_range().get(2) + 1) * + (unroll_factor * item_ct1.get_local_range().get(2))) + + item_ct1.get_local_id(2); + if (N > high_index) { + float4 rand = oneapi::mkl::rng::device::generate(distr, engine); + float* rand_data = &(rand.x()); + int k = 0; + for (int i = high_index; i < N; i++) { + float x_data = (float)input[i] + (float)bias[i % dim]; + uint8_t m = (uint8_t)(rand_data[k++] > ratio); + x_data = x_data * scale * m; + x_data += (float)residual[i]; + + out[i] = vec{x_data}.convert()[0]; + mask[i] = m; + } + } +} + +template +void launch_dropout(T* out, + const T* input, + const T* residual, + const T* bias, + uint8_t* mask, + int batch, + int dim, + float ratio, + queue* stream) +{ + assert(unroll_factor == 4); + + int total_count = batch * dim / unroll_factor; + range<3> grid_dim = range<3>(1, 1, DS_GET_BLOCKS(total_count)); + range<3> block_dim = range<3>(1, 1, DS_CUDA_NUM_THREADS); + + uint64_t inc = (batch * dim) / grid_dim[2] / block_dim[2]; + std::pair seed = SyclContext::Instance().IncrementOffset(inc); + + stream->submit([&](handler& cgh) { + cgh.parallel_for(nd_range<3>(grid_dim * block_dim, block_dim), [=](nd_item<3> item_ct1) { + dropout_kernel( + total_count, dim, ratio, input, residual, bias, out, mask, seed, item_ct1); + }); + }); +} + +template void launch_dropout(float*, + const float*, + const float* residual, + const float* bias, + uint8_t* mask, + int batch, + int dim, + float ratio, + queue* stream); +template void launch_dropout(bf16*, + const bf16*, + const bf16* residual, + const bf16* bias, + uint8_t* mask, + int batch, + int dim, + float ratio, + queue* stream); +template void launch_dropout(half*, + const half*, + const half* residual, + const half* bias, + uint8_t* mask, + int batch, + int dim, + float ratio, + queue* stream); diff --git a/intel_extension_for_deepspeed/op_builder/csrc/transformer/sycl/ds_dropout_sycl.dp.cpp b/intel_extension_for_deepspeed/op_builder/csrc/transformer/sycl/ds_dropout_sycl.dp.cpp new file mode 100644 index 0000000..4a6299e --- /dev/null +++ b/intel_extension_for_deepspeed/op_builder/csrc/transformer/sycl/ds_dropout_sycl.dp.cpp @@ -0,0 +1,92 @@ +#include "sycl/common.hpp" +#include "sycl/context.hpp" +#include "sycl/dropout.hpp" + +template +std::vector dropout_forward(float ratio, + uint32_t dim, + int bsz, + const torch::Tensor& vals) +{ + CHECK_INPUT(vals); + auto output = torch::empty_like(vals); + + auto uint8_options = torch::TensorOptions() + .dtype(torch::kInt8) + .layout(torch::kStrided) + .device(torch::kXPU) + .requires_grad(false); + + auto mask = torch::empty({bsz, dim}, uint8_options); + + const T* input_ptr = (const T*)vals.data_ptr(); + T* output_ptr = (T*)output.data_ptr(); + uint8_t* mask_ptr = (uint8_t*)mask.data_ptr(); + + sycl::queue* q = ::SyclContext::Instance().GetCurrentStream(); + Dropout _dropout = Dropout(typename Dropout::Config(ratio, dim)); + _dropout.SetMask(mask_ptr); + _dropout.Forward(bsz, output_ptr, input_ptr, q); + return {output, mask}; +} + +template +std::vector dropout_forward_with_bias(float ratio, + uint32_t dim, + int bsz, + const torch::Tensor& vals, + const torch::Tensor& bias, + const torch::Tensor& residual) +{ + CHECK_INPUT(vals); + CHECK_INPUT(bias); + CHECK_INPUT(residual); + auto output = torch::empty_like(vals); + + auto uint8_options = torch::TensorOptions() + .dtype(torch::kInt8) + .layout(torch::kStrided) + .device(torch::kXPU) + .requires_grad(false); + + auto mask = torch::empty({bsz, dim}, uint8_options); + + const T* input_ptr = (const T*)vals.data_ptr(); + const T* bias_ptr = (const T*)bias.data_ptr(); + const T* residual_ptr = (const T*)residual.data_ptr(); + T* output_ptr = (T*)output.data_ptr(); + uint8_t* mask_ptr = (uint8_t*)mask.data_ptr(); + + sycl::queue* q = ::SyclContext::Instance().GetCurrentStream(); + Dropout _dropout = Dropout(typename Dropout::Config(ratio, dim)); + _dropout.SetMask(mask_ptr); + _dropout.ForwardWithBias(bsz, output_ptr, input_ptr, residual_ptr, bias_ptr, q); + return {output, mask}; +} + +template +std::vector dropout_backward(float ratio, + uint32_t dim, + int bsz, + torch::Tensor& vals, + torch::Tensor& mask, + bool in_place) +{ + CHECK_INPUT(vals); + CHECK_INPUT(mask); + sycl::queue* q = ::SyclContext::Instance().GetCurrentStream(); + Dropout _dropout = Dropout(typename Dropout::Config(ratio, dim)); + uint8_t* mask_ptr = (uint8_t*)mask.data_ptr(); + _dropout.SetMask(mask_ptr); + if (in_place) { + T* d_input_ptr = (T*)vals.data_ptr(); + _dropout.Backward(bsz, d_input_ptr, q); + return {vals}; + } else { + auto output = torch::empty_like(vals); + const T* d_input_ptr = (const T*)vals.data_ptr(); + T* d_output_ptr = (T*)output.data_ptr(); + _dropout.Backward(bsz, d_output_ptr, d_input_ptr, q); + return {output}; + } +} diff --git a/intel_extension_for_deepspeed/op_builder/csrc/transformer/sycl/ds_feedforward_sycl.dp.cpp b/intel_extension_for_deepspeed/op_builder/csrc/transformer/sycl/ds_feedforward_sycl.dp.cpp new file mode 100644 index 0000000..a131d5f --- /dev/null +++ b/intel_extension_for_deepspeed/op_builder/csrc/transformer/sycl/ds_feedforward_sycl.dp.cpp @@ -0,0 +1,81 @@ +#include "sycl/common.hpp" +#include "sycl/context.hpp" +#include "sycl/feed_forward.hpp" + +template +std::vector feedforward_forward(int bsz, + int seq_len, + int hidden_size, + const torch::Tensor& input, + const torch::Tensor& weights) +{ + CHECK_INPUT(input); + CHECK_INPUT(weights); + + int batchSize = bsz * seq_len; + int inputSize = hidden_size; + int outputSize = 3 * hidden_size; + auto options = torch::TensorOptions() + .dtype(input.options().dtype()) + .layout(torch::kStrided) + .device(torch::kXPU) + .requires_grad(true); + + const T* input_ptr = (const T*)input.data_ptr(); + const T* weights_ptr = (const T*)weights.data_ptr(); + + auto output = torch::empty({bsz, seq_len, outputSize}, options); + + T* output_ptr = (T*)output.data_ptr(); + + sycl::queue* q = ::SyclContext::Instance().GetCurrentStream(); + + FeedForward _ff = + FeedForward(typename FeedForward::Config(batchSize, outputSize, inputSize)); + + _ff.Forward(batchSize, input_ptr, weights_ptr, output_ptr, q); + return {output}; +} + +template +std::vector feedforward_backward(int bsz, + int seq_len, + int hidden_size, + const torch::Tensor& grad_out, + const torch::Tensor& input, + const torch::Tensor& weights) +{ + CHECK_INPUT(grad_out); + CHECK_INPUT(input); + CHECK_INPUT(weights); + + int batchSize = bsz * seq_len; + int inputSize = hidden_size; + int outputSize = 3 * hidden_size; + + auto options = torch::TensorOptions() + .dtype(input.options().dtype()) + .layout(torch::kStrided) + .device(torch::kXPU) + .requires_grad(true); + + const T* grad_out_ptr = (const T*)grad_out.data_ptr(); + const T* input_ptr = (const T*)input.data_ptr(); + const T* weights_ptr = (const T*)weights.data_ptr(); + + auto grad_weights = torch::empty(weights.sizes(), options); + auto grad_bias = torch::empty({outputSize}, options); + auto grad_input = torch::empty(input.sizes(), options); + + T* grad_w_ptr = (T*)grad_weights.data_ptr(); + T* grad_b_ptr = (T*)grad_bias.data_ptr(); + T* grad_i_ptr = (T*)grad_input.data_ptr(); + sycl::queue* q = ::SyclContext::Instance().GetCurrentStream(); + + FeedForward _ff = + FeedForward(typename FeedForward::Config(batchSize, outputSize, inputSize)); + + _ff.Backward( + batchSize, grad_out_ptr, input_ptr, weights_ptr, grad_w_ptr, grad_b_ptr, q, q, grad_i_ptr); + return {grad_input, grad_weights, grad_bias}; +} diff --git a/intel_extension_for_deepspeed/op_builder/csrc/transformer/sycl/ds_gelu_sycl.dp.cpp b/intel_extension_for_deepspeed/op_builder/csrc/transformer/sycl/ds_gelu_sycl.dp.cpp new file mode 100644 index 0000000..ddc182f --- /dev/null +++ b/intel_extension_for_deepspeed/op_builder/csrc/transformer/sycl/ds_gelu_sycl.dp.cpp @@ -0,0 +1,39 @@ +#include "sycl/common.hpp" +#include "sycl/context.hpp" +#include "sycl/gelu.hpp" + +template +std::vector gelu_forward(int intermediate_size, + int bsz_seq, + const torch::Tensor& input, + const torch::Tensor& bias) +{ + CHECK_INPUT(input); + CHECK_INPUT(bias); + const T* input_ptr = (const T*)input.data_ptr(); + const T* bias_ptr = (const T*)bias.data_ptr(); + auto output = torch::empty_like(input); + T* output_ptr = (T*)output.data_ptr(); + sycl::queue* q = ::SyclContext::Instance().GetCurrentStream(); + Gelu _gelu = Gelu(typename Gelu::Config(intermediate_size)); + _gelu.ForwardWithBiasAdd(bsz_seq, input_ptr, bias_ptr, output_ptr, q); + return {output}; +} + +template +std::vector gelu_backward(torch::Tensor& d_output, + int intermediate_size, + int bsz_seq, + const torch::Tensor& input, + const torch::Tensor& bias) +{ + CHECK_INPUT(input); + CHECK_INPUT(bias); + const T* input_ptr = (const T*)input.data_ptr(); + const T* bias_ptr = (const T*)bias.data_ptr(); + T* d_output_ptr = (T*)d_output.data_ptr(); + sycl::queue* q = ::SyclContext::Instance().GetCurrentStream(); + Gelu _gelu = Gelu(typename Gelu::Config(intermediate_size)); + _gelu.Backward(bsz_seq, d_output_ptr, input_ptr, bias_ptr, q); + return {d_output}; +} diff --git a/intel_extension_for_deepspeed/op_builder/csrc/transformer/sycl/ds_layer_reorder_sycl.dp.cpp b/intel_extension_for_deepspeed/op_builder/csrc/transformer/sycl/ds_layer_reorder_sycl.dp.cpp new file mode 100644 index 0000000..4ae77e4 --- /dev/null +++ b/intel_extension_for_deepspeed/op_builder/csrc/transformer/sycl/ds_layer_reorder_sycl.dp.cpp @@ -0,0 +1,122 @@ +#include "sycl/common.hpp" +#include "sycl/context.hpp" +#include "sycl/custom_sycl_layers.hpp" +#include "sycl/general_kernels.hpp" + +template +std::vector transform4d_0213(const torch::Tensor& input, + int batch, + int seq_len, + int hidden_size, + int num_heads, + int trans_count) +{ + CHECK_INPUT(input); + auto options = torch::TensorOptions() + .dtype(input.options().dtype()) + .layout(torch::kStrided) + .device(torch::kXPU) + .requires_grad(true); + + torch::Tensor output; + if (trans_count == 3) + // trans_count=3 + output = torch::empty({batch, seq_len, 3, num_heads, hidden_size / num_heads}, options); + else + // for 1 attn_o_inp, trans_count=1 + output = torch::empty({batch, seq_len, num_heads, hidden_size / num_heads}, options); + + sycl::queue* q = ::SyclContext::Instance().GetCurrentStream(); + + const T* input_ptr = (const T*)input.data_ptr(); + T* output_ptr = (T*)output.data_ptr(); + // trans_count=1 + // launch_transform4d_0213(output_ptr, input_ptr, batch, num_heads, seq_len, + // hidden_size, q, 1); + // trans_count=3 + launch_transform4d_0213( + output_ptr, input_ptr, batch, num_heads, seq_len, hidden_size, q, trans_count); + return {output}; +} + +template +std::vector bias_add_transform_0213(const torch::Tensor& input, + const torch::Tensor& bias, + int batch, + int seq_len, + int hidden_size, + int num_heads) +{ + CHECK_INPUT(input); + CHECK_INPUT(bias); + auto options = torch::TensorOptions() + .dtype(input.options().dtype()) + .layout(torch::kStrided) + .device(torch::kXPU) + .requires_grad(true); + + auto output = torch::empty({3, batch, num_heads, seq_len, hidden_size / num_heads}, options); + + sycl::queue* q = ::SyclContext::Instance().GetCurrentStream(); + + const T* input_ptr = (const T*)input.data_ptr(); + const T* bias_ptr = (const T*)bias.data_ptr(); + T* output_ptr = (T*)output.data_ptr(); + launch_bias_add_transform_0213( + output_ptr, input_ptr, bias_ptr, batch, seq_len, hidden_size, num_heads, q, 3); + return {output}; +} + +template +std::vector transform_0213(const torch::Tensor& input, + int batch, + int seq_len, + int hidden_size, + int num_heads) +{ + CHECK_INPUT(input); + + auto options = torch::TensorOptions() + .dtype(input.options().dtype()) + .layout(torch::kStrided) + .device(torch::kXPU) + .requires_grad(true); + + auto output = torch::empty({batch, num_heads, seq_len, hidden_size / num_heads}, options); + + sycl::queue* q = ::SyclContext::Instance().GetCurrentStream(); + + const T* input_ptr = (const T*)input.data_ptr(); + T* output_ptr = (T*)output.data_ptr(); + + launch_transform_0213(output_ptr, input_ptr, batch, seq_len, hidden_size, num_heads, q); + return {output}; +} + +template +std::vector fused_add2(const torch::Tensor& input1, + const torch::Tensor& input2, + int batch, + int seq_len, + int hidden_size) +{ + CHECK_INPUT(input1); + CHECK_INPUT(input2); + + auto options = torch::TensorOptions() + .dtype(input1.options().dtype()) + .layout(torch::kStrided) + .device(torch::kXPU) + .requires_grad(true); + + auto output = torch::empty({batch, seq_len, hidden_size}, options); + + sycl::queue* q = ::SyclContext::Instance().GetCurrentStream(); + + const T* input_ptr1 = (const T*)input1.data_ptr(); + const T* input_ptr2 = (const T*)input2.data_ptr(); + T* output_ptr = (T*)output.data_ptr(); + + launch_fused_add2(output_ptr, input_ptr1, input_ptr2, batch, seq_len, hidden_size, q); + return {output}; +} diff --git a/intel_extension_for_deepspeed/op_builder/csrc/transformer/sycl/ds_normalize_sycl.dp.cpp b/intel_extension_for_deepspeed/op_builder/csrc/transformer/sycl/ds_normalize_sycl.dp.cpp new file mode 100644 index 0000000..b110bc9 --- /dev/null +++ b/intel_extension_for_deepspeed/op_builder/csrc/transformer/sycl/ds_normalize_sycl.dp.cpp @@ -0,0 +1,151 @@ +#include "sycl/common.hpp" +#include "sycl/context.hpp" +#include "sycl/normalize_layer.hpp" + +template +std::vector normalize_forward(const int batch, + const int seq_len, + const int hidden_size, + const torch::Tensor& residual, + const torch::Tensor& gamma, + const torch::Tensor& betta, + torch::Tensor& mean, + torch::Tensor& var, + const bool preln, + const bool wmean, + const float epsilon) +{ + CHECK_INPUT(residual); + CHECK_INPUT(gamma); + CHECK_INPUT(betta); + + int bsz_seq = batch * seq_len; + + auto options = torch::TensorOptions() + .dtype(residual.options().dtype()) + .layout(torch::kStrided) + .device(torch::kXPU) + .requires_grad(true); + + auto output = torch::empty({batch, seq_len, hidden_size}, options); + + T* output_ptr = (T*)output.data_ptr(); + T* mean_ptr = (T*)mean.data_ptr(); + T* var_ptr = (T*)var.data_ptr(); + const T* residual_ptr = (const T*)residual.data_ptr(); + const T* gamma_ptr = (const T*)gamma.data_ptr(); + const T* betta_ptr = (const T*)betta.data_ptr(); + + sycl::queue* q = ::SyclContext::Instance().GetCurrentStream(); + Normalize_Layer _norm( + typename Normalize_Layer::Config(batch, seq_len, hidden_size, epsilon, true, wmean)); + _norm.SetMean(mean_ptr); + _norm.SetVar(var_ptr); + + if (wmean) + _norm.ForwardCheckpoint(bsz_seq, output_ptr, residual_ptr, gamma_ptr, betta_ptr, q); + else + _norm.Forward(bsz_seq, output_ptr, residual_ptr, gamma_ptr, betta_ptr, q); + return {output}; +} + +template +std::vector normalize_backward(const int batch, + const int seq_len, + const int hidden_size, + const torch::Tensor& input, + const torch::Tensor& gamma, + const torch::Tensor& betta, + const torch::Tensor& output, + const torch::Tensor& out1_grad, + const torch::Tensor& out2_grad, + torch::Tensor& mean, + torch::Tensor& var, + const bool preln, + const bool wmean, + const float epsilon) +{ + CHECK_INPUT(input); + CHECK_INPUT(output); + CHECK_INPUT(out1_grad); + CHECK_INPUT(out2_grad); + CHECK_INPUT(gamma); + CHECK_INPUT(betta); + int bsz_seq = batch * seq_len; + + auto options = torch::TensorOptions() + .dtype(input.options().dtype()) + .layout(torch::kStrided) + .device(torch::kXPU) + .requires_grad(true); + + auto gamma_grad = torch::empty({hidden_size}, options); + auto betta_grad = torch::empty({hidden_size}, options); + auto input_grad = torch::empty({batch, seq_len, hidden_size}, options); + + const T* input_ptr = (const T*)input.data_ptr(); + const T* out1_grad_ptr = (const T*)out1_grad.data_ptr(); + const T* out2_grad_ptr = (const T*)out2_grad.data_ptr(); + const T* gamma_ptr = (const T*)gamma.data_ptr(); + const T* betta_ptr = (const T*)betta.data_ptr(); + const T* output_ptr = (const T*)output.data_ptr(); + T* gamma_grad_ptr = (T*)gamma_grad.data_ptr(); + T* betta_grad_ptr = (T*)betta_grad.data_ptr(); + T* inp_grad_ptr = (T*)input_grad.data_ptr(); + T* mean_ptr = (T*)mean.data_ptr(); + T* var_ptr = (T*)var.data_ptr(); + sycl::queue* q = ::SyclContext::Instance().GetCurrentStream(); + + Normalize_Layer _norm( + typename Normalize_Layer::Config(batch, seq_len, hidden_size, epsilon, true, wmean)); + sycl::queue* qs[2] = {q, q}; + + _norm.SetMean(mean_ptr); + _norm.SetVar(var_ptr); + + if (preln) { + if (wmean) + _norm.BackwardFusedAdd(bsz_seq, + out1_grad_ptr, + out2_grad_ptr, + gamma_ptr, + gamma_grad_ptr, + betta_grad_ptr, + qs, + inp_grad_ptr, + input_ptr); + else + _norm.BackwardFusedAdd(bsz_seq, + out1_grad_ptr, + out2_grad_ptr, + gamma_ptr, + betta_ptr, + gamma_grad_ptr, + betta_grad_ptr, + qs, + inp_grad_ptr, + output_ptr); + } else { + if (wmean) + _norm.Backward(bsz_seq, + out1_grad_ptr, + gamma_ptr, + gamma_grad_ptr, + betta_grad_ptr, + qs, + inp_grad_ptr, + input_ptr); + else { + _norm.Backward(bsz_seq, + out1_grad_ptr, + gamma_ptr, + betta_ptr, + gamma_grad_ptr, + betta_grad_ptr, + qs, + inp_grad_ptr, + output_ptr); + } + } + return {input_grad, gamma_grad, betta_grad}; +} diff --git a/intel_extension_for_deepspeed/op_builder/csrc/transformer/sycl/ds_softmax_sycl.dp.cpp b/intel_extension_for_deepspeed/op_builder/csrc/transformer/sycl/ds_softmax_sycl.dp.cpp new file mode 100644 index 0000000..87bd1ce --- /dev/null +++ b/intel_extension_for_deepspeed/op_builder/csrc/transformer/sycl/ds_softmax_sycl.dp.cpp @@ -0,0 +1,44 @@ +#include "sycl/common.hpp" +#include "sycl/context.hpp" +#include "sycl/softmax.hpp" + +template +std::vector softmax_forward(int bsz, + int seq_len, + int num_heads, + torch::Tensor& inout, + const torch::Tensor& mask) +{ + CHECK_INPUT(inout); + CHECK_INPUT(mask); + + T* inout_ptr = (T*)inout.data_ptr(); + const T* mask_ptr = (const T*)mask.data_ptr(); + + sycl::queue* q = ::SyclContext::Instance().GetCurrentStream(); + Softmax _softmax = Softmax(typename Softmax::Config(bsz, num_heads, seq_len)); + _softmax.SetSeqLength(seq_len); + _softmax.Forward(bsz, inout_ptr, mask_ptr, q); + return {inout}; +} + +template +std::vector softmax_backward(int bsz, + int seq_len, + int num_heads, + torch::Tensor& out_grad, + const torch::Tensor& input) +{ + CHECK_INPUT(out_grad); + CHECK_INPUT(input); + + T* out_grad_ptr = (T*)out_grad.data_ptr(); + const T* input_ptr = (const T*)input.data_ptr(); + + sycl::queue* q = ::SyclContext::Instance().GetCurrentStream(); + Softmax _softmax = Softmax(typename Softmax::Config(bsz, num_heads, seq_len)); + _softmax.SetSeqLength(seq_len); + + _softmax.Backward(bsz, out_grad_ptr, input_ptr, q); + return {out_grad}; +} diff --git a/intel_extension_for_deepspeed/op_builder/csrc/transformer/sycl/ds_stridedbatchgemm_sycl.dp.cpp b/intel_extension_for_deepspeed/op_builder/csrc/transformer/sycl/ds_stridedbatchgemm_sycl.dp.cpp new file mode 100644 index 0000000..f5b6102 --- /dev/null +++ b/intel_extension_for_deepspeed/op_builder/csrc/transformer/sycl/ds_stridedbatchgemm_sycl.dp.cpp @@ -0,0 +1,95 @@ +#include "sycl/common.hpp" +#include "sycl/context.hpp" +#include "sycl/strided_batch_gemm.hpp" + +template +std::vector stridedbatchgemm_forward(const int batchSize, + const int m, + const int n, + const int k, + const float alpha, + const float beta, + const torch::Tensor& matA, + const torch::Tensor& matB) +{ + CHECK_INPUT(matA); + CHECK_INPUT(matB); + + auto options = torch::TensorOptions() + .dtype(matA.options().dtype()) + .layout(torch::kStrided) + .device(torch::kXPU) + .requires_grad(true); + + StridedBatchGemm _sbgemm = + StridedBatchGemm(typename StridedBatchGemm::Config(batchSize, + m, + n, + k, + alpha, + beta, + oneapi::mkl::transpose::trans, + oneapi::mkl::transpose::nontrans, + {0, 0, 0})); + + const T* matA_ptr = (const T*)matA.data_ptr(); + const T* matB_ptr = (const T*)matB.data_ptr(); + + auto matC = torch::empty({batchSize, n, m}, options); + + T* matC_ptr = (T*)matC.data_ptr(); + + sycl::queue* q = ::SyclContext::Instance().GetCurrentStream(); + + _sbgemm.Forward(batchSize, matC_ptr, matA_ptr, matB_ptr, q); + return {matC}; +} + +template +std::vector stridedbatchgemm_backward(const int batchSize, + const int m, + const int n, + const int k, + const float alpha, + const float beta, + const torch::Tensor& grad_matC, + const torch::Tensor& matA, + const torch::Tensor& matB) +{ + CHECK_INPUT(grad_matC); + CHECK_INPUT(matA); + CHECK_INPUT(matB); + + auto options = torch::TensorOptions() + .dtype(matA.options().dtype()) + .layout(torch::kStrided) + .device(torch::kXPU) + .requires_grad(true); + + StridedBatchGemm _sbgemm = + StridedBatchGemm(typename StridedBatchGemm::Config(batchSize, + m, + n, + k, + alpha, + beta, + oneapi::mkl::transpose::trans, + oneapi::mkl::transpose::nontrans, + {0, 0, 0})); + + const T* grad_c_ptr = (const T*)grad_matC.data_ptr(); + const T* matA_ptr = (const T*)matA.data_ptr(); + const T* matB_ptr = (const T*)matB.data_ptr(); + + auto grad_matA = torch::empty(matA.sizes(), options); + auto grad_matB = torch::empty(matB.sizes(), options); + CHECK_INPUT(grad_matA); + CHECK_INPUT(grad_matB); + + T* grad_a_ptr = (T*)grad_matA.data_ptr(); + T* grad_b_ptr = (T*)grad_matB.data_ptr(); + sycl::queue* q = ::SyclContext::Instance().GetCurrentStream(); + + _sbgemm.Backward(batchSize, grad_c_ptr, matA_ptr, matB_ptr, q, grad_a_ptr, grad_b_ptr); + return {grad_matA, grad_matB}; +} diff --git a/intel_extension_for_deepspeed/op_builder/csrc/transformer/sycl/ds_transformer_sycl.dp.cpp b/intel_extension_for_deepspeed/op_builder/csrc/transformer/sycl/ds_transformer_sycl.dp.cpp new file mode 100644 index 0000000..921b5e5 --- /dev/null +++ b/intel_extension_for_deepspeed/op_builder/csrc/transformer/sycl/ds_transformer_sycl.dp.cpp @@ -0,0 +1,1081 @@ +#include "sycl/ds_transformer_sycl.hpp" +#include +#include +#include +#include +#include +#include +#include +#include "sycl/Timer.hpp" +#include "sycl/common.hpp" +#include "sycl/context.hpp" +#include "sycl/custom_sycl_layers.hpp" +#include "sycl/onednn_wrappers.hpp" +#include "sycl/onemkl_wrappers.hpp" + +static std::unordered_map> s_transformer_layers; + +const int init_seq_length = 128; + +// C++ interface + +template +size_t get_workspace_size(int maxBatchSize, + int seq_len, + int hidden_size, + int intermediate_size, + int heads, + bool training, + bool gelu_checkpoint) +{ + size_t workSpacesize = 4 * (size_t(maxBatchSize) * seq_len * hidden_size); + if (training) { + workSpacesize += 2 * (size_t(maxBatchSize) * seq_len * hidden_size); + workSpacesize += ((std::max)((size_t(maxBatchSize) * seq_len * intermediate_size), + 2 * (size_t(maxBatchSize) * heads * seq_len * seq_len))); + if (gelu_checkpoint) + workSpacesize += 2 * (size_t(maxBatchSize) * seq_len * intermediate_size); + } + return workSpacesize; // * sizeof(T); +} + +// NOTE: AT_ASSERT has become AT_CHECK on master after 0.4. +#define CHECK_XPU(x) AT_ASSERTM(x.is_xpu(), #x " must be a XPU tensor") +#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) \ + CHECK_XPU(x); \ + CHECK_CONTIGUOUS(x) + +template +BertTransformerLayer::BertTransformerLayer(int layer_id, + int batch_size, + int hidden_size, + int num_heads, + int intermediate_size, + int seq_length, + float attn_prob_dropout_ratio, + float hidden_output_dropout_ratio, + float layer_norm_eps, + bool pre_or_postLayerNorm, + const std::vector>& gemm_algos, + bool attn_dropout_checkpoint, + bool normalize_invertible, + bool gelu_checkpoint, + bool stochastic_mode) + : _layer_id(layer_id), + _batch_size(batch_size), + _hidden_size(hidden_size), + _heads(num_heads), + _intermediate_size(intermediate_size), + _seq_length(seq_length), + _training(true), + _pre_or_postLayerNorm(pre_or_postLayerNorm), + _attn_dropout_checkpoint(attn_dropout_checkpoint), + _normalize_invertible(normalize_invertible), + _gelu_checkpoint(gelu_checkpoint), + _stochastic_mode(stochastic_mode), + _stream(::SyclContext::Instance().GetCurrentStream()), + _onemklQ(::SyclContext::Instance().GetCurrentStream()), + _qkv_linear( + typename FeedForward::Config(batch_size * seq_length, 3 * hidden_size, hidden_size)), + _attn_out_linear( + typename FeedForward::Config(batch_size * seq_length, hidden_size, hidden_size)), + _attn_layer_norm(typename Normalize_Layer::Config(batch_size, + seq_length, + hidden_size, + layer_norm_eps, + true, + !normalize_invertible)), + _layer_norm(typename Normalize_Layer::Config(batch_size, + seq_length, + hidden_size, + layer_norm_eps, + true, + !normalize_invertible)), + _ff1(typename FeedForward::Config(batch_size * seq_length, + _intermediate_size, + hidden_size)), + _ff2(typename FeedForward::Config(batch_size * seq_length, + hidden_size, + _intermediate_size)), + _softmax(typename Softmax::Config(batch_size, num_heads, seq_length)), + _gelu(typename Gelu::Config(_intermediate_size)), + _attn_prob_dropout(typename Dropout::Config(attn_prob_dropout_ratio, _seq_length)), + _attn_output_dropout(typename Dropout::Config(hidden_output_dropout_ratio, _hidden_size)), + _layer_output_dropout(typename Dropout::Config(hidden_output_dropout_ratio, _hidden_size)), + _attn_scores(typename StridedBatchGemm::Config(_batch_size * _heads, + _seq_length, + _seq_length, + _hidden_size / _heads, + //(T(1.0) / T(sqrt(_hidden_size / _heads))), + float(1.0 / sqrt(_hidden_size / _heads)), + float(0.0), + oneapi::mkl::transpose::trans, + oneapi::mkl::transpose::nontrans, + gemm_algos[3])), + _attn_context(typename StridedBatchGemm::Config(_batch_size * _heads, + _hidden_size / _heads, + _seq_length, + _seq_length, + float(1.0), + float(0.0), + oneapi::mkl::transpose::nontrans, + oneapi::mkl::transpose::nontrans, + gemm_algos[4])) +{ + assert(_hidden_size % _heads == 0); + Initialize(); +} + +template +BertTransformerLayer::~BertTransformerLayer() +{ +} + +template +void BertTransformerLayer::Initialize() +{ +} + +template +void BertTransformerLayer::Forward(int bsz, + const T* input_ptr, + const T* input_mask_ptr, + const T* attn_qkvw_ptr, + const T* attn_qkvb_ptr, + const T* attn_ow_ptr, + const T* attn_ob_ptr, + const T* attn_nw_ptr, + const T* attn_nb_ptr, + const T* inter_w_ptr, + const T* inter_b_ptr, + const T* output_w_ptr, + const T* output_b_ptr, + const T* norm_w_ptr, + const T* norm_b_ptr, + T* out_ptr, + T* inp_norm_ptr, + T* q_tf_ptr, + T* k_tf_ptr, + T* v_tf_ptr, + T* soft_out_ptr, + T* ctx_bufB_ptr, + T* attn_o_inp_ptr, + T* add_res_ptr, + T* ff1_inp_ptr, + T* gelu_inp_ptr, + T* ff2_inp_ptr) +{ + if (!_stochastic_mode) _stream->wait(); + + T* workspace = static_cast(::SyclContext::Instance().GetWorkSpace()); + size_t small_buf_size = bsz * _seq_length * _hidden_size; + T* buf_0 = workspace; + T* buf_1 = buf_0 + small_buf_size; + T* buf_2 = buf_1; + + if (_normalize_invertible) { + add_res_ptr = buf_1 + 3 * small_buf_size; + buf_2 = add_res_ptr; + } + if (_gelu_checkpoint) buf_2 += small_buf_size; + if (_attn_dropout_checkpoint) + ctx_bufB_ptr = + (_gelu_checkpoint ? (buf_2 + (_intermediate_size / _hidden_size) * small_buf_size) + : (buf_1 + 4 * small_buf_size)); + + int bsz_seq = bsz * _seq_length; + + if (_pre_or_postLayerNorm) { + if (_layer_norm.UseMean()) { + _layer_norm.ForwardCheckpoint( + bsz_seq, inp_norm_ptr, input_ptr, norm_w_ptr, norm_b_ptr, _stream, true); + + } else { + // input_ptr[bsz_seq], norm_w_ptr[_seq_length], norm_b_ptr[_seq_length] + // --> inp_norm_ptr[bsz_seq] + _layer_norm.Forward( + bsz_seq, inp_norm_ptr, input_ptr, norm_w_ptr, norm_b_ptr, _stream, true); + } + } + + if (_pre_or_postLayerNorm) { + _qkv_linear.Forward(bsz_seq, inp_norm_ptr, attn_qkvw_ptr, buf_0, _onemklQ); + } else { + _qkv_linear.Forward(bsz_seq, input_ptr, attn_qkvw_ptr, buf_0, _onemklQ); + } + + launch_bias_add_transform_0213( + q_tf_ptr, buf_0, attn_qkvb_ptr, bsz, _seq_length, _hidden_size, _heads, _stream, 3); + + int bsz_heads = bsz * _heads; + + // attention scores + _attn_scores.Forward(bsz_heads, soft_out_ptr, k_tf_ptr, q_tf_ptr, _onemklQ); + + // Softmax + Mask + _softmax.Forward(bsz, soft_out_ptr, input_mask_ptr, _stream); + + // attn prob dropout. + _attn_prob_dropout.Forward(bsz_heads * _seq_length, ctx_bufB_ptr, soft_out_ptr, _stream); + + // attention context + _attn_context.Forward(bsz_heads, buf_1, v_tf_ptr, ctx_bufB_ptr, _onemklQ); + + launch_transform4d_0213( + attn_o_inp_ptr, buf_1, bsz, _heads, _seq_length, _hidden_size, _stream, 1); + + if (_pre_or_postLayerNorm) { + _attn_out_linear.Forward(bsz_seq, attn_o_inp_ptr, attn_ow_ptr, buf_1, _onemklQ); + } else { + _attn_out_linear.Forward(bsz_seq, attn_o_inp_ptr, attn_ow_ptr, ff1_inp_ptr, _onemklQ); + } + + // attn output dropout. + if (_pre_or_postLayerNorm) { + _attn_output_dropout.ForwardWithBias( + bsz_seq, add_res_ptr, buf_1, input_ptr, attn_ob_ptr, _stream); + } else { + _attn_output_dropout.ForwardWithBias( + bsz_seq, add_res_ptr, ff1_inp_ptr, input_ptr, attn_ob_ptr, _stream); + } + + if (_pre_or_postLayerNorm) { + if (_attn_layer_norm.UseMean()) { + _attn_layer_norm.ForwardCheckpoint( + bsz_seq, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true); + } else { + _attn_layer_norm.Forward( + bsz_seq, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true); + } + } else { + if (_attn_layer_norm.UseMean()) { + _attn_layer_norm.ForwardCheckpoint( + bsz_seq, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true); + } else { + _attn_layer_norm.Forward( + bsz_seq, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true); + } + } + + _ff1.Forward(bsz_seq, + ff1_inp_ptr, + inter_w_ptr, + (_gelu_checkpoint ? ff2_inp_ptr : gelu_inp_ptr), + _onemklQ); + + _gelu.ForwardWithBiasAdd(bsz_seq, + (_gelu_checkpoint ? ff2_inp_ptr : gelu_inp_ptr), + inter_b_ptr, + (_gelu_checkpoint ? buf_2 : ff2_inp_ptr), + _stream); + + _ff2.Forward( + bsz_seq, (_gelu_checkpoint ? buf_2 : ff2_inp_ptr), output_w_ptr, out_ptr, _onemklQ); + + // layer output dropout. + if (_pre_or_postLayerNorm) { + _layer_output_dropout.ForwardWithBias( + bsz_seq, out_ptr, out_ptr, add_res_ptr, output_b_ptr, _stream); + } else { + _layer_output_dropout.ForwardWithBias( + bsz_seq, inp_norm_ptr, out_ptr, ff1_inp_ptr, output_b_ptr, _stream); + } + + if (!_pre_or_postLayerNorm) { + if (_layer_norm.UseMean()) { + _layer_norm.ForwardCheckpoint( + bsz_seq, out_ptr, inp_norm_ptr, norm_w_ptr, norm_b_ptr, _stream, true); + } else { + _layer_norm.Forward( + bsz_seq, out_ptr, inp_norm_ptr, norm_w_ptr, norm_b_ptr, _stream, true); + } + } +} + +template +void BertTransformerLayer::Backward(int bsz, + const T* grad_output_ptr, + const T* input_ptr, + const T* output_ptr, + const T* inp_norm_ptr, + const T* q_tf_ptr, + const T* k_tf_ptr, + const T* v_tf_ptr, + const T* soft_out_ptr, + const T* ctx_bufB_ptr, + const T* attn_o_inp_ptr, + const T* add_res_ptr, + const T* ff1_inp_ptr, + const T* gelu_inp_ptr, + const T* ff2_inp_ptr, + const T* input_mask_ptr, + const T* attn_qkvw_ptr, + const T* attn_ow_ptr, + const T* attn_nw_ptr, + const T* attn_nb_ptr, + const T* inter_w_ptr, + const T* inter_b_ptr, + const T* output_w_ptr, + const T* norm_w_ptr, + const T* norm_b_ptr, + + T* grad_input_ptr, + T* grad_attn_qkvw_ptr, + T* grad_attn_qkvb_ptr, + T* grad_attn_ow_ptr, + T* grad_attn_ob_ptr, + T* grad_attn_nw_ptr, + T* grad_attn_nb_ptr, + T* grad_inter_w_ptr, + T* grad_inter_b_ptr, + T* grad_output_w_ptr, + T* grad_output_b_ptr, + T* grad_norm_w_ptr, + T* grad_norm_b_ptr) +{ + if (!_stochastic_mode) _stream->wait(); + + T* workspace = static_cast(::SyclContext::Instance().GetWorkSpace()); + size_t small_buf_size = bsz * _seq_length * _hidden_size; + T* buf_0 = workspace; + T* buf_1 = buf_0 + small_buf_size; + T* buf_2 = buf_1 + small_buf_size; + T* buf_3 = buf_2 + small_buf_size; + + T* ff2_buf = (_gelu_checkpoint ? buf_3 + (bsz * _seq_length * _intermediate_size) + : buf_3 + small_buf_size); + T* ctx_bufB_ptr_recomp = ff2_buf + (_seq_length * _seq_length * bsz * _heads); + + sycl::queue* streams[2] = {_stream, _stream}; + + int bsz_seq = bsz * _seq_length; + int bsz_heads = bsz * _heads; + + if (!_pre_or_postLayerNorm) { + if (_layer_norm.UseMean()) + _layer_norm.Backward(bsz_seq, + grad_output_ptr, + norm_w_ptr, + grad_norm_w_ptr, + grad_norm_b_ptr, + streams, + buf_1, + inp_norm_ptr); + + else + _layer_norm.Backward(bsz_seq, + grad_output_ptr, + norm_w_ptr, + norm_b_ptr, + grad_norm_w_ptr, + grad_norm_b_ptr, + streams, + buf_1, + output_ptr); + } + + if (_pre_or_postLayerNorm) + _layer_output_dropout.Backward(bsz_seq, buf_0, grad_output_ptr, _stream); + else + _layer_output_dropout.Backward(bsz_seq, buf_0, buf_1, _stream); + + const T* layer_dropout_buf = _layer_output_dropout.HasDropout() + ? buf_0 + : (_pre_or_postLayerNorm ? grad_output_ptr : buf_1); + + if (_gelu_checkpoint) + _gelu.ForwardWithBiasAdd(bsz_seq, ff2_inp_ptr, inter_b_ptr, buf_2, _stream); + _ff2.Backward(bsz_seq, + layer_dropout_buf, + (_gelu_checkpoint ? buf_2 : ff2_inp_ptr), + output_w_ptr, + grad_output_w_ptr, + grad_output_b_ptr, + _onemklQ, + _stream, + ff2_buf); + + _gelu.Backward( + bsz_seq, ff2_buf, (_gelu_checkpoint ? ff2_inp_ptr : gelu_inp_ptr), inter_b_ptr, _stream); + + _ff1.Backward(bsz_seq, + ff2_buf, + ff1_inp_ptr, + inter_w_ptr, + grad_inter_w_ptr, + grad_inter_b_ptr, + _onemklQ, + _stream, + buf_3); + + if (!_pre_or_postLayerNorm) + launch_fused_add2(buf_2, buf_3, buf_1, bsz, _seq_length, _hidden_size, _stream); + + if (_pre_or_postLayerNorm) { + if (_attn_layer_norm.UseMean()) + _attn_layer_norm.BackwardFusedAdd(bsz_seq, + buf_3, + grad_output_ptr, + attn_nw_ptr, + grad_attn_nw_ptr, + grad_attn_nb_ptr, + streams, + buf_0, + add_res_ptr); + + else + _attn_layer_norm.BackwardFusedAdd(bsz_seq, + buf_3, + grad_output_ptr, + attn_nw_ptr, + attn_nb_ptr, + grad_attn_nw_ptr, + grad_attn_nb_ptr, + streams, + buf_0, + ff1_inp_ptr); + } else { + if (_attn_layer_norm.UseMean()) + _attn_layer_norm.Backward(bsz_seq, + buf_2, + attn_nw_ptr, + grad_attn_nw_ptr, + grad_attn_nb_ptr, + streams, + buf_0, + add_res_ptr); + + else + _attn_layer_norm.Backward(bsz_seq, + buf_2, + attn_nw_ptr, + attn_nb_ptr, + grad_attn_nw_ptr, + grad_attn_nb_ptr, + streams, + buf_0, + ff1_inp_ptr); + } + + _attn_output_dropout.Backward(bsz_seq, buf_2, buf_0, _stream); + + T* attn_output_dropout_buf = _attn_output_dropout.HasDropout() ? buf_2 : buf_0; + + _attn_out_linear.Backward(bsz_seq, + attn_output_dropout_buf, + attn_o_inp_ptr, + attn_ow_ptr, + grad_attn_ow_ptr, + grad_attn_ob_ptr, + _onemklQ, + _stream, + buf_1); + + launch_transform_0213(buf_2, buf_1, bsz, _seq_length, _hidden_size, _heads, _stream); + + if (_attn_prob_dropout.HasDropout()) { + if (_attn_dropout_checkpoint) { + _attn_prob_dropout.Forward( + bsz_heads * _seq_length, ctx_bufB_ptr_recomp, soft_out_ptr, _stream, true); + } + + _attn_context.Backward(bsz_heads, + buf_2, + v_tf_ptr, + (_attn_dropout_checkpoint ? ctx_bufB_ptr_recomp : ctx_bufB_ptr), + _onemklQ, + buf_3, + ff2_buf); + } else { + _attn_context.Backward(bsz_heads, buf_2, v_tf_ptr, soft_out_ptr, _onemklQ, buf_3, ff2_buf); + } + + _attn_prob_dropout.Backward(bsz_heads * _seq_length, ff2_buf, _stream); + + _softmax.Backward(bsz, ff2_buf, soft_out_ptr, _stream); + + _attn_scores.Backward(bsz_heads, ff2_buf, k_tf_ptr, q_tf_ptr, _onemklQ, buf_2, buf_1); + + // the size of input (buf_1) relates to the last argument (trans_count), in + // this case, buf_1 connected with buf_2 and buf_3 are all inputs + launch_transform4d_0213(ff2_buf, buf_1, bsz, _heads, _seq_length, _hidden_size, _stream, 3); + + T* grad_out_buffer = (T*)malloc_shared(10 * sizeof(T), *_stream); + T* input_buffer = (T*)malloc_shared(10 * sizeof(T), *_stream); + T* weight_buffer = (T*)malloc_shared(10 * sizeof(T), *_stream); + T* grad_weight_buffer = (T*)malloc_shared(10 * sizeof(T), *_stream); + T* grad_bias_buffer = (T*)malloc_shared(10 * sizeof(T), *_stream); + if (_pre_or_postLayerNorm) { + _qkv_linear.Backward(bsz_seq, + ff2_buf, + inp_norm_ptr, + attn_qkvw_ptr, + grad_attn_qkvw_ptr, + grad_attn_qkvb_ptr, + _onemklQ, + _stream, + buf_2); + } else { + _qkv_linear.Backward(bsz_seq, + ff2_buf, + input_ptr, + attn_qkvw_ptr, + grad_attn_qkvw_ptr, + grad_attn_qkvb_ptr, + _onemklQ, + _stream, + buf_2); + } + + if (_pre_or_postLayerNorm) { + if (_layer_norm.UseMean()) { + _layer_norm.BackwardFusedAdd(bsz_seq, + buf_2, + buf_0, + norm_w_ptr, + grad_norm_w_ptr, + grad_norm_b_ptr, + streams, + grad_input_ptr, + input_ptr); + } + + else { + _layer_norm.BackwardFusedAdd(bsz_seq, + buf_2, + buf_0, + norm_w_ptr, + norm_b_ptr, + grad_norm_w_ptr, + grad_norm_b_ptr, + streams, + grad_input_ptr, + inp_norm_ptr); + } + } else { + launch_fused_add2(grad_input_ptr, buf_2, buf_0, bsz, _seq_length, _hidden_size, _stream); + _stream->submit([&](sycl::handler& cgh) { + cgh.single_task([=]() { + for (int i = 0; i < 10; ++i) { + grad_out_buffer[i] = ff2_buf[i]; + input_buffer[i] = inp_norm_ptr[i]; + weight_buffer[i] = attn_qkvw_ptr[i]; + grad_weight_buffer[i] = grad_attn_qkvw_ptr[i]; + grad_bias_buffer[i] = grad_attn_qkvb_ptr[i]; + } + }); + }); + } + + _stream->wait(); +} + +template +void BertTransformerLayer::SetTrainingMode(bool training) +{ + // Dropout will be skipped when not in training model. + _attn_prob_dropout.SetTrainingMode(training); + _attn_output_dropout.SetTrainingMode(training); + _layer_output_dropout.SetTrainingMode(training); +} + +template +void BertTransformerLayer::SetIntermediateBuffers(uint8_t* attn_prob_dropout_mask_ptr, + uint8_t* attn_output_dropout_mask_ptr, + uint8_t* layer_output_dropout_mask_ptr, + T* attn_layer_norm_var, + T* attn_layer_norm_mean, + T* layer_norm_var, + T* layer_norm_mean) +{ + _attn_prob_dropout.SetMask(attn_prob_dropout_mask_ptr); + _attn_output_dropout.SetMask(attn_output_dropout_mask_ptr); + _layer_output_dropout.SetMask(layer_output_dropout_mask_ptr); + + _attn_layer_norm.SetVar(attn_layer_norm_var); + _attn_layer_norm.SetMean(attn_layer_norm_mean); + _layer_norm.SetVar(layer_norm_var); + _layer_norm.SetMean(layer_norm_mean); +} + +template +void BertTransformerLayer::SetSeqLength(int seq_len) +{ + _seq_length = seq_len; + + _softmax.SetSeqLength(_seq_length); + _attn_prob_dropout.SetDimension(_seq_length); + _attn_scores.SetConfig(_seq_length, _seq_length, _hidden_size / _heads); + _attn_context.SetConfig(_hidden_size / _heads, _seq_length, _seq_length); +} + +template +int create_transformer_layer(int layer_id, + int batch_size, + int hidden_dim, + int num_heads, + int intermediate_size, + float attn_dropout_ratio, + float hidden_dropout_ratio, + float layer_norm_eps, + int seed, + bool pre_or_postLayerNorm, + bool test_gemm, + bool attn_dropout_checkpoint, + bool normalize_invertible, + bool gelu_checkpoint, + bool stochastic_mode) +{ + ::SyclContext::Instance().SetSeed(seed); + ::SyclContext::Instance().TestGemmFP16( + test_gemm, batch_size, init_seq_length, num_heads, hidden_dim / num_heads); + auto layer = std::make_shared>(layer_id, + batch_size, + hidden_dim, + num_heads, + intermediate_size, + init_seq_length, + attn_dropout_ratio, + hidden_dropout_ratio, + layer_norm_eps, + pre_or_postLayerNorm, + ::SyclContext::Instance().GetGemmAlgos(), + attn_dropout_checkpoint, + normalize_invertible, + gelu_checkpoint, + stochastic_mode); + s_transformer_layers[layer_id] = layer; + + std::string dtype = (std::is_same::value) + ? "half" + : ((std::is_same::value) ? "bf16" : "float"); + + std::cout << "layer #" << layer_id << " is created with date type [" << dtype << "]." + << std::endl; + + return 0; +} + +template +std::vector ds_transformer_forward(int layer_id, + const torch::Tensor& input, + const torch::Tensor& input_mask, + const torch::Tensor& attn_qkvw, + const torch::Tensor& attn_qkvb, + const torch::Tensor& attn_ow, + const torch::Tensor& attn_ob, + const torch::Tensor& attn_nw, + const torch::Tensor& attn_nb, + const torch::Tensor& inter_w, + const torch::Tensor& inter_b, + const torch::Tensor& output_w, + const torch::Tensor& output_b, + const torch::Tensor& norm_w, + const torch::Tensor& norm_b, + bool training_mode, + bool prelayernorm, + bool attn_dropout_checkpoint, + bool normalize_invertible, + bool gelu_checkpoint) +{ + CHECK_INPUT(input); + CHECK_INPUT(input_mask); + CHECK_INPUT(attn_qkvw); + CHECK_INPUT(attn_qkvb); + CHECK_INPUT(attn_ow); + CHECK_INPUT(attn_ob); + CHECK_INPUT(attn_nw); + CHECK_INPUT(attn_nb); + CHECK_INPUT(inter_w); + CHECK_INPUT(inter_b); + CHECK_INPUT(output_w); + CHECK_INPUT(output_b); + CHECK_INPUT(norm_w); + CHECK_INPUT(norm_b); + + int bsz = input.size(0); + + const T* input_ptr = (const T*)input.data_ptr(); + const T* input_mask_ptr = (const T*)input_mask.data_ptr(); + const T* attn_qkvw_ptr = (const T*)attn_qkvw.data_ptr(); + const T* attn_qkvb_ptr = (const T*)attn_qkvb.data_ptr(); + const T* attn_ow_ptr = (const T*)attn_ow.data_ptr(); + const T* attn_ob_ptr = (const T*)attn_ob.data_ptr(); + const T* attn_nw_ptr = (const T*)attn_nw.data_ptr(); + const T* attn_nb_ptr = (const T*)attn_nb.data_ptr(); + const T* inter_w_ptr = (const T*)inter_w.data_ptr(); + const T* inter_b_ptr = (const T*)inter_b.data_ptr(); + const T* output_w_ptr = (const T*)output_w.data_ptr(); + const T* output_b_ptr = (const T*)output_b.data_ptr(); + const T* norm_w_ptr = (const T*)norm_w.data_ptr(); + const T* norm_b_ptr = (const T*)norm_b.data_ptr(); + + auto output = torch::empty_like(input); + T* out_ptr = (T*)output.data_ptr(); + + auto options = torch::TensorOptions() + .dtype(input.options().dtype()) + .layout(torch::kStrided) + .device(torch::kXPU) + .requires_grad(true); + + auto uint8_options = torch::TensorOptions() + .dtype(torch::kInt8) + .layout(torch::kStrided) + .device(torch::kXPU) + .requires_grad(false); + + std::shared_ptr> layer = + std::static_pointer_cast>(s_transformer_layers[layer_id]); + + int seq_len = layer->GetSeqLength(); + if (input.size(1) != seq_len) { + seq_len = input.size(1); + layer->SetSeqLength(seq_len); + } + + auto workspace = torch::empty({(long)get_workspace_size(bsz, + seq_len, + layer->GetHiddenSize(), + layer->GetIntermediateSize(), + layer->GetNumHeads(), + layer->IsTrainingMode(), + layer->GeluCheckpoint())}, + options); + ::SyclContext::Instance().SetWorkSpace((T*)workspace.data_ptr()); + + auto inp_norm = ((prelayernorm || !normalize_invertible) ? torch::empty_like(input) : output); + auto add_res = (normalize_invertible ? inp_norm : torch::empty_like(input)); + auto attn_o_inp = torch::empty_like(input); + auto qkv_tf = torch::empty({(bsz * seq_len), output_w.size(0) * 3}, options); + + auto attn_prob_dropout_mask = + torch::empty({(bsz * layer->GetNumHeads() * seq_len), seq_len}, uint8_options); + auto attn_output_dropout_mask = + torch::empty({(bsz * seq_len), layer->GetHiddenSize()}, uint8_options); + auto layer_output_dropout_mask = + torch::empty({(bsz * seq_len), layer->GetHiddenSize()}, uint8_options); + + auto attn_layer_norm_var = torch::empty({(bsz * seq_len)}, options); + auto attn_layer_norm_mean = torch::empty({(bsz * seq_len)}, options); + auto layer_norm_var = torch::empty({(bsz * seq_len)}, options); + auto layer_norm_mean = torch::empty({(bsz * seq_len)}, options); + + T* inp_norm_ptr = (T*)inp_norm.data_ptr(); + T* add_res_ptr = (T*)add_res.data_ptr(); + T* q_tf_ptr = (T*)qkv_tf.data_ptr(); + T* k_tf_ptr = q_tf_ptr + (bsz * seq_len * output_w.size(0)); //(T*)k_tf.data_ptr(); + T* v_tf_ptr = k_tf_ptr + (bsz * seq_len * output_w.size(0)); //(T*)v_tf.data_ptr(); + T* attn_o_inp_ptr = (T*)attn_o_inp.data_ptr(); + + torch::Tensor ff2_inp = torch::empty({(bsz * seq_len), output_w.size(1)}, options); + torch::Tensor gelu_inp = + (gelu_checkpoint ? ff2_inp : torch::empty({(bsz * seq_len), output_w.size(1)}, options)); + auto ff1_inp = torch::empty_like(input); + T* ff2_inp_ptr = (T*)ff2_inp.data_ptr(); + T* gelu_inp_ptr = (T*)gelu_inp.data_ptr(); + T* ff1_inp_ptr = (T*)ff1_inp.data_ptr(); + + torch::Tensor soft_out = + torch::empty({(bsz * layer->GetNumHeads() * seq_len), seq_len}, options); + torch::Tensor ctx_bufB = + (attn_dropout_checkpoint + ? soft_out + : torch::empty({(bsz * layer->GetNumHeads() * seq_len), seq_len}, options)); + T* soft_out_ptr = (T*)soft_out.data_ptr(); + T* ctx_bufB_ptr = (T*)ctx_bufB.data_ptr(); + + layer->SetTrainingMode(training_mode); + layer->SetIntermediateBuffers((uint8_t*)attn_prob_dropout_mask.data_ptr(), + (uint8_t*)attn_output_dropout_mask.data_ptr(), + (uint8_t*)layer_output_dropout_mask.data_ptr(), + (T*)attn_layer_norm_var.data_ptr(), + (T*)attn_layer_norm_mean.data_ptr(), + (T*)layer_norm_var.data_ptr(), + (T*)layer_norm_mean.data_ptr()); + + layer->Forward(bsz, + input_ptr, + input_mask_ptr, + attn_qkvw_ptr, + attn_qkvb_ptr, + attn_ow_ptr, + attn_ob_ptr, + attn_nw_ptr, + attn_nb_ptr, + inter_w_ptr, + inter_b_ptr, + output_w_ptr, + output_b_ptr, + norm_w_ptr, + norm_b_ptr, + out_ptr, + inp_norm_ptr, + q_tf_ptr, + k_tf_ptr, + v_tf_ptr, + soft_out_ptr, + ctx_bufB_ptr, + attn_o_inp_ptr, + add_res_ptr, + ff1_inp_ptr, + gelu_inp_ptr, + ff2_inp_ptr); + + return {output, + inp_norm, + qkv_tf, + soft_out, + ctx_bufB, + attn_o_inp, + add_res, + ff1_inp, + gelu_inp, + ff2_inp, + attn_prob_dropout_mask, + attn_output_dropout_mask, + layer_output_dropout_mask, + attn_layer_norm_var, + attn_layer_norm_mean, + layer_norm_var, + layer_norm_mean}; +} + +template +std::vector ds_transformer_backward(int layer_id, + const torch::Tensor& grad_output, + const torch::Tensor& output, + const torch::Tensor& inp_norm, + const torch::Tensor& qkv_tf, + const torch::Tensor& soft_out, + const torch::Tensor& ctx_bufB, + const torch::Tensor& attn_o_inp, + const torch::Tensor& add_res, + const torch::Tensor& ff1_inp, + const torch::Tensor& gelu_inp, + const torch::Tensor& ff2_inp, + const torch::Tensor& attn_prob_dropout_mask, + const torch::Tensor& attn_output_dropout_mask, + const torch::Tensor& layer_output_dropout_mask, + const torch::Tensor& attn_layer_norm_var, + const torch::Tensor& attn_layer_norm_mean, + const torch::Tensor& layer_norm_var, + const torch::Tensor& layer_norm_mean, + const torch::Tensor& input, + const torch::Tensor& input_mask, + const torch::Tensor& attn_qkvw, + const torch::Tensor& attn_qkvb, + const torch::Tensor& attn_ow, + const torch::Tensor& attn_ob, + const torch::Tensor& attn_nw, + const torch::Tensor& attn_nb, + const torch::Tensor& inter_w, + const torch::Tensor& inter_b, + const torch::Tensor& output_w, + const torch::Tensor& output_b, + const torch::Tensor& norm_w, + const torch::Tensor& norm_b) +{ + auto g_output = grad_output.contiguous(); + CHECK_INPUT(g_output); + CHECK_INPUT(output); + CHECK_INPUT(inp_norm); + CHECK_INPUT(qkv_tf); + CHECK_INPUT(add_res); + CHECK_INPUT(soft_out); + CHECK_INPUT(ctx_bufB); + CHECK_INPUT(attn_o_inp); + CHECK_INPUT(ff1_inp); + CHECK_INPUT(gelu_inp); + CHECK_INPUT(ff2_inp); + CHECK_INPUT(input); + CHECK_INPUT(input_mask); + CHECK_INPUT(attn_qkvw); + CHECK_INPUT(attn_qkvb); + CHECK_INPUT(attn_ow); + CHECK_INPUT(attn_ob); + CHECK_INPUT(attn_nw); + CHECK_INPUT(attn_nb); + CHECK_INPUT(inter_w); + CHECK_INPUT(inter_b); + CHECK_INPUT(output_w); + CHECK_INPUT(output_b); + CHECK_INPUT(norm_w); + CHECK_INPUT(norm_b); + + int bsz = g_output.size(0); + + std::shared_ptr> layer = + std::static_pointer_cast>(s_transformer_layers[layer_id]); + + int seq_len = layer->GetSeqLength(); + if (g_output.size(1) != seq_len) { + seq_len = g_output.size(1); + layer->SetSeqLength(seq_len); + } + auto options = torch::TensorOptions() + .dtype(g_output.options().dtype()) + .layout(torch::kStrided) + .device(torch::kXPU) + .requires_grad(true); + auto workspace = torch::empty({(long)get_workspace_size(bsz, + seq_len, + layer->GetHiddenSize(), + layer->GetIntermediateSize(), + layer->GetNumHeads(), + layer->IsTrainingMode(), + layer->GeluCheckpoint())}, + options); + ::SyclContext::Instance().SetWorkSpace((T*)workspace.data_ptr()); + + auto grad_input = torch::empty_like(input); + auto grad_attn_qkvw = torch::empty_like(attn_qkvw); + auto grad_attn_qkvb = torch::empty_like(attn_qkvb); + auto grad_attn_ow = torch::empty_like(attn_ow); + auto grad_attn_ob = torch::empty_like(attn_ob); + auto grad_attn_nw = torch::empty_like(attn_nw); + auto grad_attn_nb = torch::empty_like(attn_nb); + auto grad_inter_w = torch::empty_like(inter_w); + auto grad_inter_b = torch::empty_like(inter_b); + auto grad_output_w = torch::empty_like(output_w); + auto grad_output_b = torch::empty_like(output_b); + auto grad_norm_w = torch::empty_like(norm_w); + auto grad_norm_b = torch::empty_like(norm_b); + + // inputs. + const T* grad_output_ptr = (const T*)g_output.data_ptr(); + const T* input_ptr = (const T*)input.data_ptr(); + const T* output_ptr = (const T*)output.data_ptr(); + const T* inp_norm_ptr = (const T*)inp_norm.data_ptr(); + const T* q_tf_ptr = (const T*)qkv_tf.data_ptr(); + const T* add_res_ptr = (const T*)add_res.data_ptr(); + const T* k_tf_ptr = + q_tf_ptr + (bsz * layer->GetSeqLength() * output_w.size(0)); //(const T*)k_tf.data_ptr(); + const T* v_tf_ptr = + k_tf_ptr + (bsz * layer->GetSeqLength() * output_w.size(0)); //(const T*)v_tf.data_ptr(); + const T* ff1_inp_ptr = (const T*)ff1_inp.data_ptr(); + const T* gelu_inp_ptr = (const T*)gelu_inp.data_ptr(); + const T* ff2_inp_ptr = (const T*)ff2_inp.data_ptr(); + const T* ctx_bufB_ptr = (const T*)ctx_bufB.data_ptr(); + const T* soft_out_ptr = (const T*)soft_out.data_ptr(); + const T* attn_o_inp_ptr = (const T*)attn_o_inp.data_ptr(); + const T* input_mask_ptr = (const T*)input_mask.data_ptr(); + const T* attn_qkvw_ptr = (const T*)attn_qkvw.data_ptr(); + const T* attn_ow_ptr = (const T*)attn_ow.data_ptr(); + const T* attn_nw_ptr = (const T*)attn_nw.data_ptr(); + const T* attn_nb_ptr = (const T*)attn_nb.data_ptr(); + const T* inter_w_ptr = (const T*)inter_w.data_ptr(); + const T* inter_b_ptr = (const T*)inter_b.data_ptr(); + const T* output_w_ptr = (const T*)output_w.data_ptr(); + const T* norm_w_ptr = (const T*)norm_w.data_ptr(); + const T* norm_b_ptr = (const T*)norm_b.data_ptr(); + + // outputs. + T* grad_input_ptr = (T*)grad_input.data_ptr(); + T* grad_attn_qkvw_ptr = (T*)grad_attn_qkvw.data_ptr(); + T* grad_attn_qkvb_ptr = (T*)grad_attn_qkvb.data_ptr(); + T* grad_attn_ow_ptr = (T*)grad_attn_ow.data_ptr(); + T* grad_attn_ob_ptr = (T*)grad_attn_ob.data_ptr(); + T* grad_attn_nw_ptr = (T*)grad_attn_nw.data_ptr(); + T* grad_attn_nb_ptr = (T*)grad_attn_nb.data_ptr(); + T* grad_inter_w_ptr = (T*)grad_inter_w.data_ptr(); + T* grad_inter_b_ptr = (T*)grad_inter_b.data_ptr(); + T* grad_output_w_ptr = (T*)grad_output_w.data_ptr(); + T* grad_output_b_ptr = (T*)grad_output_b.data_ptr(); + T* grad_norm_w_ptr = (T*)grad_norm_w.data_ptr(); + T* grad_norm_b_ptr = (T*)grad_norm_b.data_ptr(); + + layer->SetIntermediateBuffers((uint8_t*)attn_prob_dropout_mask.data_ptr(), + (uint8_t*)attn_output_dropout_mask.data_ptr(), + (uint8_t*)layer_output_dropout_mask.data_ptr(), + (T*)attn_layer_norm_var.data_ptr(), + (T*)attn_layer_norm_mean.data_ptr(), + (T*)layer_norm_var.data_ptr(), + (T*)layer_norm_mean.data_ptr()); + + layer->Backward(bsz, + grad_output_ptr, + input_ptr, + output_ptr, + inp_norm_ptr, + q_tf_ptr, + k_tf_ptr, + v_tf_ptr, + soft_out_ptr, + ctx_bufB_ptr, + attn_o_inp_ptr, + add_res_ptr, + ff1_inp_ptr, + gelu_inp_ptr, + ff2_inp_ptr, + input_mask_ptr, + attn_qkvw_ptr, + attn_ow_ptr, + attn_nw_ptr, + attn_nb_ptr, + inter_w_ptr, + inter_b_ptr, + output_w_ptr, + norm_w_ptr, + norm_b_ptr, + + grad_input_ptr, + grad_attn_qkvw_ptr, + grad_attn_qkvb_ptr, + grad_attn_ow_ptr, + grad_attn_ob_ptr, + grad_attn_nw_ptr, + grad_attn_nb_ptr, + grad_inter_w_ptr, + grad_inter_b_ptr, + grad_output_w_ptr, + grad_output_b_ptr, + grad_norm_w_ptr, + grad_norm_b_ptr); + + return {grad_input, + grad_attn_qkvw, + grad_attn_qkvb, + grad_attn_ow, + grad_attn_ob, + grad_attn_nw, + grad_attn_nb, + grad_inter_w, + grad_inter_b, + grad_output_w, + grad_output_b, + grad_norm_w, + grad_norm_b}; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("forward_fp32", + &ds_transformer_forward, + "DeepSpeed Transformer forward with fp32 (DPCPP)"); + m.def("forward_fp16", + &ds_transformer_forward, + "DeepSpeed Transformer forward with fp16 (DPCPP)"); + m.def("forward_bf16", + &ds_transformer_forward, + "DeepSpeed Transformer forward with bf16 (DPCPP)"); + m.def("backward_fp32", + &ds_transformer_backward, + "DeepSpeed Transformer backward with fp32 (DPCPP)"); + m.def("backward_fp16", + &ds_transformer_backward, + "DeepSpeed Transformer backward with fp16 (DPCPP)"); + m.def("backward_bf16", + &ds_transformer_backward, + "DeepSpeed Transformer backward with bf16 (DPCPP)"); + m.def("create_transformer_layer_fp32", + &create_transformer_layer, + "Create DeepSpeed Transformer Transformer Layer with fp32 (DPCPP)"); + m.def("create_transformer_layer_fp16", + &create_transformer_layer, + "Create DeepSpeed Transformer Transformer Layer with fp16 (DPCPP)"); + m.def("create_transformer_layer_bf16", + &create_transformer_layer, + "Create DeepSpeed Transformer Transformer Layer with bf16 (DPCPP)"); +} diff --git a/intel_extension_for_deepspeed/op_builder/csrc/transformer/sycl/gelu_kernels.dp.cpp b/intel_extension_for_deepspeed/op_builder/csrc/transformer/sycl/gelu_kernels.dp.cpp new file mode 100644 index 0000000..ebcd566 --- /dev/null +++ b/intel_extension_for_deepspeed/op_builder/csrc/transformer/sycl/gelu_kernels.dp.cpp @@ -0,0 +1,447 @@ +#include +using namespace cl::sycl; +#include + +using bf16 = sycl::ext::oneapi::experimental::bfloat16; + +inline float gelu(const float x) +{ + const float sqrt_param = 0.79788456080286535587989211986876f; + const float mul_param = 0.044715; + return x * 0.5f * (1.0f + tanh(sqrt_param * (x + mul_param * x * x * x))); +} + +inline float d_gelu(const float x) +{ + const float sqrt_param = 0.79788456080286535587989211986876f; + const float mul_param = 0.044715; + + float x2mul = x * x * mul_param; + float tan_h = tanh(sqrt_param * (x + x * x2mul)); + float dg1 = 0.5f * (1.0f + tan_h); + float dg2 = x * 0.5f * sqrt_param * (1 - tan_h * tan_h); + float dg3 = dg2 * 3 * x2mul; + return (dg1 + dg2 + dg3); +} + +/* + Fused bias add with GELU + + Loads a vector of 4 elements each iteration, for stride + iterations. It was written with the intention to launch 256 thread + threadblocks, so to launch for bert-large, we would set ITERATIONS + to 4. This is currently done automatically as a heuristic, setting + the number of iterations as blocks of 1024. + + For FP16, the values are loaded from memory as half, but converted + to FP32 for the arithmetic itself, to prevent numerous overflow on + the intermediate hyperbolic tangent, since there's no intrinsic + that computes it directly. +*/ + +void gelu_kernel(const float* input, + float* vals, + int row_stride, + int iterations, + nd_item<3> item_ct1) +{ + int row = item_ct1.get_group(2); + int id = item_ct1.get_local_id(2); + int loop_stride = item_ct1.get_local_range(2); + + const float4* input_cast = reinterpret_cast(input); + float4* vals_cast = reinterpret_cast(vals); + + for (int i = 0; i < iterations; i++) { + if (i * loop_stride + id < row_stride) { + float4 data = input_cast[row * row_stride + i * loop_stride + id]; + + data.x() = gelu(data.x()); + data.y() = gelu(data.y()); + data.z() = gelu(data.z()); + data.w() = gelu(data.w()); + + vals_cast[row * row_stride + i * loop_stride + id] = data; + } + } +} + +void gelu_kernel(const half* input, half* vals, int row_stride, int iterations, nd_item<3> item_ct1) +{ + int row = item_ct1.get_group(2); + int id = item_ct1.get_local_id(2); + int loop_stride = item_ct1.get_local_range(2); + + const float2* input_cast = reinterpret_cast(input); + float2* vals_cast = reinterpret_cast(vals); + + for (int i = 0; i < iterations; i++) { + if (i * loop_stride + id < row_stride) { + float2 vals_vec = input_cast[row * row_stride + i * loop_stride + id]; + + half2* vals_half = reinterpret_cast(&vals_vec); + + float2 low_data = vals_half[0].convert(); // __half22float2(vals_half[0]); + float2 high_data = vals_half[1].convert(); // __half22float2(vals_half[1]); + + low_data.x() = gelu(low_data.x()); + low_data.y() = gelu(low_data.y()); + high_data.x() = gelu(high_data.x()); + high_data.y() = gelu(high_data.y()); + + vals_half[0] = low_data.convert(); // __float22half2_rn(low_data); + vals_half[1] = high_data.convert(); // __float22half2_rn(high_data); + + vals_cast[row * row_stride + i * loop_stride + id] = vals_vec; + } + } +} + +void fused_bias_gelu(const float* input, + const float* bias, + float* vals, + int row_stride, + int iterations, + nd_item<3> item_ct1) +{ + int row = item_ct1.get_group(2); + int id = item_ct1.get_local_id(2); + int loop_stride = item_ct1.get_local_range(2); + + const float4* input_cast = reinterpret_cast(input); + float4* vals_cast = reinterpret_cast(vals); + const float4* bias_cast = reinterpret_cast(bias); + + for (int i = 0; i < iterations; i++) { + if (i * loop_stride + id < row_stride) { + float4 data = input_cast[row * row_stride + i * loop_stride + id]; + float4 bias_data = bias_cast[i * loop_stride + id]; + + data.x() += bias_data.x(); + data.y() += bias_data.y(); + data.z() += bias_data.z(); + data.w() += bias_data.w(); + + data.x() = gelu(data.x()); + data.y() = gelu(data.y()); + data.z() = gelu(data.z()); + data.w() = gelu(data.w()); + + vals_cast[row * row_stride + i * loop_stride + id] = data; + } + } +} + +void fused_bias_gelu(const bf16* input, + const bf16* bias, + bf16* vals, + int row_stride, + int iterations, + nd_item<3> item_ct1) +{ + int row = item_ct1.get_group(2); + int id = item_ct1.get_local_id(2); + int loop_stride = item_ct1.get_local_range(2); + + const ushort4* input_cast = reinterpret_cast(input); + ushort4* vals_cast = reinterpret_cast(vals); + const ushort4* bias_cast = reinterpret_cast(bias); + + for (int i = 0; i < iterations; i++) { + if (i * loop_stride + id < row_stride) { + ushort4 vals_vec = input_cast[row * row_stride + i * loop_stride + id]; + ushort4 bias_vec = bias_cast[i * loop_stride + id]; + + float4 data = {bf16::to_float(vals_vec.x()), + bf16::to_float(vals_vec.y()), + bf16::to_float(vals_vec.z()), + bf16::to_float(vals_vec.w())}; + float4 bias = {bf16::to_float(bias_vec.x()), + bf16::to_float(bias_vec.y()), + bf16::to_float(bias_vec.z()), + bf16::to_float(bias_vec.w())}; + + data += bias; + + data.x() = gelu(data.x()); + data.y() = gelu(data.y()); + data.z() = gelu(data.z()); + data.w() = gelu(data.w()); + + vals_vec.x() = bf16::from_float(data.x()); + vals_vec.y() = bf16::from_float(data.y()); + vals_vec.z() = bf16::from_float(data.z()); + vals_vec.w() = bf16::from_float(data.w()); + + vals_cast[row * row_stride + i * loop_stride + id] = vals_vec; + } + } +} + +void fused_bias_gelu(const half* input, + const half* bias, + half* vals, + int row_stride, + int iterations, + nd_item<3> item_ct1) +{ + int row = item_ct1.get_group(2); + int id = item_ct1.get_local_id(2); + int loop_stride = item_ct1.get_local_range(2); + + const float2* input_cast = reinterpret_cast(input); + float2* vals_cast = reinterpret_cast(vals); + const float2* bias_cast = reinterpret_cast(bias); + + for (int i = 0; i < iterations; i++) { + if (i * loop_stride + id < row_stride) { + float2 vals_vec = input_cast[row * row_stride + i * loop_stride + id]; + float2 bias_vec = bias_cast[i * loop_stride + id]; + + half2* vals_half = reinterpret_cast(&vals_vec); + half2* bias_half = reinterpret_cast(&bias_vec); + + float2 low_data = vals_half[0].convert(); // __half22float2(vals_half[0]); + float2 high_data = vals_half[1].convert(); // __half22float2(vals_half[1]); + + float2 low_bias = bias_half[0].convert(); // __half22float2(bias_half[0]); + float2 high_bias = bias_half[1].convert(); // __half22float2(bias_half[1]); + + low_data.x() += low_bias.x(); + low_data.y() += low_bias.y(); + high_data.x() += high_bias.x(); + high_data.y() += high_bias.y(); + + low_data.x() = gelu(low_data.x()); + low_data.y() = gelu(low_data.y()); + high_data.x() = gelu(high_data.x()); + high_data.y() = gelu(high_data.y()); + + vals_half[0] = low_data.convert(); // __float22half2_rn(low_data); + vals_half[1] = high_data.convert(); // __float22half2_rn(high_data); + + vals_cast[row * row_stride + i * loop_stride + id] = vals_vec; + } + } +} + +void d_gelu_func(float* d_output, + const float* gelu_input, + const float* bias, + int row_stride, + int iterations, + nd_item<3> item_ct1) +{ + int row = item_ct1.get_group(2); + int id = item_ct1.get_local_id(2); + int loop_stride = item_ct1.get_local_range(2); + + float4* d_output_cast = reinterpret_cast(d_output); + const float4* gelu_input_cast = reinterpret_cast(gelu_input); + const float4* bias_cast = reinterpret_cast(bias); + + for (int i = 0; i < iterations; i++) { + if (i * loop_stride + id < row_stride) { + float4 output_data = d_output_cast[row * row_stride + i * loop_stride + id]; + float4 gelu_input_data = gelu_input_cast[row * row_stride + i * loop_stride + id]; + float4 bias_data = bias_cast[i * loop_stride + id]; + + gelu_input_data.x() += bias_data.x(); + gelu_input_data.y() += bias_data.y(); + gelu_input_data.z() += bias_data.z(); + gelu_input_data.w() += bias_data.w(); + + output_data.x() *= d_gelu(gelu_input_data.x()); + output_data.y() *= d_gelu(gelu_input_data.y()); + output_data.z() *= d_gelu(gelu_input_data.z()); + output_data.w() *= d_gelu(gelu_input_data.w()); + + d_output_cast[row * row_stride + i * loop_stride + id] = output_data; + } + } +} + +void d_gelu_func(bf16* d_output, + const bf16* gelu_input, + const bf16* bias, + int row_stride, + int iterations, + nd_item<3> item_ct1) +{ + int row = item_ct1.get_group(2); + int id = item_ct1.get_local_id(2); + int loop_stride = item_ct1.get_local_range(2); + + ushort4* d_output_cast = reinterpret_cast(d_output); + const ushort4* gelu_input_cast = reinterpret_cast(gelu_input); + const ushort4* bias_cast = reinterpret_cast(bias); + + for (int i = 0; i < iterations; i++) { + if (i * loop_stride + id < row_stride) { + ushort4 output_vec = d_output_cast[row * row_stride + i * loop_stride + id]; + ushort4 gelu_input_vec = gelu_input_cast[row * row_stride + i * loop_stride + id]; + ushort4 bias_vec = bias_cast[i * loop_stride + id]; + + float4 gelu_input_data = {bf16::to_float(gelu_input_vec.x()), + bf16::to_float(gelu_input_vec.y()), + bf16::to_float(gelu_input_vec.z()), + bf16::to_float(gelu_input_vec.w())}; + float4 bias_data = { + bf16::to_float(bias_vec.x()), + bf16::to_float(bias_vec.y()), + bf16::to_float(bias_vec.z()), + bf16::to_float(bias_vec.w()), + }; + float4 output_data = { + bf16::to_float(output_vec.x()), + bf16::to_float(output_vec.y()), + bf16::to_float(output_vec.z()), + bf16::to_float(output_vec.w()), + }; + + gelu_input_data.x() += bias_data.x(); + gelu_input_data.y() += bias_data.y(); + gelu_input_data.z() += bias_data.z(); + gelu_input_data.w() += bias_data.w(); + + output_data.x() *= d_gelu(gelu_input_data.x()); + output_data.y() *= d_gelu(gelu_input_data.y()); + output_data.z() *= d_gelu(gelu_input_data.z()); + output_data.w() *= d_gelu(gelu_input_data.w()); + + output_vec.x() = bf16::from_float(output_data.x()); + output_vec.y() = bf16::from_float(output_data.y()); + output_vec.z() = bf16::from_float(output_data.z()); + output_vec.w() = bf16::from_float(output_data.w()); + d_output_cast[row * row_stride + i * loop_stride + id] = output_vec; + } + } +} + +void d_gelu_func(half* d_output, + const half* gelu_input, + const half* bias, + int row_stride, + int iterations, + nd_item<3> item_ct1) +{ + int row = item_ct1.get_group(2); + int id = item_ct1.get_local_id(2); + int loop_stride = item_ct1.get_local_range(2); + + float2* d_output_cast = reinterpret_cast(d_output); + const float2* gelu_input_cast = reinterpret_cast(gelu_input); + const float2* bias_cast = reinterpret_cast(bias); + +#pragma unroll + for (int i = 0; i < iterations; i++) { + if (i * loop_stride + id < row_stride) { + float2 output_data = d_output_cast[row * row_stride + i * loop_stride + id]; + float2 gelu_input_data = gelu_input_cast[row * row_stride + i * loop_stride + id]; + float2 bias_vec = bias_cast[i * loop_stride + id]; + + half2* output_data_half = reinterpret_cast(&output_data); + half2* gelu_input_data_half = reinterpret_cast(&gelu_input_data); + half2* bias_half = reinterpret_cast(&bias_vec); + + float2 output_half_0 = + output_data_half[0].convert(); // __half22float2(output_data_half[0]); + float2 output_half_1 = + output_data_half[1].convert(); // __half22float2(output_data_half[1]); + + float2 gelu_input_half_0 = + gelu_input_data_half[0] + .convert(); // __half22float2(gelu_input_data_half[0]); + float2 gelu_input_half_1 = + gelu_input_data_half[1] + .convert(); // __half22float2(gelu_input_data_half[1]); + + float2 bias_half_0 = bias_half[0].convert(); // __half22float2(bias_half[0]); + float2 bias_half_1 = bias_half[1].convert(); // __half22float2(bias_half[1]); + + gelu_input_half_0.x() += bias_half_0.x(); + gelu_input_half_0.y() += bias_half_0.y(); + gelu_input_half_1.x() += bias_half_1.x(); + gelu_input_half_1.y() += bias_half_1.y(); + + output_half_0.x() *= d_gelu(gelu_input_half_0.x()); + output_half_0.y() *= d_gelu(gelu_input_half_0.y()); + output_half_1.x() *= d_gelu(gelu_input_half_1.x()); + output_half_1.y() *= d_gelu(gelu_input_half_1.y()); + + float2 result; + half2* result_half2 = reinterpret_cast(&result); + + result_half2[0] = output_half_0.convert(); // __float22half2_rn(output_half_0); + result_half2[1] = output_half_1.convert(); // __float22half2_rn(output_half_1); + + d_output_cast[row * row_stride + i * loop_stride + id] = result; + } + } +} + +template +void launch_bias_gelu(const T* input, + const T* bias, + T* output, + int intermediate_size, + int batch_size, + queue* stream) +{ + int iterations = (intermediate_size + 1023) / 1024; + int threads = (intermediate_size - 1) / (iterations * 4) + 1; + range<3> block_dims(1, 1, threads); + range<3> grid_dims(1, 1, batch_size); + + stream->submit([&](handler& cgh) { + cgh.parallel_for(nd_range<3>(grid_dims * block_dims, block_dims), [=](nd_item<3> item_ct1) { + fused_bias_gelu(input, bias, output, intermediate_size / 4, iterations, item_ct1); + }); + }); +} + +template +void launch_gelu(const T* input, T* output, int intermediate_size, int batch_size, queue* stream) +{ + int iterations = (intermediate_size + 1023) / 1024; + int threads = (intermediate_size - 1) / (iterations * 4) + 1; + range<3> block_dims(1, 1, threads); + range<3> grid_dims(1, 1, batch_size); + + stream->submit([&](handler& cgh) { + cgh.parallel_for(nd_range<3>(grid_dims * block_dims, block_dims), [=](nd_item<3> item_ct1) { + gelu_kernel(input, output, intermediate_size / 4, iterations, item_ct1); + }); + }); +} + +template void launch_bias_gelu(const float*, const float*, float*, int, int, queue*); +template void launch_bias_gelu(const half*, const half*, half*, int, int, queue*); +template void launch_bias_gelu(const bf16*, const bf16*, bf16*, int, int, queue*); + +template void launch_gelu(const float*, float*, int, int, queue*); +template void launch_gelu(const half*, half*, int, int, queue*); + +template +void launch_d_gelu(T* d_output, + const T* input, + const T* bias, + int intermediate_size, + int batch_size, + queue* stream) +{ + int iterations = (intermediate_size + 1023) / 1024; + int threads = (intermediate_size - 1) / (iterations * 4) + 1; + range<3> block_dims(1, 1, threads); + range<3> grid_dims(1, 1, batch_size); + + stream->submit([&](handler& cgh) { + cgh.parallel_for(nd_range<3>(grid_dims * block_dims, block_dims), [=](nd_item<3> item_ct1) { + d_gelu_func(d_output, input, bias, intermediate_size / 4, iterations, item_ct1); + }); + }); +} + +template void launch_d_gelu(float*, const float*, const float*, int, int, queue*); +template void launch_d_gelu(half*, const half*, const half*, int, int, queue*); +template void launch_d_gelu(bf16*, const bf16*, const bf16*, int, int, queue*); diff --git a/intel_extension_for_deepspeed/op_builder/csrc/transformer/sycl/general_kernels.dp.cpp b/intel_extension_for_deepspeed/op_builder/csrc/transformer/sycl/general_kernels.dp.cpp new file mode 100644 index 0000000..4d91e50 --- /dev/null +++ b/intel_extension_for_deepspeed/op_builder/csrc/transformer/sycl/general_kernels.dp.cpp @@ -0,0 +1,540 @@ +#include "sycl/general_kernels.hpp" +#include + +using namespace cl::sycl; + +constexpr int MAX_SG_NUM = 32; +constexpr int MAX_SG_NUM1 = MAX_SG_NUM + 1; +template +void column_sum_reduce(const T* inp, T* out, int rows, int width, nd_item<3> item_ct1, float* tile) +{ + group<3> b = item_ct1.get_group(); + sub_group sg = item_ct1.get_sub_group(); + + int idx = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2); + + int y_stride = width * MAX_SG_NUM; + + float localSum = 0; + + // Loop across matrix height + if (idx < width) { + int offset = item_ct1.get_local_id(1) * width + idx; + for (int r = item_ct1.get_local_id(1); r < rows; r += MAX_SG_NUM) { + localSum += (float)inp[offset]; + offset += y_stride; + } + } + + tile[item_ct1.get_local_id(2) * MAX_SG_NUM1 + item_ct1.get_local_id(1)] = localSum; + + item_ct1.barrier(); + + // Sum the shared buffer. + float sum = tile[item_ct1.get_local_id(1) * MAX_SG_NUM1 + item_ct1.get_local_id(2)]; + +#ifndef __STOCHASTIC_MODE__ + item_ct1.barrier(); +#endif + + for (int i = 1; i < MAX_SG_NUM; i <<= 1) { sum += sg.shuffle_down(sum, i); } + + if (item_ct1.get_local_id(2) == 0) { + int pos = item_ct1.get_group(2) * MAX_SG_NUM + item_ct1.get_local_id(1); + if (pos < width) out[pos] = sum; + } +} + +template <> +void column_sum_reduce(const bf16* inp, + bf16* out, + int rows, + int width, + nd_item<3> item_ct1, + float* tile) +{ + group<3> b = item_ct1.get_group(); + sub_group sg = item_ct1.get_sub_group(); + + int idx = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2); + + int y_stride = width * MAX_SG_NUM; + + float localSum = 0; + ushort* inp_cast = (ushort*)inp; + ushort* out_cast = (ushort*)out; + // Loop across matrix height + if (idx < width) { + int offset = item_ct1.get_local_id(1) * width + idx; + for (int r = item_ct1.get_local_id(1); r < rows; r += MAX_SG_NUM) { + localSum += bf16::to_float(inp_cast[offset]); + offset += y_stride; + } + } + + tile[item_ct1.get_local_id(2) * MAX_SG_NUM1 + item_ct1.get_local_id(1)] = localSum; + + item_ct1.barrier(); + + // Sum the shared buffer. + float sum = tile[item_ct1.get_local_id(1) * MAX_SG_NUM1 + item_ct1.get_local_id(2)]; + +#ifndef __STOCHASTIC_MODE__ + item_ct1.barrier(); +#endif + + for (int i = 1; i < MAX_SG_NUM; i <<= 1) { sum += sg.shuffle_down(sum, i); } + + if (item_ct1.get_local_id(2) == 0) { + int pos = item_ct1.get_group(2) * MAX_SG_NUM + item_ct1.get_local_id(1); + if (pos < width) out_cast[pos] = bf16::from_float(sum); + } +} + +template +void launch_fuse_transpose_bias_kernel(const T* inp, T* out, int rows, int cols, queue* stream) +{ + range<3> grid_dim(1, 1, (cols - 1) / MAX_SG_NUM + 1); + range<3> block_dim(1, MAX_SG_NUM, MAX_SG_NUM); + + stream->submit([&](handler& cgh) { + accessor tile( + range<2>(MAX_SG_NUM, MAX_SG_NUM1), cgh); + cgh.parallel_for(nd_range<3>(grid_dim * block_dim, block_dim), + [=](nd_item<3> item_ct1) [[intel::reqd_sub_group_size(MAX_SG_NUM)]] { + column_sum_reduce( + inp, out, rows, cols, item_ct1, tile.get_pointer()); + }); + }); +} + +template void launch_fuse_transpose_bias_kernel(const float* inp, + float* out, + int rows, + int cols, + queue* stream); +template void launch_fuse_transpose_bias_kernel(const bf16* inp, + bf16* out, + int rows, + int cols, + queue* stream); +template void launch_fuse_transpose_bias_kernel(const half* inp, + half* out, + int rows, + int cols, + queue* stream); + +void fused_add2_kernel(const int N, + float* out, + const float* inp1, + const float* inp2, + nd_item<3> item_ct1) +{ + const float4* inp1_4 = reinterpret_cast(inp1); + const float4* inp2_4 = reinterpret_cast(inp2); + float4* out_4 = reinterpret_cast(out); + + DPCPP_1D_KERNEL_LOOP(j, N) + { + float4 val; + float4 inp1_reg = inp1_4[j]; + float4 inp2_reg = inp2_4[j]; + + val.x() = inp1_reg.x() + inp2_reg.x(); + val.y() = inp1_reg.y() + inp2_reg.y(); + val.z() = inp1_reg.z() + inp2_reg.z(); + val.w() = inp1_reg.w() + inp2_reg.w(); + + out_4[j] = val; + } +} + +void fused_add2_kernel(const int N, + bf16* out, + const bf16* inp1, + const bf16* inp2, + nd_item<3> item_ct1) +{ + const ushort4* inp1_cast = reinterpret_cast(inp1); + const ushort4* inp2_cast = reinterpret_cast(inp2); + ushort4* out_cast = reinterpret_cast(out); + + DPCPP_1D_KERNEL_LOOP(j, N) + { + float4 val; + float4 inp1_reg = {bf16::to_float(inp1_cast[j].x()), + bf16::to_float(inp1_cast[j].y()), + bf16::to_float(inp1_cast[j].z()), + bf16::to_float(inp1_cast[j].w())}; + float4 inp2_reg = {bf16::to_float(inp2_cast[j].x()), + bf16::to_float(inp2_cast[j].y()), + bf16::to_float(inp2_cast[j].z()), + bf16::to_float(inp2_cast[j].w())}; + + val.x() = inp1_reg.x() + inp2_reg.x(); + val.y() = inp1_reg.y() + inp2_reg.y(); + val.z() = inp1_reg.z() + inp2_reg.z(); + val.w() = inp1_reg.w() + inp2_reg.w(); + + out_cast[j] = {bf16::from_float(val.x()), + bf16::from_float(val.y()), + bf16::from_float(val.z()), + bf16::from_float(val.w())}; + } +} + +void fused_add2_kernel(const int N, + half* out, + const half* inp1, + const half* inp2, + nd_item<3> item_ct1) +{ + float2 inp1_4; + float2 inp2_4; + + half2* inp1_h = reinterpret_cast(&inp1_4); + half2* inp2_h = reinterpret_cast(&inp2_4); + + const float2* inp1_arr = reinterpret_cast(inp1); + const float2* inp2_arr = reinterpret_cast(inp2); + + DPCPP_1D_KERNEL_LOOP(j, N) + { + inp1_4 = inp1_arr[j]; + inp2_4 = inp2_arr[j]; + + float2 inp1_h_f_0 = inp1_h[0].convert(); + float2 inp1_h_f_1 = inp1_h[1].convert(); + + float2 inp2_h_f_0 = inp2_h[0].convert(); + float2 inp2_h_f_1 = inp2_h[1].convert(); + + inp1_h_f_0.x() += inp2_h_f_0.x(); + inp1_h_f_0.y() += inp2_h_f_0.y(); + inp1_h_f_1.x() += inp2_h_f_1.x(); + inp1_h_f_1.y() += inp2_h_f_1.y(); + + float2 val_f; + half2* val_h = reinterpret_cast(&val_f); + + val_h[0] = inp1_h_f_0.convert(); + val_h[1] = inp1_h_f_1.convert(); + + float2* out_4 = reinterpret_cast(out); + out_4[j] = val_f; + } +} + +template +void launch_fused_add2(T* out, + const T* inp1, + const T* inp2, + int batch_size, + int seq_length, + int hidden_dim, + queue* stream) +{ + int total_count = batch_size * seq_length * hidden_dim / 4; + range<3> grid_dim = range<3>(1, 1, DS_GET_BLOCKS(total_count)); //(batch_size * seq_length); + + range<3> block_dim = range<3>(1, 1, DS_CUDA_NUM_THREADS); //(hidden_dim / 4); + stream->submit([&](handler& cgh) { + cgh.parallel_for(nd_range<3>(grid_dim * block_dim, block_dim), [=](nd_item<3> item_ct1) { + fused_add2_kernel(total_count, out, inp1, inp2, item_ct1); + }); + }); +} + +template void launch_fused_add2(float* out, + const float* inp1, + const float* inp2, + int batch_size, + int seq_length, + int hidden_dim, + queue* stream); +template void launch_fused_add2(bf16* out, + const bf16* inp1, + const bf16* inp2, + int batch_size, + int seq_length, + int hidden_dim, + queue* stream); +template void launch_fused_add2(half* out, + const half* inp1, + const half* inp2, + int batch_size, + int seq_length, + int hidden_dim, + queue* stream); + +void fused_add3_kernel(float* out, + const float* inp1, + const float* inp2, + const float* inp3, + int size, + int row_stride, + nd_item<3> item_ct1) +{ + int row = item_ct1.get_group(2); + int id = item_ct1.get_local_id(2); + + const float4* inp1_4 = reinterpret_cast(inp1); + const float4* inp2_4 = reinterpret_cast(inp2); + const float4* inp3_4 = reinterpret_cast(inp3); + + float4* out_4 = reinterpret_cast(out); + + float4 val; + float4 inp1_reg = inp1_4[row * row_stride + id]; + float4 inp2_reg = inp2_4[row * row_stride + id]; + float4 inp3_reg = inp3_4[row * row_stride + id]; + + val.x() = inp1_reg.x() + inp2_reg.x() + inp3_reg.x(); + val.y() = inp1_reg.y() + inp2_reg.y() + inp3_reg.y(); + val.z() = inp1_reg.z() + inp2_reg.z() + inp3_reg.z(); + val.w() = inp1_reg.w() + inp2_reg.w() + inp3_reg.w(); + + out_4[row * row_stride + id] = val; +} + +void fused_add3_kernel(half* out, + const half* inp1, + const half* inp2, + const half* inp3, + int size, + int row_stride, + nd_item<3> item_ct1) +{ + int row = item_ct1.get_group(2); + int id = item_ct1.get_local_id(2); + const float2* inp1_arr = reinterpret_cast(inp1); + const float2* inp2_arr = reinterpret_cast(inp2); + const float2* inp3_arr = reinterpret_cast(inp3); + + float2 inp1_4 = inp1_arr[row * row_stride + id]; + float2 inp2_4 = inp2_arr[row * row_stride + id]; + float2 inp3_4 = inp3_arr[row * row_stride + id]; + + half2* inp1_h = reinterpret_cast(&inp1_4); + half2* inp2_h = reinterpret_cast(&inp2_4); + half2* inp3_h = reinterpret_cast(&inp3_4); + + float2 inp1_h_f_0 = inp1_h[0].convert(); + float2 inp1_h_f_1 = inp1_h[1].convert(); + + float2 inp2_h_f_0 = inp2_h[0].convert(); + float2 inp2_h_f_1 = inp2_h[1].convert(); + + float2 inp3_h_f_0 = inp3_h[0].convert(); + float2 inp3_h_f_1 = inp3_h[1].convert(); + + inp1_h_f_0.x() += (inp2_h_f_0.x() + inp3_h_f_0.x()); + inp1_h_f_0.y() += (inp2_h_f_0.y() + inp3_h_f_0.y()); + inp1_h_f_1.x() += (inp2_h_f_1.x() + inp3_h_f_1.x()); + inp1_h_f_1.y() += (inp2_h_f_1.y() + inp3_h_f_1.y()); + + float2 val_f; + half2* val_h = reinterpret_cast(&val_f); + + val_h[0] = inp1_h_f_0.convert(); + val_h[1] = inp1_h_f_1.convert(); + + float2* out_4 = reinterpret_cast(out); + out_4[row * row_stride + id] = val_f; +} + +template <> +void launch_fused_add3(float* out, + const float* inp1, + const float* inp2, + const float* inp3, + int batch_size, + int seq_length, + int hidden_size, + queue* stream) +{ + range<3> grid_dim(1, 1, batch_size * seq_length); + range<3> block_dim(1, 1, hidden_size / 4); + + stream->submit([&](handler& cgh) { + cgh.parallel_for(nd_range<3>(grid_dim * block_dim, block_dim), [=](nd_item<3> item_ct1) { + fused_add3_kernel(out, + inp1, + inp2, + inp3, + (batch_size * seq_length * hidden_size), + hidden_size / 4, + item_ct1); + }); + }); +} + +template <> +void launch_fused_add3(half* out, + const half* inp1, + const half* inp2, + const half* inp3, + int batch_size, + int seq_length, + int hidden_size, + queue* stream) +{ + range<3> grid_dim(1, 1, batch_size * seq_length); + + range<3> block_dim(1, 1, hidden_size / 4); + + stream->submit([&](handler& cgh) { + cgh.parallel_for(nd_range<3>(grid_dim * block_dim, block_dim), [=](nd_item<3> item_ct1) { + fused_add3_kernel(out, + inp1, + inp2, + inp3, + (batch_size * seq_length * hidden_size), + hidden_size / 4, + item_ct1); + }); + }); +} + +void fused_add4_kernel(float* out, + const float* inp1, + const float* inp2, + const float* inp3, + const float* inp4, + int size, + int row_stride, + nd_item<3> item_ct1) +{ + int row = item_ct1.get_group(2); + int id = item_ct1.get_local_id(2); + + const float4* inp1_4 = reinterpret_cast(inp1); + const float4* inp2_4 = reinterpret_cast(inp2); + const float4* inp3_4 = reinterpret_cast(inp3); + const float4* inp4_4 = reinterpret_cast(inp4); + float4* out_4 = reinterpret_cast(out); + + float4 val; + float4 inp1_reg = inp1_4[row * row_stride + id]; + float4 inp2_reg = inp2_4[row * row_stride + id]; + float4 inp3_reg = inp3_4[row * row_stride + id]; + float4 inp4_reg = inp4_4[row * row_stride + id]; + + val.x() = inp1_reg.x() + inp2_reg.x() + inp3_reg.x() + inp4_reg.x(); + val.y() = inp1_reg.y() + inp2_reg.y() + inp3_reg.y() + inp4_reg.y(); + val.z() = inp1_reg.z() + inp2_reg.z() + inp3_reg.z() + inp4_reg.z(); + val.w() = inp1_reg.w() + inp2_reg.w() + inp3_reg.w() + inp4_reg.w(); + + out_4[row * row_stride + id] = val; +} + +void fused_add4_kernel(half* out, + const half* inp1, + const half* inp2, + const half* inp3, + const half* inp4, + int size, + int row_stride, + nd_item<3> item_ct1) +{ + int row = item_ct1.get_group(2); + int id = item_ct1.get_local_id(2); + const float2* inp1_arr = reinterpret_cast(inp1); + const float2* inp2_arr = reinterpret_cast(inp2); + const float2* inp3_arr = reinterpret_cast(inp3); + const float2* inp4_arr = reinterpret_cast(inp4); + + float2 inp1_4 = inp1_arr[row * row_stride + id]; + float2 inp2_4 = inp2_arr[row * row_stride + id]; + float2 inp3_4 = inp3_arr[row * row_stride + id]; + float2 inp4_4 = inp4_arr[row * row_stride + id]; + + half2* inp1_h = reinterpret_cast(&inp1_4); + half2* inp2_h = reinterpret_cast(&inp2_4); + half2* inp3_h = reinterpret_cast(&inp3_4); + half2* inp4_h = reinterpret_cast(&inp4_4); + + float2 inp1_h_f_0 = inp1_h[0].convert(); + float2 inp1_h_f_1 = inp1_h[1].convert(); + + float2 inp2_h_f_0 = inp2_h[0].convert(); + float2 inp2_h_f_1 = inp2_h[1].convert(); + + float2 inp3_h_f_0 = inp3_h[0].convert(); + float2 inp3_h_f_1 = inp3_h[1].convert(); + + float2 inp4_h_f_0 = inp4_h[0].convert(); + float2 inp4_h_f_1 = inp4_h[1].convert(); + + inp1_h_f_0.x() += (inp2_h_f_0.x() + inp3_h_f_0.x() + inp4_h_f_0.x()); + inp1_h_f_0.y() += (inp2_h_f_0.y() + inp3_h_f_0.y() + inp4_h_f_0.y()); + inp1_h_f_1.x() += (inp2_h_f_1.x() + inp3_h_f_1.x() + inp4_h_f_1.x()); + inp1_h_f_1.y() += (inp2_h_f_1.y() + inp3_h_f_1.y() + inp4_h_f_1.y()); + + float2 val_f; + half2* val_h = reinterpret_cast(&val_f); + + val_h[0] = inp1_h_f_0.convert(); + val_h[1] = inp1_h_f_1.convert(); + + float2* out_4 = reinterpret_cast(out); + out_4[row * row_stride + id] = val_f; +} + +template <> +void launch_fused_add4(float* out, + const float* inp1, + const float* inp2, + const float* inp3, + const float* inp4, + int batch_size, + int seq_length, + int hidden_size, + queue* stream) +{ + range<3> grid_dim(1, 1, batch_size * seq_length); + + range<3> block_dim(1, 1, hidden_size / 4); + + stream->submit([&](handler& cgh) { + cgh.parallel_for(nd_range<3>(grid_dim * block_dim, block_dim), [=](nd_item<3> item_ct1) { + fused_add4_kernel(out, + inp1, + inp2, + inp3, + inp4, + (batch_size * seq_length * hidden_size), + hidden_size / 4, + item_ct1); + }); + }); +} + +template <> +void launch_fused_add4(half* out, + const half* inp1, + const half* inp2, + const half* inp3, + const half* inp4, + int batch_size, + int seq_length, + int hidden_size, + queue* stream) +{ + range<3> grid_dim(1, 1, batch_size * seq_length); + + range<3> block_dim(1, 1, hidden_size / 4); + + stream->submit([&](handler& cgh) { + cgh.parallel_for(nd_range<3>(grid_dim * block_dim, block_dim), [=](nd_item<3> item_ct1) { + fused_add4_kernel(out, + inp1, + inp2, + inp3, + inp4, + (batch_size * seq_length * hidden_size), + hidden_size / 4, + item_ct1); + }); + }); +} diff --git a/intel_extension_for_deepspeed/op_builder/csrc/transformer/sycl/normalize_kernels.dp.cpp b/intel_extension_for_deepspeed/op_builder/csrc/transformer/sycl/normalize_kernels.dp.cpp new file mode 100644 index 0000000..f66b67d --- /dev/null +++ b/intel_extension_for_deepspeed/op_builder/csrc/transformer/sycl/normalize_kernels.dp.cpp @@ -0,0 +1,2529 @@ +#include +#include "sycl/custom_sycl_layers.hpp" +using namespace cl::sycl; +/* + Fused bias add, residual (elementwise) add, and normalization layer. + + For FP16, this kernel does not promote to FP32 in order to utilize the 2x + throughput for + __half2 instructions, and avoid the conversion overhead (1/8 of __hal2 + arithmetic). + + For specific launch constraints, see the launch functions. +*/ + +#define NORM_REG (128) +#define MAX_SG_NUM (32) +#define MAX_SG_NUM1 (MAX_SG_NUM + 1) +#define TILE_DIM (32) +template +void fused_bias_residual_layer_norm(float* vals, + const float* residual, + const float* gamma, + const float* beta, + float epsilon, + bool preLayerNorm, + bool training, + float* vars, + float* means, + int row_stride, + nd_item<3> item_ct1, + float* shr) +{ + int iteration_stride = item_ct1.get_local_range(2); + int iterations = row_stride / iteration_stride; + + // sycl::group<3> b = item_ct1.get_group(); + // cg::thread_block_tile g = cg::tiled_partition(b); + sub_group sg = item_ct1.get_sub_group(); + + int row = item_ct1.get_group(2); + int id = item_ct1.get_local_id(2); + // int gid = id / MAX_SG_NUM; + int gid = id / MAX_SG_NUM; + + float vals_arr[NORM_REG]; + + residual += (row * row_stride); + vals += (row * row_stride); + + float sum = 0.f; + int high_index = iterations * iteration_stride + id; +#pragma unroll + for (int i = 0; i < iterations; i++) { + vals_arr[i] = residual[i * iteration_stride + id]; + sum += vals_arr[i]; + } + if (high_index < row_stride) { + vals_arr[iterations] = residual[high_index]; + sum += vals_arr[iterations]; + iterations++; + } + + // for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); } + for (int i = 1; i < MAX_SG_NUM; i *= 2) { sum += sg.shuffle_down(sum, i); } + + if (sg.get_local_id() == 0) shr[gid] = sum; + + item_ct1.barrier(); + + if (sg.get_local_id() < (iteration_stride >> 5)) sum = shr[sg.get_local_id()]; + +#if !defined(__STOCHASTIC_MODE__) + item_ct1.barrier(); +#endif + + for (int i = 1; i < (iteration_stride >> 5); i *= 2) { sum += sg.shuffle_down(sum, i); } + + // sum = g.shfl(sum, 0); + sum = sg.shuffle(sum, 0); + float mean = sum / row_stride; + // if (training) + // if (g.thread_rank() == 0) means[row] = mean; + if constexpr (is_mean) { + if (training) + if (sg.get_local_id() == 0) means[row] = mean; + } + float variance = 0.f; + for (int i = 0; i < iterations; i++) { + vals_arr[i] -= mean; + variance += vals_arr[i] * vals_arr[i]; + } + + // for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); } + for (int i = 1; i < MAX_SG_NUM; i *= 2) { variance += sg.shuffle_down(variance, i); } + + // if (g.thread_rank() == 0) shr[gid] = variance; + if (sg.get_local_id() == 0) shr[gid] = variance; + + item_ct1.barrier(); + + if (sg.get_local_id() < (iteration_stride >> 5)) variance = shr[sg.get_local_id()]; + +#ifndef __STOCHASTIC_MODE__ + + item_ct1.barrier(); +#endif + + for (int i = 1; i < (iteration_stride >> 5); i *= 2) { + variance += sg.shuffle_down(variance, i); + } + variance = sg.shuffle(variance, 0); + variance /= row_stride; + variance += epsilon; + + if (training) + if (sg.get_local_id() == 0) vars[row] = variance; + + iterations = row_stride / iteration_stride; + for (int i = 0; i < iterations; i++) { + vals_arr[i] = vals_arr[i] * rsqrt(variance); + vals_arr[i] = + vals_arr[i] * gamma[i * iteration_stride + id] + beta[i * iteration_stride + id]; + vals[i * iteration_stride + id] = vals_arr[i]; + } + if ((high_index) < row_stride) { + vals_arr[iterations] = vals_arr[iterations] * rsqrt(variance); + vals_arr[iterations] = vals_arr[iterations] * gamma[high_index] + beta[high_index]; + vals[high_index] = vals_arr[iterations]; + } +} + +template +void fused_bias_residual_layer_norm(bf16* vals, + const bf16* residual, + const bf16* gamma, + const bf16* beta, + float epsilon, + bool preLayerNorm, + bool training, + bf16* vars, + bf16* means, + int row_stride, + nd_item<3> item_ct1, + float* shr) +{ + int iteration_stride = item_ct1.get_local_range(2); + int iterations = row_stride / iteration_stride; + + // sycl::group<3> b = item_ct1.get_group(); + // cg::thread_block_tile g = cg::tiled_partition(b); + sub_group sg = item_ct1.get_sub_group(); + + int row = item_ct1.get_group(2); + int id = item_ct1.get_local_id(2); + // int gid = id / MAX_SG_NUM; + int gid = id / MAX_SG_NUM; + + float vals_arr[NORM_REG]; + + ushort* vals_cast = reinterpret_cast(vals); + const ushort* residual_cast = reinterpret_cast(residual); + const ushort* gamma_cast = reinterpret_cast(gamma); + const ushort* beta_cast = reinterpret_cast(beta); + ushort* vars_cast = reinterpret_cast(vars); + ushort* means_cast = reinterpret_cast(means); + + residual_cast += (row * row_stride); + vals_cast += (row * row_stride); + + float sum = 0.f; + int high_index = iterations * iteration_stride + id; +#pragma unroll + for (int i = 0; i < iterations; i++) { + vals_arr[i] = bf16::to_float(residual_cast[i * iteration_stride + id]); + sum += vals_arr[i]; + } + if (high_index < row_stride) { + vals_arr[iterations] = bf16::to_float(residual_cast[high_index]); + sum += vals_arr[iterations]; + iterations++; + } + + for (int i = 1; i < MAX_SG_NUM; i *= 2) { sum += sg.shuffle_down(sum, i); } + + if (sg.get_local_id() == 0) shr[gid] = sum; + + item_ct1.barrier(); + + // if (g.thread_rank() < (iteration_stride >> 5)) sum = shr[g.thread_rank()]; + if (sg.get_local_id() < (iteration_stride >> 5)) sum = shr[sg.get_local_id()]; + +#if !defined(__STOCHASTIC_MODE__) + + item_ct1.barrier(); +#endif + + // for (int i = 1; i < (iteration_stride >> 5); i *= 2) { sum += + // g.shfl_down(sum, i); } + for (int i = 1; i < (iteration_stride >> 5); i *= 2) { sum += sg.shuffle_down(sum, i); } + + // sum = g.shfl(sum, 0); + sum = sg.shuffle(sum, 0); + float mean = sum / row_stride; + // if (training) + // if (g.thread_rank() == 0) means[row] = mean; + if constexpr (is_mean) { + if (training) + if (sg.get_local_id() == 0) means_cast[row] = bf16::from_float(mean); + } + float variance = 0.f; + for (int i = 0; i < iterations; i++) { + vals_arr[i] -= mean; + variance += vals_arr[i] * vals_arr[i]; + } + + // for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); } + for (int i = 1; i < MAX_SG_NUM; i *= 2) { variance += sg.shuffle_down(variance, i); } + + // if (g.thread_rank() == 0) shr[gid] = variance; + if (sg.get_local_id() == 0) shr[gid] = variance; + + item_ct1.barrier(); + + // if (g.thread_rank() < (iteration_stride >> 5)) variance = + // shr[g.thread_rank()]; + if (sg.get_local_id() < (iteration_stride >> 5)) variance = shr[sg.get_local_id()]; + +#ifndef __STOCHASTIC_MODE__ + + item_ct1.barrier(); +#endif + + // for (int i = 1; i < (iteration_stride >> 5); i *= 2) { variance += + // g.shfl_down(variance, i); } + for (int i = 1; i < (iteration_stride >> 5); i *= 2) { + variance += sg.shuffle_down(variance, i); + } + // variance = g.shfl(variance, 0); + variance = sg.shuffle(variance, 0); + variance /= row_stride; + variance += epsilon; + // if (training) + // if (g.thread_rank() == 0) vars[row] = variance; + if (training) + if (sg.get_local_id() == 0) vars_cast[row] = bf16::from_float(variance); + + iterations = row_stride / iteration_stride; + for (int i = 0; i < iterations; i++) { + vals_arr[i] = vals_arr[i] * rsqrt(variance); + vals_arr[i] = vals_arr[i] * bf16::to_float(gamma_cast[i * iteration_stride + id]) + + bf16::to_float(beta_cast[i * iteration_stride + id]); + vals_cast[i * iteration_stride + id] = bf16::from_float(vals_arr[i]); + } + if ((high_index) < row_stride) { + vals_arr[iterations] = vals_arr[iterations] * rsqrt(variance); + vals_arr[iterations] = vals_arr[iterations] * bf16::to_float(gamma[high_index]) + + bf16::to_float(beta[high_index]); + vals_cast[high_index] = bf16::from_float(vals_arr[iterations]); + } +} + +template +void fused_bias_residual_layer_norm(half* vals, + const half* residual, + const half* gamma, + const half* beta, + float epsilon, + bool preLayerNorm, + bool training, + half* vars, + half* means, + int row_stride, + nd_item<3> item_ct1, + float* shr) +{ + int iteration_stride = item_ct1.get_local_range(2); + int iterations = row_stride / iteration_stride; + + // cg::thread_block b = cg::this_thread_block(); + // cg::thread_block_tile<32> g = cg::tiled_partition<32>(b); + sub_group sg = item_ct1.get_sub_group(); + + int row = item_ct1.get_group(2); + int id = item_ct1.get_local_id(2); + // int gid = id >> 5; + int gid = id / MAX_SG_NUM; + + float2 vals_f[NORM_REG]; + + half2* vals_cast = reinterpret_cast(vals); + const half2* residual_cast = reinterpret_cast(residual); + + residual_cast += (row * row_stride); + vals_cast += (row * row_stride); + + float sum = 0.f; + int high_index = iterations * iteration_stride + id; +#pragma unroll + // for (int i = 0; i < iterations; i++) { + // vals_f[i] = __half22float2(residual_cast[i * iteration_stride + id]); + // sum += vals_f[i].x; + // sum += vals_f[i].y; + // } + for (int i = 0; i < iterations; i++) { + vals_f[i] = residual_cast[i * iteration_stride + id] + .convert(); // __half22float2(residual_cast[i * + // iteration_stride + id]); + sum += vals_f[i].x(); + sum += vals_f[i].y(); + } + // if ((high_index) < row_stride) { + // vals_f[iterations] = __half22float2(residual_cast[high_index]); + // sum += vals_f[iterations].x; + // sum += vals_f[iterations].y; + // iterations++; + // } + if ((high_index) < row_stride) { + vals_f[iterations] = residual_cast[high_index] + .convert(); // __half22float2(residual_cast[high_index]); + sum += vals_f[iterations].x(); + sum += vals_f[iterations].y(); + iterations++; + } + + // for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); } + for (int i = 1; i < MAX_SG_NUM; i *= 2) { sum += sg.shuffle_down(sum, i); } + + // if (g.thread_rank() == 0) shr[gid] = sum; + if (sg.get_local_id() == 0) shr[gid] = sum; + + item_ct1.barrier(); + + // if (g.thread_rank() < (iteration_stride >> 5)) sum = shr[g.thread_rank()]; + if (sg.get_local_id() < (iteration_stride >> 5)) sum = shr[sg.get_local_id()]; + +#ifndef __STOCHASTIC_MODE__ + // b.sync(); + item_ct1.barrier(); +#endif + + // for (int i = 1; i < (iteration_stride >> 5); i *= 2) { sum += + // g.shfl_down(sum, i); } + for (int i = 1; i < (iteration_stride >> 5); i *= 2) { sum += sg.shuffle_down(sum, i); } + // sum = g.shfl(sum, 0); + sum = sg.shuffle(sum, 0); + float mean = sum / (row_stride * 2); + + float variance = 0.f; + for (int i = 0; i < iterations; i++) { + vals_f[i].x() -= mean; + vals_f[i].y() -= mean; + variance += vals_f[i].x() * vals_f[i].x(); + variance += vals_f[i].y() * vals_f[i].y(); + } + + // for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); } + for (int i = 1; i < MAX_SG_NUM; i *= 2) { variance += sg.shuffle_down(variance, i); } + + // if (g.thread_rank() == 0) shr[gid] = variance; + if (sg.get_local_id() == 0) shr[gid] = variance; + + // b.sync(); + item_ct1.barrier(); + + // if (g.thread_rank() < (iteration_stride >> 5)) variance = + // shr[g.thread_rank()]; + if (sg.get_local_id() < (iteration_stride >> 5)) variance = shr[sg.get_local_id()]; + +#ifndef __STOCHASTIC_MODE__ + // b.sync(); + item_ct1.barrier(); +#endif + + // for (int i = 1; i < (iteration_stride >> 5); i *= 2) { variance += + // g.shfl_down(variance, i); } + for (int i = 1; i < (iteration_stride >> 5); i *= 2) { + variance += sg.shuffle_down(variance, i); + } + // variance = g.shfl(variance, 0); + variance = sg.shuffle(variance, 0); + variance /= (row_stride * 2); + variance += epsilon; + + half2 variance_h = + vec({variance, variance}).convert(); // __float2half2_rn(variance); + const half2* gamma_cast = reinterpret_cast(gamma); + const half2* beta_cast = reinterpret_cast(beta); + + // if (training && g.thread_rank() == 0) { + // vars[row] = __float2half(variance); + // means[row] = __float2half(mean); + // } + if (training && sg.get_local_id() == 0) { + vars[row] = vec(variance).convert(); // __float2half(variance); + if constexpr (is_mean) { + means[row] = vec(mean).convert(); // __float2half(mean); + } + } + iterations = row_stride / iteration_stride; + // for (int i = 0; i < iterations; i++) { + // half2 vals_arr = __float22half2_rn(vals_f[i]); + // vals_arr = vals_arr * h2rsqrt(variance_h); + // vals_arr = + // vals_arr * gamma_cast[i * iteration_stride + id] + beta_cast[i * + // iteration_stride + id]; + // vals_cast[i * iteration_stride + id] = vals_arr; + // } + for (int i = 0; i < iterations; i++) { + half2 vals_arr = vals_f[i].convert(); // __float22half2_rn(vals_f[i]); + vals_arr = vals_arr * rsqrt(variance_h); + vals_arr = + vals_arr * gamma_cast[i * iteration_stride + id] + beta_cast[i * iteration_stride + id]; + vals_cast[i * iteration_stride + id] = vals_arr; + } + // if ((high_index) < row_stride) { + // half2 vals_arr = __float22half2_rn(vals_f[iterations]); + // vals_arr = vals_arr * h2rsqrt(variance_h); + // vals_arr = vals_arr * gamma_cast[high_index] + beta_cast[high_index]; + // vals_cast[high_index] = vals_arr; + // } + if ((high_index) < row_stride) { + half2 vals_arr = + vals_f[iterations].convert(); // __float22half2_rn(vals_f[iterations]); + vals_arr = vals_arr * rsqrt(variance_h); + vals_arr = vals_arr * gamma_cast[high_index] + beta_cast[high_index]; + vals_cast[high_index] = vals_arr; + } +} + +template +void launch_bias_residual_layer_norm(T* vals, + const T* residual, + const T* gamma, + const T* beta, + float epsilon, + int batch_size, + int hidden_dim, + sycl::queue* stream, + bool preLayerNorm, + bool training, + T* vars, + T* means) +{ + int threads = THREADS; + + sycl::range<3> grid_dim(1, 1, batch_size); + + if (hidden_dim > 16384 && hidden_dim <= 32768) + threads <<= 1; + else if (hidden_dim > 32768 && hidden_dim <= 65536) + threads <<= 2; + else if (hidden_dim > 65536) + throw std::runtime_error("Unsupport hidden_dim."); + + sycl::range<3> block_dim(1, 1, threads); + + stream->submit([&](sycl::handler& cgh) { + sycl::accessor + shr_acc_ct1(sycl::range<1>(MAX_SG_NUM /*MAX_WARP_NUM*/), cgh); + cgh.parallel_for(sycl::nd_range<3>(grid_dim * block_dim, block_dim), + [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(MAX_SG_NUM)]] { + fused_bias_residual_layer_norm(vals, + residual, + gamma, + beta, + epsilon, + preLayerNorm, + training, + vars, + means, + hidden_dim, + item_ct1, + shr_acc_ct1.get_pointer()); + }); + }); +} + +template void launch_bias_residual_layer_norm(float* vals, + const float* residual, + const float* gamma, + const float* beta, + float epsilon, + int batch_size, + int hidden_dim, + sycl::queue* stream, + bool preLayerNorm, + bool training, + float* vars, + float* means); +template void launch_bias_residual_layer_norm(bf16* vals, + const bf16* residual, + const bf16* gamma, + const bf16* beta, + float epsilon, + int batch_size, + int hidden_dim, + sycl::queue* stream, + bool preLayerNorm, + bool training, + bf16* vars, + bf16* means); +template <> +void launch_bias_residual_layer_norm(half* vals, + const half* residual, + const half* gamma, + const half* beta, + float epsilon, + int batch_size, + int hidden_dim, + queue* stream, + bool preLayerNorm, + bool training, + half* vars, + half* means) +{ + int threads = 128; + + range<3> grid_dim(1, 1, batch_size); + + if (hidden_dim > 8192 && hidden_dim <= 16384) + threads <<= 1; + else if (hidden_dim > 16384 && hidden_dim <= 32768) + threads <<= 2; + else if (hidden_dim > 32768 && hidden_dim <= 65536) + threads <<= 3; + else if (hidden_dim > 65536) + throw std::runtime_error("Unsupport hidden_dim."); + + range<3> block_dim(1, 1, threads); + + stream->submit([&](handler& cgh) { + sycl::accessor + shr_acc_ct1(sycl::range<1>(MAX_SG_NUM /*MAX_WARP_NUM*/), cgh); + cgh.parallel_for(nd_range<3>(grid_dim * block_dim, block_dim), + [=](nd_item<3> item_ct1) [[intel::reqd_sub_group_size(MAX_SG_NUM)]] { + fused_bias_residual_layer_norm(vals, + residual, + gamma, + beta, + epsilon, + preLayerNorm, + training, + vars, + means, + hidden_dim / 2, + item_ct1, + shr_acc_ct1.get_pointer()); + }); + }); +} + +/* + To tune this launch the following restrictions must be met: + + For float: + row_stride == hidden_size + threads * iterations == row_stride + threads is in [32, 64, 128, 256, 512, 1024] + + For half: + row_stride == hidden_size / 2 + threads * iterations == row_stride + threads is in [32, 64, 128, 256, 512, 1024] + +*/ + +template +void launch_bias_residual_layer_norm(T* vals, + const T* residual, + const T* gamma, + const T* beta, + float epsilon, + int batch_size, + int hidden_dim, + queue* stream, + bool preLayerNorm, + bool training, + T* vars) +{ + int threads = THREADS; + + range<3> grid_dim(1, 1, batch_size); + + // There are some limitations to call below functions, now just enumerate the + // situations. + + if (hidden_dim > 16384 && hidden_dim <= 32768) + threads <<= 1; + else if (hidden_dim > 32768 && hidden_dim <= 65536) + threads <<= 2; + else if (hidden_dim > 65536) + throw std::runtime_error("Unsupport hidden_dim."); + + range<3> block_dim(1, 1, threads); + + stream->submit([&](handler& cgh) { + accessor shr_acc_ct1( + range<1>(MAX_SG_NUM /*MAX_WARP_NUM*/), cgh); + + cgh.parallel_for(nd_range<3>(grid_dim * block_dim, block_dim), + [=](nd_item<3> item_ct1) [[intel::reqd_sub_group_size(MAX_SG_NUM)]] { + fused_bias_residual_layer_norm(vals, + residual, + gamma, + beta, + epsilon, + preLayerNorm, + training, + vars, + nullptr, + hidden_dim, + item_ct1, + shr_acc_ct1.get_pointer()); + }); + }); +} + +template void launch_bias_residual_layer_norm(float* vals, + const float* residual, + const float* gamma, + const float* beta, + float epsilon, + int batch_size, + int hidden_dim, + queue* stream, + bool preLayerNorm, + bool training, + float* vars); +template void launch_bias_residual_layer_norm(bf16* vals, + const bf16* residual, + const bf16* gamma, + const bf16* beta, + float epsilon, + int batch_size, + int hidden_dim, + queue* stream, + bool preLayerNorm, + bool training, + bf16* vars); +template <> +void launch_bias_residual_layer_norm(half* vals, + const half* residual, + const half* gamma, + const half* beta, + float epsilon, + int batch_size, + int hidden_dim, + queue* stream, + bool preLayerNorm, + bool training, + half* vars) +{ + int threads = 128; + + range<3> grid_dim(1, 1, batch_size); + + // There are some limitations to call below functions, now just enumerate the + // situations. + + if (hidden_dim > 8192 && hidden_dim <= 16384) + threads <<= 1; + else if (hidden_dim > 16384 && hidden_dim <= 32768) + threads <<= 2; + else if (hidden_dim > 32768 && hidden_dim <= 65536) + threads <<= 3; + else if (hidden_dim > 65536) + throw std::runtime_error("Unsupport hidden_dim."); + + range<3> block_dim(1, 1, threads); + + stream->submit([&](handler& cgh) { + sycl::accessor + shr_acc_ct1(sycl::range<1>(MAX_SG_NUM /*MAX_WARP_NUM*/), cgh); + cgh.parallel_for(nd_range<3>(grid_dim * block_dim, block_dim), + [=](nd_item<3> item_ct1) [[intel::reqd_sub_group_size(MAX_SG_NUM)]] { + fused_bias_residual_layer_norm(vals, + residual, + gamma, + beta, + epsilon, + preLayerNorm, + training, + vars, + nullptr, + hidden_dim / 2, + item_ct1, + shr_acc_ct1.get_pointer()); + }); + }); +} + +/* Normalize Gamma & Betta gradients + * Compute gradients using either X_hat or + * normalize input (invertible). + * Combine transpose with gradients computation. + */ + +template +void LayerNormBackward1(const T* out_grad, + const T* vals_hat, + const T* gamma, + const T* betta, + T* gamma_grad, + T* betta_grad, + int rows, + int width, + bool invertible, + nd_item<3> item_ct1, + float* betta_buffer, + float* gamma_buffer) +{ + // group<3> b = item_ct1.get_group(); + // cg::thread_block_tile g = cg::tiled_partition(b); + sub_group sg = item_ct1.get_sub_group(); + + int idx = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2); + int offset = item_ct1.get_local_id(1) * width + idx; + int y_stride = width * TILE_DIM; + + float betta_reg = (invertible ? (float)betta[idx] : 0.0f); + float gamma_reg = (float)gamma[idx]; + + // Loop across matrix height + float betta_tmp = 0; + float gamma_tmp = 0; + for (int r = item_ct1.get_local_id(1); r < rows; r += TILE_DIM) { + float grad = (float)out_grad[offset]; + float val = (invertible ? ((float)vals_hat[offset] - betta_reg) / gamma_reg + : (float)vals_hat[offset]); + betta_tmp += grad; + gamma_tmp += (val * grad); + + offset += y_stride; + } + + // betta_buffer[item_ct1.get_local_id(2)][item_ct1.get_local_id(1)] = + // betta_tmp; gamma_buffer[item_ct1.get_local_id(2)][item_ct1.get_local_id(1)] + // = gamma_tmp; + betta_buffer[item_ct1.get_local_id(2) * MAX_SG_NUM1 + item_ct1.get_local_id(1)] = betta_tmp; + gamma_buffer[item_ct1.get_local_id(2) * MAX_SG_NUM1 + item_ct1.get_local_id(1)] = gamma_tmp; + + item_ct1.barrier(); + + // Sum the shared buffer. + // float s1 = + // betta_buffer[item_ct1.get_local_id(1)][item_ct1.get_local_id(2)]; float s2 + // = gamma_buffer[item_ct1.get_local_id(1)][item_ct1.get_local_id(2)]; + float s1 = betta_buffer[item_ct1.get_local_id(1) * MAX_SG_NUM1 + item_ct1.get_local_id(2)]; + float s2 = gamma_buffer[item_ct1.get_local_id(1) * MAX_SG_NUM1 + item_ct1.get_local_id(2)]; + +#ifndef __STOCHASTIC_MODE__ + + item_ct1.barrier(); +#endif + + // for (int i = 1; i < TILE_DIM; i <<= 1) { + // s1 += g.shfl_down(s1, i); + // s2 += g.shfl_down(s2, i); + // } + for (int i = 1; i < TILE_DIM; i <<= 1) { + s1 += sg.shuffle_down(s1, i); + s2 += sg.shuffle_down(s2, i); + } + + if (item_ct1.get_local_id(2) == 0) { + int pos = item_ct1.get_group(2) * TILE_DIM + item_ct1.get_local_id(1); + betta_grad[pos] = s1; + gamma_grad[pos] = s2; + } +} + +/* Normalize Gamma & Betta gradients + * Compute gradients using either X_hat or + * normalize input (invertible). + * Combine transpose with gradients computation. + */ + +template <> +void LayerNormBackward1(const bf16* out_grad, + const bf16* vals_hat, + const bf16* gamma, + const bf16* betta, + bf16* gamma_grad, + bf16* betta_grad, + int rows, + int width, + bool invertible, + nd_item<3> item_ct1, + float* betta_buffer, + float* gamma_buffer) +{ + // group<3> b = item_ct1.get_group(); + // cg::thread_block_tile g = cg::tiled_partition(b); + sub_group sg = item_ct1.get_sub_group(); + + int idx = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2); + int offset = item_ct1.get_local_id(1) * width + idx; + int y_stride = width * TILE_DIM; + + const ushort* out_grad_cast = reinterpret_cast(out_grad); + const ushort* vals_hat_cast = reinterpret_cast(vals_hat); + const ushort* gamma_cast = reinterpret_cast(gamma); + const ushort* betta_cast = reinterpret_cast(betta); + ushort* gamma_grad_cast = reinterpret_cast(gamma_grad); + ushort* betta_grad_cast = reinterpret_cast(betta_grad); + + float betta_reg = (invertible ? bf16::to_float(betta_cast[idx]) : 0.0f); + float gamma_reg = bf16::to_float(gamma_cast[idx]); + + // Loop across matrix height + float betta_tmp = 0; + float gamma_tmp = 0; + for (int r = item_ct1.get_local_id(1); r < rows; r += TILE_DIM) { + float grad = bf16::to_float(out_grad_cast[offset]); + float val = (invertible ? (bf16::to_float(vals_hat_cast[offset]) - betta_reg) / gamma_reg + : bf16::to_float(vals_hat_cast[offset])); + betta_tmp += grad; + gamma_tmp += (val * grad); + + offset += y_stride; + } + + // betta_buffer[item_ct1.get_local_id(2)][item_ct1.get_local_id(1)] = + // betta_tmp; gamma_buffer[item_ct1.get_local_id(2)][item_ct1.get_local_id(1)] + // = gamma_tmp; + betta_buffer[item_ct1.get_local_id(2) * MAX_SG_NUM1 + item_ct1.get_local_id(1)] = betta_tmp; + gamma_buffer[item_ct1.get_local_id(2) * MAX_SG_NUM1 + item_ct1.get_local_id(1)] = gamma_tmp; + + item_ct1.barrier(); + + // Sum the shared buffer. + // float s1 = + // betta_buffer[item_ct1.get_local_id(1)][item_ct1.get_local_id(2)]; float s2 + // = gamma_buffer[item_ct1.get_local_id(1)][item_ct1.get_local_id(2)]; + float s1 = betta_buffer[item_ct1.get_local_id(1) * MAX_SG_NUM1 + item_ct1.get_local_id(2)]; + float s2 = gamma_buffer[item_ct1.get_local_id(1) * MAX_SG_NUM1 + item_ct1.get_local_id(2)]; + +#ifndef __STOCHASTIC_MODE__ + + item_ct1.barrier(); +#endif + + // for (int i = 1; i < TILE_DIM; i <<= 1) { + // s1 += g.shfl_down(s1, i); + // s2 += g.shfl_down(s2, i); + // } + for (int i = 1; i < TILE_DIM; i <<= 1) { + s1 += sg.shuffle_down(s1, i); + s2 += sg.shuffle_down(s2, i); + } + + if (item_ct1.get_local_id(2) == 0) { + int pos = item_ct1.get_group(2) * TILE_DIM + item_ct1.get_local_id(1); + betta_grad_cast[pos] = bf16::from_float(s1); + gamma_grad_cast[pos] = bf16::from_float(s2); + } +} + +/* Normalize Gamma & Betta gradients + * Compute gradients using the input to + * the normalize. + * Combine transpose with gradients computation. + */ + +template +void LayerNormBackward1(const T* out_grad, + const T* X_data, + const T* vars, + const T* means, + T* gamma_grad, + T* betta_grad, + int rows, + int width, + nd_item<3> item_ct1, + float* betta_buffer, + float* gamma_buffer) +{ + // group<3> b = item_ct1.get_group(); + // cg::thread_block_tile g = cg::tiled_partition(b); + sub_group sg = item_ct1.get_sub_group(); + + int idx = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2); + int offset = item_ct1.get_local_id(1) * width + idx; + int y_stride = width * TILE_DIM; + + int pos = item_ct1.get_group(2) * TILE_DIM + item_ct1.get_local_id(1); + // Loop across matrix height + + float betta_tmp = 0; + float gamma_tmp = 0; + for (int r = item_ct1.get_local_id(1); r < rows; r += TILE_DIM) { + float grad = (float)out_grad[offset]; + float val = (float)X_data[offset]; + val = (val - (float)means[r]) * rsqrt((float)vars[r]); + betta_tmp += grad; + gamma_tmp += (val * grad); + + offset += y_stride; + } + + betta_buffer[item_ct1.get_local_id(2) * MAX_SG_NUM1 + item_ct1.get_local_id(1)] = betta_tmp; + gamma_buffer[item_ct1.get_local_id(2) * MAX_SG_NUM1 + item_ct1.get_local_id(1)] = gamma_tmp; + + item_ct1.barrier(); + + // Sum the shared buffer. + float s1 = betta_buffer[item_ct1.get_local_id(1) * MAX_SG_NUM1 + item_ct1.get_local_id(2)]; + float s2 = gamma_buffer[item_ct1.get_local_id(1) * MAX_SG_NUM1 + item_ct1.get_local_id(2)]; + +#ifndef __STOCHASTIC_MODE__ + + item_ct1.barrier(); +#endif + + // for (int i = 1; i < TILE_DIM; i <<= 1) { + // s1 += g.shfl_down(s1, i); + // s2 += g.shfl_down(s2, i); + // } + for (int i = 1; i < TILE_DIM; i <<= 1) { + s1 += sg.shuffle_down(s1, i); + s2 += sg.shuffle_down(s2, i); + } + + if (item_ct1.get_local_id(2) == 0) { + betta_grad[pos] = s1; + gamma_grad[pos] = s2; + } +} + +template <> +void LayerNormBackward1(const bf16* out_grad, + const bf16* X_data, + const bf16* vars, + const bf16* means, + bf16* gamma_grad, + bf16* betta_grad, + int rows, + int width, + nd_item<3> item_ct1, + float* betta_buffer, + float* gamma_buffer) +{ + // group<3> b = item_ct1.get_group(); + // cg::thread_block_tile g = cg::tiled_partition(b); + sub_group sg = item_ct1.get_sub_group(); + + int idx = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2); + int offset = item_ct1.get_local_id(1) * width + idx; + int y_stride = width * TILE_DIM; + + int pos = item_ct1.get_group(2) * TILE_DIM + item_ct1.get_local_id(1); + // Loop across matrix height + + const ushort* out_grad_cast = reinterpret_cast(out_grad); + const ushort* X_data_cast = reinterpret_cast(X_data); + const ushort* vars_cast = reinterpret_cast(vars); + const ushort* means_cast = reinterpret_cast(means); + ushort* gamma_grad_cast = reinterpret_cast(gamma_grad); + ushort* betta_grad_cast = reinterpret_cast(betta_grad); + + float betta_tmp = 0; + float gamma_tmp = 0; + for (int r = item_ct1.get_local_id(1); r < rows; r += TILE_DIM) { + float grad = bf16::to_float(out_grad_cast[offset]); + float val = bf16::to_float(X_data_cast[offset]); + val = (val - bf16::to_float(means_cast[r])) * rsqrt(bf16::to_float(vars_cast[r])); + betta_tmp += grad; + gamma_tmp += (val * grad); + + offset += y_stride; + } + + betta_buffer[item_ct1.get_local_id(2) * MAX_SG_NUM1 + item_ct1.get_local_id(1)] = betta_tmp; + gamma_buffer[item_ct1.get_local_id(2) * MAX_SG_NUM1 + item_ct1.get_local_id(1)] = gamma_tmp; + + item_ct1.barrier(); + + // Sum the shared buffer. + float s1 = betta_buffer[item_ct1.get_local_id(1) * MAX_SG_NUM1 + item_ct1.get_local_id(2)]; + float s2 = gamma_buffer[item_ct1.get_local_id(1) * MAX_SG_NUM1 + item_ct1.get_local_id(2)]; + +#ifndef __STOCHASTIC_MODE__ + + item_ct1.barrier(); +#endif + + // for (int i = 1; i < TILE_DIM; i <<= 1) { + // s1 += g.shfl_down(s1, i); + // s2 += g.shfl_down(s2, i); + // } + for (int i = 1; i < TILE_DIM; i <<= 1) { + s1 += sg.shuffle_down(s1, i); + s2 += sg.shuffle_down(s2, i); + } + + if (item_ct1.get_local_id(2) == 0) { + betta_grad_cast[pos] = bf16::from_float(s1); + gamma_grad_cast[pos] = bf16::from_float(s2); + } +} +/* + +/* Backward Normalize (Input-Gradient) +* Using the means and variances from the input +* This type of backward is invertible! +* We do the backward using the X_hat (X - u) / sqrt(variance) or the output of +Normalization. +*/ +template +void LayerNormBackward2(const float* out_grad, + const float* out_grad_add, + const float* vals_hat, + const float* gamma, + const float* betta, + const float* vars, + float* inp_grad, + bool invertible, + int row_stride, + nd_item<3> item_ct1, + float* partialSum) +{ + int iteration_stride = item_ct1.get_local_range(2); + int iterations = row_stride / iteration_stride; + + // group<3> b = item_ct1.get_group(); + // cg::thread_block_tile g = cg::tiled_partition(b); + sub_group sg = item_ct1.get_sub_group(); + + int row = item_ct1.get_group(2); + int id = item_ct1.get_local_id(2); + int wid = id / MAX_SG_NUM; + int warp_num = (THREADS < row_stride ? THREADS : row_stride) / MAX_SG_NUM; + + out_grad += (row * row_stride); + if constexpr (is_fuseadd) { out_grad_add += (row * row_stride); } + vals_hat += (row * row_stride); + inp_grad += (row * row_stride); + + float vals_arr[NORM_REG]; + float vals_hat_arr[NORM_REG]; + int high_index = iterations * iteration_stride + id; +#pragma unroll + for (int i = 0; i < iterations; i++) { + float gamma_reg = gamma[i * iteration_stride + id]; + vals_arr[i] = out_grad[i * iteration_stride + id]; + vals_arr[i] *= gamma_reg; + vals_hat_arr[i] = + (invertible ? (vals_hat[i * iteration_stride + id] - betta[i * iteration_stride + id]) / + gamma_reg + : vals_hat[i * iteration_stride + id]); + } + if ((high_index) < row_stride) { + float gamma_reg = gamma[high_index]; + vals_arr[iterations] = out_grad[high_index]; + vals_arr[iterations] *= gamma_reg; + vals_hat_arr[iterations] = + (invertible ? (vals_hat[high_index] - betta[high_index]) / gamma_reg + : vals_hat[high_index]); + iterations++; + } + + float var_reg = vars[row]; + + float sum = 0; + for (int i = 0; i < iterations; i++) { + sum += + vals_hat_arr[i] * vals_arr[i] * sqrt(var_reg); // dval_hat = gamma * (x - u) * out_grad + vals_arr[i] *= rsqrt(var_reg); // dvar_inv = gamma * out_grad / sqrt(var) + } + + // for (int i = 1; i < MAX_SG_NUM; i *= 2) { sum += g.shfl_down(sum, i); } + for (int i = 1; i < MAX_SG_NUM; i *= 2) { sum += sg.shuffle_down(sum, i); } + + // if (g.thread_rank() == 0) partialSum[wid] = sum; + if (sg.get_local_id() == 0) partialSum[wid] = sum; + + item_ct1.barrier(); + + // if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; + if (sg.get_local_id() < warp_num) sum = partialSum[sg.get_local_id()]; + +#ifndef __STOCHASTIC_MODE__ + + item_ct1.barrier(); +#endif + + // for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); + for (int i = 1; i < warp_num; i *= 2) sum += sg.shuffle_down(sum, i); + + // sum = g.shfl(sum, 0); + sum = sg.shuffle(sum, 0); + sum /= row_stride; + + for (int i = 0; i < iterations; i++) { vals_arr[i] += ((-sum * vals_hat_arr[i]) / var_reg); } + + sum = 0; + for (int i = 0; i < iterations; i++) { sum += vals_arr[i]; } + + // for (int i = 1; i < MAX_SG_NUM; i *= 2) { sum += g.shfl_down(sum, i); } + for (int i = 1; i < MAX_SG_NUM; i *= 2) { sum += sg.shuffle_down(sum, i); } + + // if (g.thread_rank() == 0) partialSum[wid] = sum; + if (sg.get_local_id() == 0) partialSum[wid] = sum; + + item_ct1.barrier(); + + // if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; + if (sg.get_local_id() < warp_num) sum = partialSum[sg.get_local_id()]; + +#ifndef __STOCHASTIC_MODE__ + item_ct1.barrier(); +#endif + + // for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); + for (int i = 1; i < warp_num; i *= 2) sum += sg.shuffle_down(sum, i); + // sum = g.shfl(sum, 0); + sum = sg.shuffle(sum, 0); + sum /= row_stride; + + iterations = row_stride / iteration_stride; + for (int i = 0; i < iterations; i++) + if constexpr (is_fuseadd) { + inp_grad[i * iteration_stride + id] = + (vals_arr[i] - sum) + out_grad_add[i * iteration_stride + id]; + } else { + inp_grad[i * iteration_stride + id] = (vals_arr[i] - sum); + } + if ((high_index) < row_stride) + if constexpr (is_fuseadd) { + inp_grad[high_index] = (vals_arr[iterations] - sum) + out_grad_add[high_index]; + } else { + inp_grad[high_index] = (vals_arr[iterations] - sum); + } +} + +template +void LayerNormBackward2(const bf16* out_grad, + const bf16* out_grad_add, + const bf16* vals_hat, + const bf16* gamma, + const bf16* betta, + const bf16* vars, + bf16* inp_grad, + bool invertible, + int row_stride, + nd_item<3> item_ct1, + float* partialSum) +{ + int iteration_stride = item_ct1.get_local_range(2); + int iterations = row_stride / iteration_stride; + + // group<3> b = item_ct1.get_group(); + // cg::thread_block_tile g = cg::tiled_partition(b); + sub_group sg = item_ct1.get_sub_group(); + + int row = item_ct1.get_group(2); + int id = item_ct1.get_local_id(2); + int wid = id / MAX_SG_NUM; + int warp_num = (THREADS < row_stride ? THREADS : row_stride) / MAX_SG_NUM; + + const ushort* out_grad_cast = reinterpret_cast(out_grad); + const ushort* out_grad_add_cast = reinterpret_cast(out_grad_add); + const ushort* vals_hat_cast = reinterpret_cast(vals_hat); + const ushort* gamma_cast = reinterpret_cast(gamma); + const ushort* betta_cast = reinterpret_cast(betta); + const ushort* vars_cast = reinterpret_cast(vars); + ushort* inp_grad_cast = reinterpret_cast(inp_grad); + + out_grad_cast += (row * row_stride); + if constexpr (is_fuseadd) { out_grad_add_cast += (row * row_stride); } + vals_hat_cast += (row * row_stride); + inp_grad_cast += (row * row_stride); + + float vals_arr[NORM_REG]; + float vals_hat_arr[NORM_REG]; + int high_index = iterations * iteration_stride + id; +#pragma unroll + for (int i = 0; i < iterations; i++) { + float gamma_reg = bf16::to_float(gamma_cast[i * iteration_stride + id]); + vals_arr[i] = bf16::to_float(out_grad_cast[i * iteration_stride + id]); + vals_arr[i] *= gamma_reg; + vals_hat_arr[i] = (invertible ? (bf16::to_float(vals_hat_cast[i * iteration_stride + id]) - + bf16::to_float(betta_cast[i * iteration_stride + id])) / + gamma_reg + : bf16::to_float(vals_hat_cast[i * iteration_stride + id])); + } + if ((high_index) < row_stride) { + float gamma_reg = bf16::to_float(gamma_cast[high_index]); + vals_arr[iterations] = bf16::to_float(out_grad_cast[high_index]); + vals_arr[iterations] *= gamma_reg; + vals_hat_arr[iterations] = (invertible ? (bf16::to_float(vals_hat_cast[high_index]) - + bf16::to_float(betta_cast[high_index])) / + gamma_reg + : bf16::to_float(vals_hat_cast[high_index])); + iterations++; + } + + float var_reg = bf16::to_float(vars_cast[row]); + + float sum = 0; + for (int i = 0; i < iterations; i++) { + sum += + vals_hat_arr[i] * vals_arr[i] * sqrt(var_reg); // dval_hat = gamma * (x - u) * out_grad + vals_arr[i] *= rsqrt(var_reg); // dvar_inv = gamma * out_grad / sqrt(var) + } + + // for (int i = 1; i < MAX_SG_NUM; i *= 2) { sum += g.shfl_down(sum, i); } + for (int i = 1; i < MAX_SG_NUM; i *= 2) { sum += sg.shuffle_down(sum, i); } + + // if (g.thread_rank() == 0) partialSum[wid] = sum; + if (sg.get_local_id() == 0) partialSum[wid] = sum; + + item_ct1.barrier(); + + // if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; + if (sg.get_local_id() < warp_num) sum = partialSum[sg.get_local_id()]; + +#ifndef __STOCHASTIC_MODE__ + item_ct1.barrier(); +#endif + + // for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); + for (int i = 1; i < warp_num; i *= 2) sum += sg.shuffle_down(sum, i); + + // sum = g.shfl(sum, 0); + sum = sg.shuffle(sum, 0); + sum /= row_stride; + + for (int i = 0; i < iterations; i++) { vals_arr[i] += ((-sum * vals_hat_arr[i]) / var_reg); } + + sum = 0; + for (int i = 0; i < iterations; i++) { sum += vals_arr[i]; } + + // for (int i = 1; i < MAX_SG_NUM; i *= 2) { sum += g.shfl_down(sum, i); } + for (int i = 1; i < MAX_SG_NUM; i *= 2) { sum += sg.shuffle_down(sum, i); } + + // if (g.thread_rank() == 0) partialSum[wid] = sum; + if (sg.get_local_id() == 0) partialSum[wid] = sum; + + item_ct1.barrier(); + + // if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; + if (sg.get_local_id() < warp_num) sum = partialSum[sg.get_local_id()]; + +#ifndef __STOCHASTIC_MODE__ + item_ct1.barrier(); +#endif + + // for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); + for (int i = 1; i < warp_num; i *= 2) sum += sg.shuffle_down(sum, i); + // sum = g.shfl(sum, 0); + sum = sg.shuffle(sum, 0); + sum /= row_stride; + + iterations = row_stride / iteration_stride; + for (int i = 0; i < iterations; i++) + if constexpr (is_fuseadd) { + inp_grad_cast[i * iteration_stride + id] = bf16::from_float( + (vals_arr[i] - sum) + bf16::to_float(out_grad_add_cast[i * iteration_stride + id])); + } else { + inp_grad_cast[i * iteration_stride + id] = bf16::from_float((vals_arr[i] - sum)); + } + if ((high_index) < row_stride) + if constexpr (is_fuseadd) { + inp_grad_cast[high_index] = bf16::from_float( + (vals_arr[iterations] - sum) + bf16::to_float(out_grad_add_cast[high_index])); + } else { + inp_grad_cast[high_index] = bf16::from_float((vals_arr[iterations] - sum)); + } +} + +template +void LayerNormBackward2(const half* out_grad, + const half* out_grad_add, + const half* vals_hat, + const half* gamma, + const half* betta, + const half* vars, + half* inp_grad, + bool invertible, + int row_stride, + nd_item<3> item_ct1, + float* partialSum) +{ + int iteration_stride = item_ct1.get_local_range(2); + int iterations = row_stride / iteration_stride; + + // group<3> b = item_ct1.get_group(); + // cg::thread_block_tile g = cg::tiled_partition(b); + sub_group sg = item_ct1.get_sub_group(); + + int row = item_ct1.get_group(2); + int id = item_ct1.get_local_id(2); + int wid = id / MAX_SG_NUM; + int warp_num = (iteration_stride < row_stride ? iteration_stride : row_stride) / MAX_SG_NUM; + + half2 vals_arr[NORM_REG]; + float2 vals_arr_f[NORM_REG]; + half2 vals_hat_arr[NORM_REG]; + + half2* inp_grad_h = reinterpret_cast(inp_grad); + const half2* out_grad_h = reinterpret_cast(out_grad); + const half2* out_grad_add_h = reinterpret_cast(out_grad_add); + const half2* vals_hat_h = reinterpret_cast(vals_hat); + + inp_grad_h += (row * row_stride); + out_grad_h += (row * row_stride); + if constexpr (is_fuseadd) { out_grad_add_h += (row * row_stride); } + vals_hat_h += (row * row_stride); + + const half2* gamma_h = reinterpret_cast(gamma); + const half2* betta_h = (invertible ? reinterpret_cast(betta) : nullptr); + int high_index = iterations * iteration_stride + id; +#pragma unroll + for (int i = 0; i < iterations; i++) { + half2 gamma_reg = gamma_h[i * iteration_stride + id]; + vals_arr[i] = out_grad_h[i * iteration_stride + id]; + vals_arr[i] *= gamma_reg; + vals_hat_arr[i] = + (invertible + ? (vals_hat_h[i * iteration_stride + id] - betta_h[i * iteration_stride + id]) / + gamma_reg + : vals_hat_h[i * iteration_stride + id]); + } + if ((high_index) < row_stride) { + half2 gamma_reg = gamma_h[high_index]; + vals_arr[iterations] = out_grad_h[high_index]; + vals_arr[iterations] *= gamma_reg; + vals_hat_arr[iterations] = + (invertible ? (vals_hat_h[high_index] - betta_h[high_index]) / gamma_reg + : vals_hat_h[high_index]); + iterations++; + } + half var_h = vars[row]; + half2 var_reg = half2{var_h, var_h}; + + float sum = 0.f; + for (int i = 0; i < iterations; i++) { + half2 result_h = (vals_hat_arr[i] * vals_arr[i] * sqrt(var_reg)); + float2 result_f = result_h.convert(); + sum += result_f.x(); + sum += result_f.y(); + vals_arr[i] *= rsqrt(var_reg); + } + + // for (int i = 1; i < MAX_SG_NUM; i *= 2) { sum += g.shfl_down(sum, i); } + for (int i = 1; i < MAX_SG_NUM; i *= 2) { sum += sg.shuffle_down(sum, i); } + + // if (g.thread_rank() == 0) partialSum[wid] = sum; + if (sg.get_local_id() == 0) partialSum[wid] = sum; + + item_ct1.barrier(); + + // if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; + if (sg.get_local_id() < warp_num) sum = partialSum[sg.get_local_id()]; + +#ifndef __STOCHASTIC_MODE__ + item_ct1.barrier(); +#endif + + // for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); + for (int i = 1; i < warp_num; i *= 2) sum += sg.shuffle_down(sum, i); + + // sum = g.shfl(sum, 0); + sum = sg.shuffle(sum, 0); + sum /= (2 * row_stride); + half2 sum_h = float2{sum, sum}.convert(); + + for (int i = 0; i < iterations; i++) { + half2 temp = ((-sum_h * vals_hat_arr[i]) / (var_reg)); + vals_arr_f[i] = vals_arr[i].convert(); + float2 temp_f = temp.convert(); + vals_arr_f[i].x() += temp_f.x(); + vals_arr_f[i].y() += temp_f.y(); + } + sum = 0.f; + + for (int i = 0; i < iterations; i++) { + sum += (vals_arr_f[i].x()); + sum += (vals_arr_f[i].y()); + } + + // for (int i = 1; i < MAX_SG_NUM; i *= 2) { sum += g.shfl_down(sum, i); } + for (int i = 1; i < MAX_SG_NUM; i *= 2) { sum += sg.shuffle_down(sum, i); } + + // if (g.thread_rank() == 0) partialSum[wid] = sum; + if (sg.get_local_id() == 0) partialSum[wid] = sum; + + item_ct1.barrier(); + + // if (sg.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; + if (sg.get_local_id() < warp_num) sum = partialSum[sg.get_local_id()]; + +#ifndef __STOCHASTIC_MODE__ + item_ct1.barrier(); +#endif + + // for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); + for (int i = 1; i < warp_num; i *= 2) sum += sg.shuffle_down(sum, i); + + // sum = sg.shfl(sum, 0); + sum = sg.shuffle(sum, 0); + sum /= (2 * row_stride); + + iterations = row_stride / iteration_stride; + for (int i = 0; i < iterations; i++) { + vals_arr_f[i].x() -= sum; + vals_arr_f[i].y() -= sum; + half2 temp = vals_arr_f[i].convert(); + if constexpr (is_fuseadd) { + inp_grad_h[i * iteration_stride + id] = + temp + out_grad_add_h[i * iteration_stride + id]; + } else { + inp_grad_h[i * iteration_stride + id] = temp; + } + } + if ((high_index) < row_stride) { + vals_arr_f[iterations].x() -= sum; + vals_arr_f[iterations].y() -= sum; + half2 temp = vals_arr_f[iterations].convert(); + if constexpr (is_fuseadd) { + inp_grad_h[high_index] = temp + out_grad_add_h[high_index]; + } else { + inp_grad_h[high_index] = temp; + } + } +} + +template +void launch_layerNorm_backward(const T* out_grad, + const T* vals_hat, + const T* vars, + const T* gamma, + T* gamma_grad, + T* betta_grad, + T* inp_grad, + int batch, + int hidden_dim, + queue* stream[2], + bool invertible, + const T* betta) +{ + int threads = THREADS; + + range<3> grid_dim(1, 1, hidden_dim / TILE_DIM); + range<3> block_dim(1, TILE_DIM, TILE_DIM); + + stream[0]->submit([&](handler& cgh) { + accessor betta_buffer( + range<2>(MAX_SG_NUM /*MAX_WARP_NUM*/, MAX_SG_NUM1), cgh); + accessor gamma_buffer( + range<2>(MAX_SG_NUM /*MAX_WARP_NUM*/, MAX_SG_NUM1), cgh); + + cgh.parallel_for(nd_range<3>(grid_dim * block_dim, block_dim), + [=](nd_item<3> item_ct1) [[intel::reqd_sub_group_size(MAX_SG_NUM)]] { + LayerNormBackward1(out_grad, + vals_hat, + gamma, + betta, + gamma_grad, + betta_grad, + batch, + hidden_dim, + invertible, + item_ct1, + betta_buffer.get_pointer(), + gamma_buffer.get_pointer()); + }); + }); + // LayerNormBackward1<<>>( + // out_grad, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, + // hidden_dim, invertible); + range<3> grid_dim2(1, 1, batch); + + if (hidden_dim > 16384 && hidden_dim <= 32768) + threads <<= 1; + else if (hidden_dim > 32768 && hidden_dim <= 65536) + threads <<= 2; + else if (hidden_dim > 65536) + throw std::runtime_error("Unsupport hidden_dim."); + + range<3> block_dim2(1, 1, threads); + + stream[1]->submit([&](handler& cgh) { + accessor partialSum_acc_ct1( + range<1>(MAX_SG_NUM /*MAX_WARP_NUM*/), cgh); + + cgh.parallel_for(nd_range<3>(grid_dim2 * block_dim2, block_dim2), + [=](nd_item<3> item_ct1) [[intel::reqd_sub_group_size(MAX_SG_NUM)]] { + LayerNormBackward2(out_grad, + nullptr, + vals_hat, + gamma, + betta, + vars, + inp_grad, + invertible, + hidden_dim, + item_ct1, + partialSum_acc_ct1.get_pointer()); + }); + }); +} + +template void launch_layerNorm_backward(const float* out_grad, + const float* vals_hat, + const float* vars, + const float* gamma, + float* gamma_grad, + float* betta_grad, + float* inp_grad, + int batch, + int hidden_dim, + queue* stream[2], + bool invertible, + const float* betta); + +template void launch_layerNorm_backward(const bf16* out_grad, + const bf16* vals_hat, + const bf16* vars, + const bf16* gamma, + bf16* gamma_grad, + bf16* betta_grad, + bf16* inp_grad, + int batch, + int hidden_dim, + queue* stream[2], + bool invertible, + const bf16* betta); + +template <> +void launch_layerNorm_backward(const half* out_grad, + const half* vals_hat, + const half* vars, + const half* gamma, + half* gamma_grad, + half* betta_grad, + half* inp_grad, + int batch, + int hidden_dim, + queue* stream[2], + bool invertible, + const half* betta) +{ + int threads = THREADS; + + range<3> grid_dim(1, 1, hidden_dim / TILE_DIM); + range<3> block_dim(1, TILE_DIM, TILE_DIM); + + // LayerNormBackward1<<>>( + // out_grad, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, + // hidden_dim, invertible); + stream[0]->submit([&](handler& cgh) { + accessor betta_buffer( + range<2>(MAX_SG_NUM /*MAX_WARP_NUM*/, MAX_SG_NUM1), cgh); + accessor gamma_buffer( + range<2>(MAX_SG_NUM /*MAX_WARP_NUM*/, MAX_SG_NUM1), cgh); + cgh.parallel_for(nd_range<3>(grid_dim * block_dim, block_dim), + [=](nd_item<3> item_ct1) [[intel::reqd_sub_group_size(MAX_SG_NUM)]] { + LayerNormBackward1(out_grad, + vals_hat, + gamma, + betta, + gamma_grad, + betta_grad, + batch, + hidden_dim, + invertible, + item_ct1, + betta_buffer.get_pointer(), + gamma_buffer.get_pointer()); + }); + }); + + range<3> grid_dim2(1, 1, batch); + + if (hidden_dim > 8192 && hidden_dim <= 16384) + threads <<= 1; + else if (hidden_dim > 16384 && hidden_dim <= 32768) + threads <<= 2; + else if (hidden_dim > 32768 && hidden_dim <= 65536) + threads <<= 3; + else if (hidden_dim > 65536) + throw std::runtime_error("Unsupport hidden_dim."); + + range<3> block_dim2(1, 1, threads / 2); + + stream[1]->submit([&](handler& cgh) { + accessor partialSum_acc_ct1( + range<1>(MAX_SG_NUM /*MAX_WARP_NUM*/), cgh); + + cgh.parallel_for(nd_range<3>(grid_dim2 * block_dim2, block_dim2), + [=](nd_item<3> item_ct1) [[intel::reqd_sub_group_size(MAX_SG_NUM)]] { + LayerNormBackward2(out_grad, + nullptr, + vals_hat, + gamma, + betta, + vars, + inp_grad, + invertible, + hidden_dim / 2, + item_ct1, + partialSum_acc_ct1.get_pointer()); + }); + }); +} + +/* Backward Normalize (Input-Gradient) + * Using the means and variances from the input + * This type of backward is not invertible! + * We do the backward using the input (X) + */ +template +void LayerNormBackward2(const float* out_grad, + const float* out_grad_add, + const float* X_vals, + const float* gamma, + const float* vars, + const float* means, + float* inp_grad, + int row_stride, + nd_item<3> item_ct1, + float* partialSum) +{ + int iteration_stride = item_ct1.get_local_range(2); + int iterations = row_stride / iteration_stride; + + // group<3> b = item_ct1.get_group(); + // cg::thread_block_tile g = cg::tiled_partition(b); + sub_group sg = item_ct1.get_sub_group(); + + int row = item_ct1.get_group(2); + int id = item_ct1.get_local_id(2); + int wid = id / MAX_SG_NUM; + int warp_num = (THREADS < row_stride ? THREADS : row_stride) / MAX_SG_NUM; + + out_grad += (row * row_stride); + if constexpr (is_fuseadd) { out_grad_add += (row * row_stride); } + X_vals += (row * row_stride); + inp_grad += (row * row_stride); + + float vals_arr[NORM_REG]; + int high_index = iterations * iteration_stride + id; +#pragma unroll + for (int i = 0; i < iterations; i++) { + float gamma_reg = gamma[i * iteration_stride + id]; + vals_arr[i] = out_grad[i * iteration_stride + id]; + vals_arr[i] *= gamma_reg; + } + if ((high_index) < row_stride) { + float gamma_reg = gamma[high_index]; + vals_arr[iterations] = out_grad[high_index]; + vals_arr[iterations] *= gamma_reg; + iterations++; + } + + float var_reg = vars[row]; + float mean_reg = means[row]; + + float sum = 0; + float xu[NORM_REG]; + for (int i = 0; i < iterations; i++) { + xu[i] = (X_vals[i * iteration_stride + id] - mean_reg); + sum += vals_arr[i] * xu[i]; + vals_arr[i] *= rsqrt(var_reg); + } + + // for (int i = 1; i < MAX_SG_NUM; i *= 2) { sum += g.shfl_down(sum, i); } + for (int i = 1; i < MAX_SG_NUM; i *= 2) { sum += sg.shuffle_down(sum, i); } + + // if (g.thread_rank() == 0) partialSum[wid] = sum; + if (sg.get_local_id() == 0) partialSum[wid] = sum; + + item_ct1.barrier(); + + // if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; + if (sg.get_local_id() < warp_num) sum = partialSum[sg.get_local_id()]; + +#ifndef __STOCHASTIC_MODE__ + + item_ct1.barrier(); +#endif + + // for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); + for (int i = 1; i < warp_num; i *= 2) sum += sg.shuffle_down(sum, i); + + // sum = g.shfl(sum, 0); + sum = sg.shuffle(sum, 0); + sum /= row_stride; + + for (int i = 0; i < iterations; i++) { + vals_arr[i] += (-sum * xu[i] * rsqrt(var_reg) / (var_reg)); + } + + sum = 0; + for (int i = 0; i < iterations; i++) { sum += vals_arr[i]; } + + // for (int i = 1; i < MAX_SG_NUM; i *= 2) { sum += g.shfl_down(sum, i); } + for (int i = 1; i < MAX_SG_NUM; i *= 2) { sum += sg.shuffle_down(sum, i); } + + // if (g.thread_rank() == 0) partialSum[wid] = sum; + if (sg.get_local_id() == 0) partialSum[wid] = sum; + + item_ct1.barrier(); + + // if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; + if (sg.get_local_id() < warp_num) sum = partialSum[sg.get_local_id()]; + +#ifndef __STOCHASTIC_MODE__ + item_ct1.barrier(); +#endif + + for (int i = 1; i < warp_num; i *= 2) sum += sg.shuffle_down(sum, i); + sum = sg.shuffle(sum, 0); + sum /= row_stride; + + iterations = row_stride / iteration_stride; + for (int i = 0; i < iterations; i++) + if constexpr (is_fuseadd) { + inp_grad[i * iteration_stride + id] = + (vals_arr[i] - sum) + out_grad_add[i * iteration_stride + id]; + } else { + inp_grad[i * iteration_stride + id] = (vals_arr[i] - sum); + } + if ((high_index) < row_stride) + if constexpr (is_fuseadd) { + inp_grad[high_index] = (vals_arr[iterations] - sum) + out_grad_add[high_index]; + } else { + inp_grad[high_index] = (vals_arr[iterations] - sum); + } +} + +/* Backward Normalize (Input-Gradient) + * Using the means and variances from the input + * This type of backward is not invertible! + * We do the backward using the input (X) + */ +template +void LayerNormBackward2(const bf16* out_grad, + const bf16* out_grad_add, + const bf16* X_vals, + const bf16* gamma, + const bf16* vars, + const bf16* means, + bf16* inp_grad, + int row_stride, + nd_item<3> item_ct1, + float* partialSum) +{ + int iteration_stride = item_ct1.get_local_range(2); + int iterations = row_stride / iteration_stride; + + // group<3> b = item_ct1.get_group(); + // cg::thread_block_tile g = cg::tiled_partition(b); + sub_group sg = item_ct1.get_sub_group(); + + const ushort* out_grad_cast = reinterpret_cast(out_grad); + const ushort* out_grad_add_cast = reinterpret_cast(out_grad_add); + const ushort* X_vals_cast = reinterpret_cast(X_vals); + const ushort* gamma_cast = reinterpret_cast(gamma); + const ushort* vars_cast = reinterpret_cast(vars); + const ushort* means_cast = reinterpret_cast(means); + ushort* inp_grad_cast = reinterpret_cast(inp_grad); + + int row = item_ct1.get_group(2); + int id = item_ct1.get_local_id(2); + int wid = id / MAX_SG_NUM; + int warp_num = (THREADS < row_stride ? THREADS : row_stride) / MAX_SG_NUM; + + out_grad_cast += (row * row_stride); + if constexpr (is_fuseadd) { out_grad_add_cast += (row * row_stride); } + X_vals_cast += (row * row_stride); + inp_grad_cast += (row * row_stride); + + float vals_arr[NORM_REG]; + int high_index = iterations * iteration_stride + id; +#pragma unroll + for (int i = 0; i < iterations; i++) { + float gamma_reg = bf16::to_float(gamma_cast[i * iteration_stride + id]); + vals_arr[i] = bf16::to_float(out_grad_cast[i * iteration_stride + id]); + vals_arr[i] *= gamma_reg; + } + if ((high_index) < row_stride) { + float gamma_reg = bf16::to_float(gamma_cast[high_index]); + vals_arr[iterations] = bf16::to_float(out_grad_cast[high_index]); + vals_arr[iterations] *= gamma_reg; + iterations++; + } + + float var_reg = bf16::to_float(vars_cast[row]); + float mean_reg = bf16::to_float(means_cast[row]); + + float sum = 0; + float xu[NORM_REG]; + for (int i = 0; i < iterations; i++) { + xu[i] = (bf16::to_float(X_vals_cast[i * iteration_stride + id]) - mean_reg); + sum += vals_arr[i] * xu[i]; + vals_arr[i] *= rsqrt(var_reg); + } + + // for (int i = 1; i < MAX_SG_NUM; i *= 2) { sum += g.shfl_down(sum, i); } + for (int i = 1; i < MAX_SG_NUM; i *= 2) { sum += sg.shuffle_down(sum, i); } + + // if (g.thread_rank() == 0) partialSum[wid] = sum; + if (sg.get_local_id() == 0) partialSum[wid] = sum; + + item_ct1.barrier(); + + // if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; + if (sg.get_local_id() < warp_num) sum = partialSum[sg.get_local_id()]; + +#ifndef __STOCHASTIC_MODE__ + item_ct1.barrier(); +#endif + + // for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); + for (int i = 1; i < warp_num; i *= 2) sum += sg.shuffle_down(sum, i); + + // sum = g.shfl(sum, 0); + sum = sg.shuffle(sum, 0); + sum /= row_stride; + + for (int i = 0; i < iterations; i++) { + vals_arr[i] += (-sum * xu[i] * rsqrt(var_reg) / (var_reg)); + } + + sum = 0; + for (int i = 0; i < iterations; i++) { sum += vals_arr[i]; } + + // for (int i = 1; i < MAX_SG_NUM; i *= 2) { sum += g.shfl_down(sum, i); } + for (int i = 1; i < MAX_SG_NUM; i *= 2) { sum += sg.shuffle_down(sum, i); } + + // if (g.thread_rank() == 0) partialSum[wid] = sum; + if (sg.get_local_id() == 0) partialSum[wid] = sum; + item_ct1.barrier(); + + // if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; + if (sg.get_local_id() < warp_num) sum = partialSum[sg.get_local_id()]; + +#ifndef __STOCHASTIC_MODE__ + item_ct1.barrier(); +#endif + + for (int i = 1; i < warp_num; i *= 2) sum += sg.shuffle_down(sum, i); + sum = sg.shuffle(sum, 0); + sum /= row_stride; + + iterations = row_stride / iteration_stride; + for (int i = 0; i < iterations; i++) + if constexpr (is_fuseadd) { + inp_grad_cast[i * iteration_stride + id] = bf16::from_float( + (vals_arr[i] - sum) + bf16::to_float(out_grad_add_cast[i * iteration_stride + id])); + } else { + inp_grad_cast[i * iteration_stride + id] = bf16::from_float(vals_arr[i] - sum); + } + if ((high_index) < row_stride) + if constexpr (is_fuseadd) { + inp_grad_cast[high_index] = bf16::from_float( + (vals_arr[iterations] - sum) + bf16::to_float(out_grad_add_cast[high_index])); + } else { + inp_grad_cast[high_index] = bf16::from_float(vals_arr[iterations] - sum); + } +} + +template +void LayerNormBackward2(const half* out_grad, + const half* out_grad_add, + const half* X_vals, + const half* gamma, + const half* vars, + const half* means, + half* inp_grad, + int row_stride, + nd_item<3> item_ct1, + float* partialSum) +{ + int iteration_stride = item_ct1.get_local_range(2); + int iterations = row_stride / iteration_stride; + + // group<3> b = item_ct1.get_group(); + // cg::thread_block_tile g = cg::tiled_partition(b); + sub_group sg = item_ct1.get_sub_group(); + + int row = item_ct1.get_group(2); + int id = item_ct1.get_local_id(2); + int wid = id / MAX_SG_NUM; + int warp_num = (iteration_stride < row_stride ? iteration_stride : row_stride) / MAX_SG_NUM; + + half2 vals_arr[NORM_REG]; + float2 vals_arr_f[NORM_REG]; + + half2* inp_grad_h = reinterpret_cast(inp_grad); + const half2* out_grad_h = reinterpret_cast(out_grad); + const half2* out_grad_add_h = reinterpret_cast(out_grad_add); + const half2* vals_hat_h = reinterpret_cast(X_vals); + + inp_grad_h += (row * row_stride); + out_grad_h += (row * row_stride); + if constexpr (is_fuseadd) { out_grad_add_h += (row * row_stride); } + vals_hat_h += (row * row_stride); + + const half2* gamma_h = reinterpret_cast(gamma); + int high_index = iterations * iteration_stride + id; +#pragma unroll + for (int i = 0; i < iterations; i++) { + half2 gamma_reg = gamma_h[i * iteration_stride + id]; + vals_arr[i] = out_grad_h[i * iteration_stride + id]; + vals_arr[i] *= gamma_reg; // out_grad * gamma + } + if ((high_index) < row_stride) { + half2 gamma_reg = gamma_h[high_index]; + vals_arr[iterations] = out_grad_h[high_index]; + vals_arr[iterations] *= gamma_reg; // out_grad * gamma + iterations++; + } + half mean_h = means[row]; + half var_h = vars[row]; + half2 var_reg = half2{var_h, var_h}; + half2 mean_reg = half2{mean_h, mean_h}; + half2 xu[NORM_REG]; + + float sum = 0.f; + for (int i = 0; i < iterations; i++) { + xu[i] = (vals_hat_h[i * iteration_stride + id] - mean_reg); + half2 result_h = (xu[i] * vals_arr[i]); + float2 result_f = result_h.convert(); + sum += result_f.x(); + sum += result_f.y(); + vals_arr[i] *= rsqrt(var_reg); + } + + for (int i = 1; i < MAX_SG_NUM; i *= 2) { sum += sg.shuffle_down(sum, i); } + + if (sg.get_local_id() == 0) partialSum[wid] = sum; + + item_ct1.barrier(); + + if (sg.get_local_id() < warp_num) sum = partialSum[sg.get_local_id()]; + +#ifndef __STOCHASTIC_MODE__ + item_ct1.barrier(); +#endif + + for (int i = 1; i < warp_num; i *= 2) sum += sg.shuffle_down(sum, i); + + sum = sg.shuffle(sum, 0); + sum /= (2 * row_stride); + half2 sum_h = float2{sum, sum}.convert(); + + for (int i = 0; i < iterations; i++) { + half2 xu_grad = ((-sum_h * xu[i] * rsqrt(var_reg)) / (var_reg)); + vals_arr_f[i] = vals_arr[i].convert(); + float2 xu_grad_f = xu_grad.convert(); + vals_arr_f[i].x() += xu_grad_f.x(); + vals_arr_f[i].y() += xu_grad_f.y(); + } + + sum = 0.f; + for (int i = 0; i < iterations; i++) { + sum += (vals_arr_f[i].x()); + sum += (vals_arr_f[i].y()); + } + + for (int i = 1; i < MAX_SG_NUM; i *= 2) { sum += sg.shuffle_down(sum, i); } + + if (sg.get_local_id() == 0) partialSum[wid] = sum; + + item_ct1.barrier(); + + if (sg.get_local_id() < warp_num) sum = partialSum[sg.get_local_id()]; + +#ifndef __STOCHASTIC_MODE__ + item_ct1.barrier(); +#endif + + for (int i = 1; i < warp_num; i *= 2) sum += sg.shuffle_down(sum, i); + + sum = sg.shuffle(sum, 0); + sum /= (2 * row_stride); + + iterations = row_stride / iteration_stride; + for (int i = 0; i < iterations; i++) { + vals_arr_f[i].x() -= sum; + vals_arr_f[i].y() -= sum; + half2 temp = vals_arr_f[i].convert(); + if constexpr (is_fuseadd) { + inp_grad_h[i * iteration_stride + id] = + temp + out_grad_add_h[i * iteration_stride + id]; + } else { + inp_grad_h[i * iteration_stride + id] = temp; + } + } + if ((high_index) < row_stride) { + vals_arr_f[iterations].x() -= sum; + vals_arr_f[iterations].y() -= sum; + half2 temp = vals_arr_f[iterations].convert(); + if constexpr (is_fuseadd) { + inp_grad_h[high_index] = temp + out_grad_add_h[high_index]; + } else { + inp_grad_h[high_index] = temp; + } + } +} + +template +void launch_layerNorm_backward(const T* out_grad, + const T* X_data, + const T* vars, + const T* means, + const T* gamma, + T* gamma_grad, + T* betta_grad, + T* inp_grad, + int batch, + int hidden_dim, + queue* stream[2]) +{ + int threads = THREADS; + + range<3> grid_dim(1, 1, hidden_dim / TILE_DIM); + range<3> block_dim(1, TILE_DIM, TILE_DIM); + + // LayerNormBackward1<<>>( + // out_grad, X_data, vars, means, gamma_grad, betta_grad, batch, + // hidden_dim); + stream[0]->submit([&](handler& cgh) { + accessor betta_buffer( + range<2>(MAX_SG_NUM /*MAX_WARP_NUM*/, MAX_SG_NUM1), cgh); + accessor gamma_buffer( + range<2>(MAX_SG_NUM /*MAX_WARP_NUM*/, MAX_SG_NUM1), cgh); + + cgh.parallel_for(nd_range<3>(grid_dim * block_dim, block_dim), + [=](nd_item<3> item_ct1) [[intel::reqd_sub_group_size(MAX_SG_NUM)]] { + LayerNormBackward1(out_grad, + X_data, + vars, + means, + gamma_grad, + betta_grad, + batch, + hidden_dim, + item_ct1, + betta_buffer.get_pointer(), + gamma_buffer.get_pointer()); + }); + }); + + range<3> grid_dim2(1, 1, batch); + + if (hidden_dim > 16384 && hidden_dim <= 32768) + threads <<= 1; + else if (hidden_dim > 32768 && hidden_dim <= 65536) + threads <<= 2; + else if (hidden_dim > 65536) + throw std::runtime_error("Unsupport hidden_dim."); + + range<3> block_dim2(1, 1, threads); + stream[1]->submit([&](handler& cgh) { + accessor partialSum_acc_ct1( + range<1>(MAX_SG_NUM /*MAX_WARP_NUM*/), cgh); + + cgh.parallel_for(nd_range<3>(grid_dim2 * block_dim2, block_dim2), + [=](nd_item<3> item_ct1) [[intel::reqd_sub_group_size(MAX_SG_NUM)]] { + LayerNormBackward2(out_grad, + nullptr, + X_data, + gamma, + vars, + means, + inp_grad, + hidden_dim, + item_ct1, + partialSum_acc_ct1.get_pointer()); + }); + }); +} + +template void launch_layerNorm_backward(const float* out_grad, + const float* X_data, + const float* vars, + const float* means, + const float* gamma, + float* gamma_grad, + float* betta_grad, + float* inp_grad, + int batch, + int hidden_dim, + queue* stream[2]); +template void launch_layerNorm_backward(const bf16* out_grad, + const bf16* X_data, + const bf16* vars, + const bf16* means, + const bf16* gamma, + bf16* gamma_grad, + bf16* betta_grad, + bf16* inp_grad, + int batch, + int hidden_dim, + queue* stream[2]); +template <> +void launch_layerNorm_backward(const half* out_grad, + const half* X_data, + const half* vars, + const half* means, + const half* gamma, + half* gamma_grad, + half* betta_grad, + half* inp_grad, + int batch, + int hidden_dim, + queue* stream[2]) +{ + int threads = THREADS; + + range<3> grid_dim(1, 1, hidden_dim / TILE_DIM); + range<3> block_dim(1, TILE_DIM, TILE_DIM); + + // LayerNormBackward1<<>>( + // out_grad, X_data, vars, means, gamma_grad, betta_grad, batch, + // hidden_dim); + stream[0]->submit([&](handler& cgh) { + accessor betta_buffer( + range<2>(MAX_SG_NUM /*MAX_WARP_NUM*/, MAX_SG_NUM1), cgh); + accessor gamma_buffer( + range<2>(MAX_SG_NUM /*MAX_WARP_NUM*/, MAX_SG_NUM1), cgh); + + cgh.parallel_for(nd_range<3>(grid_dim * block_dim, block_dim), + [=](nd_item<3> item_ct1) [[intel::reqd_sub_group_size(MAX_SG_NUM)]] { + LayerNormBackward1(out_grad, + X_data, + vars, + means, + gamma_grad, + betta_grad, + batch, + hidden_dim, + item_ct1, + betta_buffer.get_pointer(), + gamma_buffer.get_pointer()); + }); + }); + + range<3> grid_dim2(1, 1, batch); + + if (hidden_dim > 8192 && hidden_dim <= 16384) + threads <<= 1; + else if (hidden_dim > 16384 && hidden_dim <= 32768) + threads <<= 2; + else if (hidden_dim > 32768 && hidden_dim <= 65536) + threads <<= 3; + else if (hidden_dim > 65536) + throw std::runtime_error("Unsupport hidden_dim."); + + range<3> block_dim2(1, 1, threads / 2); + stream[1]->submit([&](handler& cgh) { + accessor partialSum_acc_ct1( + range<1>(MAX_SG_NUM /*MAX_WARP_NUM*/), cgh); + + cgh.parallel_for(nd_range<3>(grid_dim2 * block_dim2, block_dim2), + [=](nd_item<3> item_ct1) [[intel::reqd_sub_group_size(MAX_SG_NUM)]] { + LayerNormBackward2(out_grad, + nullptr, + X_data, + gamma, + vars, + means, + inp_grad, + hidden_dim / 2, + item_ct1, + partialSum_acc_ct1.get_pointer()); + }); + }); +} + +template +void launch_layerNorm_backward_fused_add(const T* out_grad1, + const T* out_grad2, + const T* vals_hat, + const T* vars, + const T* gamma, + T* gamma_grad, + T* betta_grad, + T* inp_grad, + int batch, + int hidden_dim, + queue* stream[2], + bool invertible, + const T* betta) +{ + int threads = THREADS; + + range<3> grid_dim(1, 1, hidden_dim / TILE_DIM); + range<3> block_dim(1, TILE_DIM, TILE_DIM); + // LayerNormBackward1<<>>( + // out_grad1, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, + // hidden_dim, invertible); + stream[0]->submit([&](handler& cgh) { + accessor betta_buffer( + range<2>(MAX_SG_NUM /*MAX_WARP_NUM*/, MAX_SG_NUM1), cgh); + accessor gamma_buffer( + range<2>(MAX_SG_NUM /*MAX_WARP_NUM*/, MAX_SG_NUM1), cgh); + cgh.parallel_for(nd_range<3>(grid_dim * block_dim, block_dim), + [=](nd_item<3> item_ct1) [[intel::reqd_sub_group_size(MAX_SG_NUM)]] { + LayerNormBackward1(out_grad1, + vals_hat, + gamma, + betta, + gamma_grad, + betta_grad, + batch, + hidden_dim, + invertible, + item_ct1, + betta_buffer.get_pointer(), + gamma_buffer.get_pointer()); + }); + }); + + range<3> grid_dim2(1, 1, batch); + + if (hidden_dim > 16384 && hidden_dim <= 32768) + threads <<= 1; + else if (hidden_dim > 32768 && hidden_dim <= 65536) + threads <<= 2; + else if (hidden_dim > 65536) + throw std::runtime_error("Unsupport hidden_dim."); + + range<3> block_dim2(1, 1, threads); + + stream[1]->submit([&](handler& cgh) { + accessor partialSum_acc_ct1( + range<1>(MAX_SG_NUM /*MAX_WARP_NUM*/), cgh); + + cgh.parallel_for(nd_range<3>(grid_dim2 * block_dim2, block_dim2), + [=](nd_item<3> item_ct1) [[intel::reqd_sub_group_size(MAX_SG_NUM)]] { + LayerNormBackward2(out_grad1, + out_grad2, + vals_hat, + gamma, + betta, + vars, + inp_grad, + invertible, + hidden_dim, + item_ct1, + partialSum_acc_ct1.get_pointer()); + }); + }); +} + +template void launch_layerNorm_backward_fused_add(const float* out_grad1, + const float* out_grad2, + const float* vals_hat, + const float* vars, + const float* gamma, + float* gamma_grad, + float* betta_grad, + float* inp_grad, + int batch, + int hidden_dim, + queue* stream[2], + bool invertible, + const float* betta); + +template void launch_layerNorm_backward_fused_add(const bf16* out_grad1, + const bf16* out_grad2, + const bf16* vals_hat, + const bf16* vars, + const bf16* gamma, + bf16* gamma_grad, + bf16* betta_grad, + bf16* inp_grad, + int batch, + int hidden_dim, + queue* stream[2], + bool invertible, + const bf16* betta); + +template <> +void launch_layerNorm_backward_fused_add(const half* out_grad1, + const half* out_grad2, + const half* vals_hat, + const half* vars, + const half* gamma, + half* gamma_grad, + half* betta_grad, + half* inp_grad, + int batch, + int hidden_dim, + queue* stream[2], + bool invertible, + const half* betta) +{ + int threads = THREADS; + + range<3> grid_dim(1, 1, hidden_dim / TILE_DIM); + range<3> block_dim(1, TILE_DIM, TILE_DIM); + + // LayerNormBackward1<<>>( + // out_grad1, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, + // hidden_dim, invertible); + stream[0]->submit([&](handler& cgh) { + accessor betta_buffer( + range<2>(MAX_SG_NUM /*MAX_WARP_NUM*/, MAX_SG_NUM1), cgh); + accessor gamma_buffer( + range<2>(MAX_SG_NUM /*MAX_WARP_NUM*/, MAX_SG_NUM1), cgh); + cgh.parallel_for(nd_range<3>(grid_dim * block_dim, block_dim), + [=](nd_item<3> item_ct1) [[intel::reqd_sub_group_size(MAX_SG_NUM)]] { + LayerNormBackward1(out_grad1, + vals_hat, + gamma, + betta, + gamma_grad, + betta_grad, + batch, + hidden_dim, + invertible, + item_ct1, + betta_buffer.get_pointer(), + gamma_buffer.get_pointer()); + }); + }); + + range<3> grid_dim2(1, 1, batch); + + if (hidden_dim > 8192 && hidden_dim <= 16384) + threads <<= 1; + else if (hidden_dim > 16384 && hidden_dim <= 32768) + threads <<= 2; + else if (hidden_dim > 32768 && hidden_dim <= 65536) + threads <<= 3; + else if (hidden_dim > 65536) + throw std::runtime_error("Unsupport hidden_dim."); + + range<3> block_dim2(1, 1, threads / 2); + stream[1]->submit([&](handler& cgh) { + accessor partialSum_acc_ct1( + range<1>(MAX_SG_NUM /*MAX_WARP_NUM*/), cgh); + + cgh.parallel_for(nd_range<3>(grid_dim2 * block_dim2, block_dim2), + [=](nd_item<3> item_ct1) [[intel::reqd_sub_group_size(MAX_SG_NUM)]] { + LayerNormBackward2(out_grad1, + out_grad2, + vals_hat, + gamma, + betta, + vars, + inp_grad, + invertible, + hidden_dim / 2, + item_ct1, + partialSum_acc_ct1.get_pointer()); + }); + }); +} + +template +void launch_layerNorm_backward_fused_add(const T* out_grad1, + const T* out_grad2, + const T* X_data, + const T* vars, + const T* means, + const T* gamma, + T* gamma_grad, + T* betta_grad, + T* inp_grad, + int batch, + int hidden_dim, + queue* stream[2]) +{ + int threads = THREADS; + + range<3> grid_dim(1, 1, hidden_dim / TILE_DIM); + range<3> block_dim(1, TILE_DIM, TILE_DIM); + + // LayerNormBackward1<<>>( + // out_grad1, X_data, vars, means, gamma_grad, betta_grad, batch, + // hidden_dim); + stream[0]->submit([&](handler& cgh) { + accessor betta_buffer( + range<2>(MAX_SG_NUM /*MAX_WARP_NUM*/, MAX_SG_NUM1), cgh); + accessor gamma_buffer( + range<2>(MAX_SG_NUM /*MAX_WARP_NUM*/, MAX_SG_NUM1), cgh); + + cgh.parallel_for(nd_range<3>(grid_dim * block_dim, block_dim), + [=](nd_item<3> item_ct1) [[intel::reqd_sub_group_size(MAX_SG_NUM)]] { + LayerNormBackward1(out_grad1, + X_data, + vars, + means, + gamma_grad, + betta_grad, + batch, + hidden_dim, + item_ct1, + betta_buffer.get_pointer(), + gamma_buffer.get_pointer()); + }); + }); + + range<3> grid_dim2(1, 1, batch); + + if (hidden_dim > 16384 && hidden_dim <= 32768) + threads <<= 1; + else if (hidden_dim > 32768 && hidden_dim <= 65536) + threads <<= 2; + else if (hidden_dim > 65536) + throw std::runtime_error("Unsupport hidden_dim."); + + range<3> block_dim2(1, 1, threads); + stream[1]->submit([&](handler& cgh) { + accessor partialSum_acc_ct1( + range<1>(MAX_SG_NUM /*MAX_WARP_NUM*/), cgh); + + cgh.parallel_for(nd_range<3>(grid_dim2 * block_dim2, block_dim2), + [=](nd_item<3> item_ct1) [[intel::reqd_sub_group_size(MAX_SG_NUM)]] { + LayerNormBackward2(out_grad1, + out_grad2, + X_data, + gamma, + vars, + means, + inp_grad, + hidden_dim, + item_ct1, + partialSum_acc_ct1.get_pointer()); + }); + }); +} + +template void launch_layerNorm_backward_fused_add(const float* out_grad1, + const float* out_grad2, + const float* X_data, + const float* vars, + const float* means, + const float* gamma, + float* gamma_grad, + float* betta_grad, + float* inp_grad, + int batch, + int hidden_dim, + queue* stream[2]); +template void launch_layerNorm_backward_fused_add(const bf16* out_grad1, + const bf16* out_grad2, + const bf16* X_data, + const bf16* vars, + const bf16* means, + const bf16* gamma, + bf16* gamma_grad, + bf16* betta_grad, + bf16* inp_grad, + int batch, + int hidden_dim, + queue* stream[2]); +template <> +void launch_layerNorm_backward_fused_add(const half* out_grad1, + const half* out_grad2, + const half* X_data, + const half* vars, + const half* means, + const half* gamma, + half* gamma_grad, + half* betta_grad, + half* inp_grad, + int batch, + int hidden_dim, + queue* stream[2]) +{ + int threads = THREADS; + + range<3> grid_dim(1, 1, hidden_dim / TILE_DIM); + range<3> block_dim(1, TILE_DIM, TILE_DIM); + + // LayerNormBackward1<<>>( + // out_grad1, X_data, vars, means, gamma_grad, betta_grad, batch, + // hidden_dim); + stream[0]->submit([&](handler& cgh) { + accessor betta_buffer( + range<2>(MAX_SG_NUM /*MAX_WARP_NUM*/, MAX_SG_NUM1), cgh); + accessor gamma_buffer( + range<2>(MAX_SG_NUM /*MAX_WARP_NUM*/, MAX_SG_NUM1), cgh); + cgh.parallel_for(nd_range<3>(grid_dim * block_dim, block_dim), + [=](nd_item<3> item_ct1) [[intel::reqd_sub_group_size(MAX_SG_NUM)]] { + LayerNormBackward1(out_grad1, + X_data, + vars, + means, + gamma_grad, + betta_grad, + batch, + hidden_dim, + item_ct1, + betta_buffer.get_pointer(), + gamma_buffer.get_pointer()); + }); + }); + + range<3> grid_dim2(1, 1, batch); + + if (hidden_dim > 8192 && hidden_dim <= 16384) + threads <<= 1; + else if (hidden_dim > 16384 && hidden_dim <= 32768) + threads <<= 2; + else if (hidden_dim > 32768 && hidden_dim <= 65536) + threads <<= 3; + else if (hidden_dim > 65536) + throw std::runtime_error("Unsupport hidden_dim."); + + range<3> block_dim2(1, 1, threads / 2); + stream[1]->submit([&](handler& cgh) { + accessor partialSum_acc_ct1( + range<1>(MAX_SG_NUM /*MAX_WARP_NUM*/), cgh); + + cgh.parallel_for(nd_range<3>(grid_dim2 * block_dim2, block_dim2), + [=](nd_item<3> item_ct1) [[intel::reqd_sub_group_size(MAX_SG_NUM)]] { + LayerNormBackward2(out_grad1, + out_grad2, + X_data, + gamma, + vars, + means, + inp_grad, + hidden_dim / 2, + item_ct1, + partialSum_acc_ct1.get_pointer()); + }); + }); +} diff --git a/intel_extension_for_deepspeed/op_builder/csrc/transformer/sycl/onednn_wrappers.dp.cpp b/intel_extension_for_deepspeed/op_builder/csrc/transformer/sycl/onednn_wrappers.dp.cpp new file mode 100644 index 0000000..dbeef56 --- /dev/null +++ b/intel_extension_for_deepspeed/op_builder/csrc/transformer/sycl/onednn_wrappers.dp.cpp @@ -0,0 +1,124 @@ +#include "sycl/onednn_wrappers.hpp" +#include + +template +inline int onednn_matmul(sycl::queue* handle, + bool trans_src, + bool trans_wgt, + int m, + int n, + int k, + const float alpha, + const float beta, + const bf16* src_ptr, + const bf16* wgt_ptr, + bf16* dst_ptr, + int batch) +{ + /* + * src, [m, k], m: batch, k: in_feature + * wgt, [k, n], n: k: in_features, out_feature + * dst, [m, n], m: batch, n: out_features + */ + device dev = handle->get_device(); + context ctx = handle->get_context(); + dnnl::engine engine = dnnl::sycl_interop::make_engine(dev, ctx); + dnnl::stream stream = dnnl::sycl_interop::make_stream(engine, *handle); + + dnnl::memory::dims src_dims, wgt_dims, dst_dims; + + if constexpr (bmm) { + src_dims = {batch, m, k}; + wgt_dims = {batch, k, n}; + dst_dims = {batch, m, n}; + } else { + src_dims = {m, k}; + wgt_dims = {k, n}; + dst_dims = {m, n}; + } + + dnnl::memory::desc src_md, wgt_md, dst_md; + + if constexpr (bmm) { + src_md = dnnl::memory::desc( + src_dims, + dnnl::memory::data_type::bf16, + trans_src ? dnnl::memory::format_tag::acb : dnnl::memory::format_tag::abc); + wgt_md = dnnl::memory::desc( + wgt_dims, + dnnl::memory::data_type::bf16, + trans_wgt ? dnnl::memory::format_tag::acb : dnnl::memory::format_tag::abc); + dst_md = dnnl::memory::desc( + dst_dims, dnnl::memory::data_type::bf16, dnnl::memory::format_tag::abc); + } else { + src_md = dnnl::memory::desc( + src_dims, + dnnl::memory::data_type::bf16, + trans_src ? dnnl::memory::format_tag::ba : dnnl::memory::format_tag::ab); + wgt_md = dnnl::memory::desc( + wgt_dims, + dnnl::memory::data_type::bf16, + trans_wgt ? dnnl::memory::format_tag::ba : dnnl::memory::format_tag::ab); + dst_md = dnnl::memory::desc( + dst_dims, dnnl::memory::data_type::bf16, dnnl::memory::format_tag::ab); + } + + auto src_mem = dnnl::memory(src_md, engine, (void*)src_ptr); + auto wgt_mem = dnnl::memory(wgt_md, engine, (void*)wgt_ptr); + auto dst_mem = dnnl::memory(dst_md, engine, (void*)dst_ptr); + + auto matmul_d = dnnl::matmul::desc(src_md, wgt_md, dst_md); + + dnnl::primitive_attr attr; + if (alpha != 1.0f) attr.set_output_scales(0, {alpha}); + if (beta != 0.0f) { + dnnl::post_ops po; + po.append_sum(beta); + attr.set_post_ops(po); + } + + auto matmul_pd = dnnl::matmul::primitive_desc(matmul_d, attr, engine); + + auto matmul_prim = dnnl::matmul(matmul_pd); + + std::unordered_map matmul_args; + matmul_args.insert({DNNL_ARG_SRC, src_mem}); + matmul_args.insert({DNNL_ARG_WEIGHTS, wgt_mem}); + matmul_args.insert({DNNL_ARG_DST, dst_mem}); + + matmul_prim.execute(stream, matmul_args); + stream.wait(); +} + +int onednn_matmul_ex(sycl::queue* handle, + bool trans_src, + bool trans_wgt, + int m, + int n, + int k, + const float alpha, + const float beta, + const bf16* src_ptr, + const bf16* wgt_ptr, + bf16* dst_ptr) +{ + onednn_matmul( + handle, trans_src, trans_wgt, m, n, k, alpha, beta, src_ptr, wgt_ptr, dst_ptr, 1); +} + +int onednn_batchgemm(sycl::queue* handle, + int m, + int n, + int k, + const float alpha, + const float beta, + const bf16* src_ptr, + const bf16* wgt_ptr, + bf16* dst_ptr, + bool trans_src, + bool trans_wgt, + int batch) +{ + onednn_matmul( + handle, trans_src, trans_wgt, m, n, k, alpha, beta, src_ptr, wgt_ptr, dst_ptr, batch); +} diff --git a/intel_extension_for_deepspeed/op_builder/csrc/transformer/sycl/onemkl_wrappers.dp.cpp b/intel_extension_for_deepspeed/op_builder/csrc/transformer/sycl/onemkl_wrappers.dp.cpp new file mode 100644 index 0000000..b20eb78 --- /dev/null +++ b/intel_extension_for_deepspeed/op_builder/csrc/transformer/sycl/onemkl_wrappers.dp.cpp @@ -0,0 +1,144 @@ +#include "sycl/onemkl_wrappers.hpp" +#include + +int onemkl_gemm_ex(sycl::queue* handle, + oneapi::mkl::transpose transa, + oneapi::mkl::transpose transb, + int m, + int n, + int k, + const float alpha, + const float beta, + const float* A, + const float* B, + float* C) +{ + try { + int lda = (transa == oneapi::mkl::transpose::nontrans) ? m : k; + int ldb = (transb == oneapi::mkl::transpose::nontrans) ? k : n; + int ldc = m; + oneapi::mkl::blas::gemm( + *handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); + } catch (sycl::exception const& exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ + << std::endl; + std::exit(1); + } +} + +int onemkl_gemm_ex(sycl::queue* handle, + oneapi::mkl::transpose transa, + oneapi::mkl::transpose transb, + int m, + int n, + int k, + const sycl::half alpha, + const sycl::half beta, + const sycl::half* A, + const sycl::half* B, + sycl::half* C) +{ + try { + int lda = (transa == oneapi::mkl::transpose::nontrans) ? m : k; + int ldb = (transb == oneapi::mkl::transpose::nontrans) ? k : n; + int ldc = m; + oneapi::mkl::blas::gemm( + *handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); + } catch (sycl::exception const& exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ + << std::endl; + std::exit(1); + } +} + +int onemkl_strided_batched_gemm(sycl::queue* handle, + int m, + int n, + int k, + const float alpha, + const float beta, + const float* A, + const float* B, + float* C, + oneapi::mkl::transpose transa, + oneapi::mkl::transpose transb, + int stride_A, + int stride_B, + int stride_C, + int batch, + int algo) +{ + try { + int lda = (transa == oneapi::mkl::transpose::nontrans) ? m : k; + int ldb = (transb == oneapi::mkl::transpose::nontrans) ? k : n; + int ldc = m; + oneapi::mkl::blas::gemm_batch(*handle, + transa, + transb, + m, + n, + k, + alpha, + A, + lda, + stride_A, + B, + ldb, + stride_B, + beta, + C, + ldc, + stride_C, + batch); + } catch (sycl::exception const& exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ + << " (batch, m, n, k)" << batch << " " << m << " " << n << " " << k << std::endl; + std::exit(1); + } +} + +int onemkl_strided_batched_gemm(sycl::queue* handle, + int m, + int n, + int k, + const sycl::half alpha, + const sycl::half beta, + const sycl::half* A, + const sycl::half* B, + sycl::half* C, + oneapi::mkl::transpose transa, + oneapi::mkl::transpose transb, + int stride_A, + int stride_B, + int stride_C, + int batch, + int algo) +{ + try { + int lda = (transa == oneapi::mkl::transpose::nontrans) ? m : k; + int ldb = (transb == oneapi::mkl::transpose::nontrans) ? k : n; + int ldc = m; + oneapi::mkl::blas::gemm_batch(*handle, + transa, + transb, + m, + n, + k, + alpha, + A, + lda, + stride_A, + B, + ldb, + stride_B, + beta, + C, + ldc, + stride_C, + batch); + } catch (sycl::exception const& exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ + << " (batch, m, n, k)" << batch << " " << m << " " << n << " " << k << std::endl; + std::exit(1); + } +} diff --git a/intel_extension_for_deepspeed/op_builder/csrc/transformer/sycl/softmax_kernels.dp.cpp b/intel_extension_for_deepspeed/op_builder/csrc/transformer/sycl/softmax_kernels.dp.cpp new file mode 100644 index 0000000..9dbf7e7 --- /dev/null +++ b/intel_extension_for_deepspeed/op_builder/csrc/transformer/sycl/softmax_kernels.dp.cpp @@ -0,0 +1,854 @@ +#include +#include "sycl/custom_sycl_layers.hpp" +#include "sycl/general_kernels.hpp" + +using namespace cl::sycl; +#define MAX_SG_NUM (32) +// Fused attention + softmax +template +void attn_softmax(float* vals, + const float* attn_mask, + int heads, + int seq_length, + int iterations, + nd_item<3> item_ct1, + float* partialSum) +{ + int sg_num = item_ct1.get_local_range().get(2) >> 5; + + int iteration_stride = item_ct1.get_local_range().get(2); + int block_width = blockStride * seq_length; + + // auto b = item_ct1.get_group(); + // cg::thread_block_tile g = cg::tiled_partition(b); + sub_group sg = item_ct1.get_sub_group(); + + int batch = item_ct1.get_group(2); + int row = item_ct1.get_group(1); + int max_threads_in_sequence = std::max(seq_length, tbSeq); + int seq_lane = item_ct1.get_local_id(2) % max_threads_in_sequence; + + int data_offset = batch * (item_ct1.get_group_range(1) * block_width) + row * block_width + + (item_ct1.get_local_id(2) / max_threads_in_sequence) * seq_length; + int mask_offset = batch * seq_length; + + int wid = item_ct1.get_local_id(2) >> 5; + int lane = item_ct1.get_local_id(2) & 0x1f; + + float4* val_cast = reinterpret_cast(vals); + const float4* attn_mask_cast = reinterpret_cast(attn_mask); + + float4 data[MAX_THREAD_ITERATIONS]; + + float max_val = minus_infinity; + + for (int i = 0; i < iterations; i++) { + int data_id = i * iteration_stride + seq_lane; + if (data_id < seq_length) { + float4 mask = attn_mask_cast[mask_offset + data_id]; + data[i] = val_cast[data_offset + data_id]; + data[i].x() += mask.x(); + data[i].y() += mask.y(); + data[i].z() += mask.z(); + data[i].w() += mask.w(); + + max_val = (data[i].x() > max_val ? data[i].x() : max_val); + max_val = (data[i].y() > max_val ? data[i].y() : max_val); + max_val = (data[i].z() > max_val ? data[i].z() : max_val); + max_val = (data[i].w() > max_val ? data[i].w() : max_val); + } else { + data[i].x() = minus_infinity; + data[i].y() = minus_infinity; + data[i].z() = minus_infinity; + data[i].w() = minus_infinity; + } + } + + for (int i = 1; i < tbSize; i *= 2) { + auto temp = sg.shuffle_xor(max_val, i); + max_val = (temp > max_val ? temp : max_val); + } + + if (seq_length > tbSize) { + if (lane == 0) partialSum[wid] = max_val; + item_ct1.barrier(); + + if (lane < sg_num) max_val = partialSum[lane]; + +#ifndef __STOCHASTIC_MODE__ + item_ct1.barrier(); +#endif + + int iters = sg_num; + if (seq_length < iteration_stride) + iters = sg_num / (iteration_stride / max_threads_in_sequence); + + for (int i = 1; i < iters; i *= 2) { + auto temp = sg.shuffle_xor(max_val, i); + max_val = (temp > max_val ? temp : max_val); + } + + max_val = sg.shuffle(max_val, item_ct1.get_local_id(2) / tbSize); + } + + float sum = 0; + for (int i = 0; i < iterations; i++) { + data[i].x() = sycl::exp(data[i].x() - max_val); + data[i].y() = sycl::exp(data[i].y() - max_val); + data[i].z() = sycl::exp(data[i].z() - max_val); + data[i].w() = sycl::exp(data[i].w() - max_val); + + sum += (data[i].x() + data[i].y() + data[i].z() + data[i].w()); + } + + for (int i = 1; i < tbSize; i *= 2) { sum += sg.shuffle_xor(sum, i); } + + if (seq_length > tbSize) { + if (lane == 0) partialSum[wid] = sum; + item_ct1.barrier(); + + if (lane < sg_num) sum = partialSum[lane]; + +#ifndef __STOCHASTIC_MODE__ + item_ct1.barrier(); +#endif + + int iters = sg_num; + if (seq_length < iteration_stride) + iters = sg_num / (iteration_stride / max_threads_in_sequence); + + for (int i = 1; i < iters; i *= 2) { sum += sg.shuffle_xor(sum, i); } + + sum = sg.shuffle(sum, item_ct1.get_local_id(2) / tbSize); + } + + sum += 1e-6; + + for (int i = 0; i < iterations; i++) { + data[i].x() /= sum; + data[i].y() /= sum; + data[i].z() /= sum; + data[i].w() /= sum; + + int data_id = i * iteration_stride + seq_lane; + if (data_id < seq_length) val_cast[data_offset + data_id] = data[i]; + } +} + +template +void attn_softmax(bf16* vals, + const bf16* attn_mask, + int heads, + int seq_length, + int iterations, + nd_item<3> item_ct1, + float* partialSum) +{ + int sg_num = item_ct1.get_local_range().get(2) >> 5; + + int iteration_stride = item_ct1.get_local_range().get(2); + int block_width = blockStride * seq_length; + + // auto b = item_ct1.get_group(); + // cg::thread_block_tile g = cg::tiled_partition(b); + sub_group sg = item_ct1.get_sub_group(); + + int batch = item_ct1.get_group(2); + int row = item_ct1.get_group(1); + int max_threads_in_sequence = std::max(seq_length, tbSeq); + int seq_lane = item_ct1.get_local_id(2) % max_threads_in_sequence; + + int data_offset = batch * (item_ct1.get_group_range(1) * block_width) + row * block_width + + (item_ct1.get_local_id(2) / max_threads_in_sequence) * seq_length; + int mask_offset = batch * seq_length; + + int wid = item_ct1.get_local_id(2) >> 5; + int lane = item_ct1.get_local_id(2) & 0x1f; + + ushort4* val_cast = reinterpret_cast(vals); + const ushort4* attn_mask_cast = reinterpret_cast(attn_mask); + + float4 data[MAX_THREAD_ITERATIONS]; + + float max_val = minus_infinity; + + for (int i = 0; i < iterations; i++) { + int data_id = i * iteration_stride + seq_lane; + if (data_id < seq_length) { + ushort4 mask_ushort = attn_mask_cast[mask_offset + data_id]; + ushort4 val_ushort = val_cast[data_offset + data_id]; + float4 mask = {bf16::to_float(mask_ushort.x()), + bf16::to_float(mask_ushort.y()), + bf16::to_float(mask_ushort.z()), + bf16::to_float(mask_ushort.w())}; + data[i] = {bf16::to_float(val_ushort.x()), + bf16::to_float(val_ushort.y()), + bf16::to_float(val_ushort.z()), + bf16::to_float(val_ushort.w())}; + + data[i].x() += mask.x(); + data[i].y() += mask.y(); + data[i].z() += mask.z(); + data[i].w() += mask.w(); + + max_val = (data[i].x() > max_val ? data[i].x() : max_val); + max_val = (data[i].y() > max_val ? data[i].y() : max_val); + max_val = (data[i].z() > max_val ? data[i].z() : max_val); + max_val = (data[i].w() > max_val ? data[i].w() : max_val); + } else { + data[i].x() = minus_infinity; + data[i].y() = minus_infinity; + data[i].z() = minus_infinity; + data[i].w() = minus_infinity; + } + } + + for (int i = 1; i < tbSize; i *= 2) { + auto temp = sg.shuffle_xor(max_val, i); + max_val = (temp > max_val ? temp : max_val); + } + + if (seq_length > tbSize) { + if (lane == 0) partialSum[wid] = max_val; + item_ct1.barrier(); + + if (lane < sg_num) max_val = partialSum[lane]; + +#ifndef __STOCHASTIC_MODE__ + item_ct1.barrier(); +#endif + + int iters = sg_num; + if (seq_length < iteration_stride) + iters = sg_num / (iteration_stride / max_threads_in_sequence); + + for (int i = 1; i < iters; i *= 2) { + auto temp = sg.shuffle_xor(max_val, i); + max_val = (temp > max_val ? temp : max_val); + } + + max_val = sg.shuffle(max_val, item_ct1.get_local_id(2) / tbSize); + } + + float sum = 0; + for (int i = 0; i < iterations; i++) { + data[i].x() = sycl::exp(data[i].x() - max_val); + data[i].y() = sycl::exp(data[i].y() - max_val); + data[i].z() = sycl::exp(data[i].z() - max_val); + data[i].w() = sycl::exp(data[i].w() - max_val); + + sum += (data[i].x() + data[i].y() + data[i].z() + data[i].w()); + } + + for (int i = 1; i < tbSize; i *= 2) { sum += sg.shuffle_xor(sum, i); } + + if (seq_length > tbSize) { + if (lane == 0) partialSum[wid] = sum; + item_ct1.barrier(); + + if (lane < sg_num) sum = partialSum[lane]; + +#ifndef __STOCHASTIC_MODE__ + item_ct1.barrier(); +#endif + + int iters = sg_num; + if (seq_length < iteration_stride) + iters = sg_num / (iteration_stride / max_threads_in_sequence); + + for (int i = 1; i < iters; i *= 2) { sum += sg.shuffle_xor(sum, i); } + + sum = sg.shuffle(sum, item_ct1.get_local_id(2) / tbSize); + } + + sum += 1e-6; + + for (int i = 0; i < iterations; i++) { + data[i].x() /= sum; + data[i].y() /= sum; + data[i].z() /= sum; + data[i].w() /= sum; + + int data_id = i * iteration_stride + seq_lane; + if (data_id < seq_length) { + ushort4 data_ushort = {bf16::from_float(data[i].x()), + bf16::from_float(data[i].y()), + bf16::from_float(data[i].z()), + bf16::from_float(data[i].w())}; + val_cast[data_offset + data_id] = data_ushort; + } + } +} + +template +void attn_softmax(half* vals, + const half* attn_mask, + int heads, + int seq_length, + int iterations, + nd_item<3> item_ct1, + float* partialSum) +{ + int sg_num = item_ct1.get_local_range(2) >> 5; + + int iteration_stride = item_ct1.get_local_range(2); + int block_width = blockStride * seq_length; + + // cg::thread_block b = cg::this_thread_block(); + // cg::thread_block_tile g = cg::tiled_partition(b); + sub_group sg = item_ct1.get_sub_group(); + + int batch = item_ct1.get_group(2); + int row = item_ct1.get_group(1); + int max_threads_in_sequence = std::max(seq_length, tbSeq); + int seq_lane = item_ct1.get_local_id(2) % max_threads_in_sequence; + + int data_offset = batch * (item_ct1.get_group_range(1) * block_width) + row * block_width + + (item_ct1.get_local_id(2) / max_threads_in_sequence) * seq_length; + int mask_offset = batch * seq_length; + + int wid = item_ct1.get_local_id(2) >> 5; + int lane = item_ct1.get_local_id(2) & 0x1f; + + float2* val_cast = reinterpret_cast(vals); + const float2* attn_mask_cast = reinterpret_cast(attn_mask); + + val_cast += data_offset; + attn_mask_cast += mask_offset; + + float2 low_data[MAX_THREAD_ITERATIONS]; + float2 high_data[MAX_THREAD_ITERATIONS]; + + float max_val = minus_infinity; + + for (int i = 0; i < iterations; i++) { + int data_id = i * iteration_stride + seq_lane; + if (data_id < seq_length) { + float2 data = val_cast[data_id]; + float2 mask = attn_mask_cast[data_id]; + + half2* data_arr = reinterpret_cast(&data); + half2* mask_arr = reinterpret_cast(&mask); + + low_data[i] = data_arr[0].convert(); + high_data[i] = data_arr[1].convert(); + float2 low_mask = mask_arr[0].convert(); + float2 high_mask = mask_arr[1].convert(); + + low_data[i].x() += low_mask.x(); + low_data[i].y() += low_mask.y(); + high_data[i].x() += high_mask.x(); + high_data[i].y() += high_mask.y(); + + max_val = (low_data[i].x() > max_val ? low_data[i].x() : max_val); + max_val = (low_data[i].y() > max_val ? low_data[i].y() : max_val); + max_val = (high_data[i].x() > max_val ? high_data[i].x() : max_val); + max_val = (high_data[i].y() > max_val ? high_data[i].y() : max_val); + } + } + + for (int i = 1; i < tbSize; i *= 2) { + auto temp = sg.shuffle_xor(max_val, i); + max_val = (temp > max_val ? temp : max_val); + } + + if (seq_length > tbSize) { + if (lane == 0) partialSum[wid] = max_val; + item_ct1.barrier(); + + if (lane < sg_num) max_val = partialSum[lane]; + +#ifndef __STOCHASTIC_MODE__ + item_ct1.barrier(); +#endif + + int iters = sg_num; + if (seq_length < iteration_stride) + iters = sg_num / (iteration_stride / max_threads_in_sequence); + + for (int i = 1; i < iters; i *= 2) { + auto temp = sg.shuffle_xor(max_val, i); + max_val = (temp > max_val ? temp : max_val); + } + + max_val = sg.shuffle(max_val, item_ct1.get_local_id(2) / tbSize); + } + + float sum = 0; + for (int i = 0; i < iterations; i++) { + int data_id = i * iteration_stride + seq_lane; + if (data_id < seq_length) { + low_data[i] = sycl::exp(low_data[i] - max_val); + high_data[i] = sycl::exp(high_data[i] - max_val); + // low_data[i].x() = sycl::exp(low_data[i].x() - max_val); + // low_data[i].y() = sycl::exp(low_data[i].y() - max_val); + // high_data[i].x() = sycl::exp(high_data[i].x() - max_val); + // high_data[i].y() = sycl::exp(high_data[i].y() - max_val); + + sum += (low_data[i].x() + low_data[i].y() + high_data[i].x() + high_data[i].y()); + } + } + + for (int i = 1; i < tbSize; i *= 2) { sum += sg.shuffle_xor(sum, i); } + + if (seq_length > tbSize) { + if (lane == 0) partialSum[wid] = sum; + item_ct1.barrier(); + + if (lane < sg_num) sum = partialSum[lane]; + +#ifndef __STOCHASTIC_MODE__ + item_ct1.barrier(); +#endif + + int iters = sg_num; + if (seq_length < iteration_stride) + iters = sg_num / (iteration_stride / max_threads_in_sequence); + + for (int i = 1; i < iters; i *= 2) { sum += sg.shuffle_xor(sum, i); } + + sum = sg.shuffle(sum, item_ct1.get_local_id(2) / tbSize); + } + + sum += 1e-6; + + for (int i = 0; i < iterations; i++) { + int data_id = i * iteration_stride + seq_lane; + if (data_id < seq_length) { + float2 result_f; + half2* result_h = reinterpret_cast(&result_f); + + low_data[i].x() /= sum; + low_data[i].y() /= sum; + high_data[i].x() /= sum; + high_data[i].y() /= sum; + + result_h[0] = low_data[i].convert(); + result_h[1] = high_data[i].convert(); + + val_cast[data_id] = result_f; + } + } +} + +template +void launch_attn_softmax(T* vals, + const T* attn_mask, + int batch_size, + int heads, + int sequence_length, + queue* stream) +{ + const int threads = 128; + int seq_length4 = sequence_length / 4; + + int block_compute_size = + (seq_length4 < threads ? (int)pow(2.0, floor(log2((float)(threads / seq_length4)))) : 1); + range<3> grid_dim(1, heads * sequence_length / block_compute_size, batch_size); + + int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads; + + range<3> block_dim(1, + 1, + seq_length4 > threads ? ((sequence_length + subblock_max_workload - 1) / + subblock_max_workload * threads) + : threads); + int iterations = + (sequence_length < subblock_max_workload ? (seq_length4 + threads - 1) / threads + : MAX_THREAD_ITERATIONS); + + if (sequence_length <= 8) + stream->submit([&](handler& cgh) { + accessor data_block_acc_ct1( + range<1>(MAX_SG_NUM), cgh); + cgh.parallel_for(nd_range<3>(grid_dim * block_dim, block_dim), + [=](nd_item<3> item_ct1) [[intel::reqd_sub_group_size(MAX_SG_NUM)]] { + attn_softmax<2, (threads / 2), 2>( + vals, + attn_mask, + heads, + seq_length4, + iterations, + item_ct1, + data_block_acc_ct1.get_pointer()); + }); + }); + else if (sequence_length <= 16) + stream->submit([&](handler& cgh) { + accessor data_block_acc_ct1( + range<1>(MAX_SG_NUM), cgh); + cgh.parallel_for(nd_range<3>(grid_dim * block_dim, block_dim), + [=](nd_item<3> item_ct1) [[intel::reqd_sub_group_size(MAX_SG_NUM)]] { + attn_softmax<4, (threads / 4), 4>( + vals, + attn_mask, + heads, + seq_length4, + iterations, + item_ct1, + data_block_acc_ct1.get_pointer()); + }); + }); + else if (sequence_length <= 32) + stream->submit([&](handler& cgh) { + accessor data_block_acc_ct1( + range<1>(MAX_SG_NUM), cgh); + cgh.parallel_for(nd_range<3>(grid_dim * block_dim, block_dim), + [=](nd_item<3> item_ct1) [[intel::reqd_sub_group_size(MAX_SG_NUM)]] { + attn_softmax<8, (threads / 8), 8>( + vals, + attn_mask, + heads, + seq_length4, + iterations, + item_ct1, + data_block_acc_ct1.get_pointer()); + }); + }); + else if (sequence_length <= 64) + stream->submit([&](handler& cgh) { + accessor data_block_acc_ct1( + range<1>(MAX_SG_NUM), cgh); + cgh.parallel_for(nd_range<3>(grid_dim * block_dim, block_dim), + [=](nd_item<3> item_ct1) [[intel::reqd_sub_group_size(MAX_SG_NUM)]] { + attn_softmax<16, (threads / 16), 16>( + vals, + attn_mask, + heads, + seq_length4, + iterations, + item_ct1, + data_block_acc_ct1.get_pointer()); + }); + }); + else if (sequence_length <= 128) + stream->submit([&](handler& cgh) { + accessor data_block_acc_ct1( + range<1>(MAX_SG_NUM), cgh); + cgh.parallel_for(nd_range<3>(grid_dim * block_dim, block_dim), + [=](nd_item<3> item_ct1) [[intel::reqd_sub_group_size(MAX_SG_NUM)]] { + attn_softmax<32, (threads / 32), 32>( + vals, + attn_mask, + heads, + seq_length4, + iterations, + item_ct1, + data_block_acc_ct1.get_pointer()); + }); + }); + else if (sequence_length <= 256) + stream->submit([&](handler& cgh) { + accessor data_block_acc_ct1( + range<1>(MAX_SG_NUM), cgh); + cgh.parallel_for(nd_range<3>(grid_dim * block_dim, block_dim), + [=](nd_item<3> item_ct1) [[intel::reqd_sub_group_size(MAX_SG_NUM)]] { + attn_softmax<32, (threads / 64), 64>( + vals, + attn_mask, + heads, + seq_length4, + iterations, + item_ct1, + data_block_acc_ct1.get_pointer()); + }); + }); + else { + const int threads = 256; + block_compute_size = + (seq_length4 < threads ? (int)pow(2.0, floor(log2((float)(threads / seq_length4)))) + : 1); + range<3> grid_dim(1, heads * sequence_length / block_compute_size, batch_size); + + int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads; + + range<3> block_dim(1, + 1, + seq_length4 > threads ? ((sequence_length + subblock_max_workload - 1) / + subblock_max_workload * threads) + : threads); + iterations = + (sequence_length < subblock_max_workload ? (seq_length4 + threads - 1) / threads + : MAX_THREAD_ITERATIONS); + if (sequence_length <= 512) { + stream->submit([&](handler& cgh) { + accessor + data_block_acc_ct1(range<1>(MAX_SG_NUM), cgh); + cgh.parallel_for( + nd_range<3>(grid_dim * block_dim, block_dim), + [=](nd_item<3> item_ct1) [[intel::reqd_sub_group_size(MAX_SG_NUM)]] { + attn_softmax<32, (threads / 128), 128>(vals, + attn_mask, + heads, + seq_length4, + iterations, + item_ct1, + data_block_acc_ct1.get_pointer()); + }); + }); + } else if (sequence_length < (MAX_THREADS * MAX_THREAD_ITERATIONS * 4)) + stream->submit([&](handler& cgh) { + accessor + data_block_acc_ct1(range<1>(MAX_SG_NUM), cgh); + cgh.parallel_for(nd_range<3>(grid_dim * block_dim, block_dim), + [=](nd_item<3> item_ct1) + [[intel::reqd_sub_group_size(MAX_SG_NUM)]] { + attn_softmax<32, 1, 128>(vals, + attn_mask, + heads, + seq_length4, + iterations, + item_ct1, + data_block_acc_ct1.get_pointer()); + }); + }); + else + throw std::runtime_error( + "Unsupport Seq_Length! Check the restriction of the max_threads and " + "max_thread_iterations!"); + } +} + +template void launch_attn_softmax(float* vals, + const float* attn_mask, + int batch_size, + int heads, + int sequence_length, + queue* stream); + +template void launch_attn_softmax(bf16* vals, + const bf16* attn_mask, + int batch_size, + int heads, + int sequence_length, + queue* stream); + +template void launch_attn_softmax(half* vals, + const half* attn_mask, + int batch_size, + int heads, + int sequence_length, + queue* stream); + +template +void softmax_backward_kernel(T* out_grad, + const T* soft_inp, + int seq_length, + nd_item<3> item_ct1, + float* partialSum) +{ + int sg_num = item_ct1.get_local_range().get(2) >> 5; // sg-count = num_threads / SG_SIZE (32) + + int iteration_stride = item_ct1.get_local_range().get(2); + int block_width = blockStride * seq_length; + + int iterations = (seq_length < (MAX_THREAD_ITERATIONS * iteration_stride) + ? (seq_length + iteration_stride - 1) / iteration_stride + : MAX_THREAD_ITERATIONS); + + // auto b = item_ct1.get_group(); + // cg::thread_block_tile g = cg::tiled_partition(b); + sub_group sg = item_ct1.get_sub_group(); + + int row = item_ct1.get_group(2); + int id = item_ct1.get_local_id(2); + + int wid = id >> 5; + int lane = id & 0x1f; + + T val_reg[MAX_THREAD_ITERATIONS]; + T soft_reg[MAX_THREAD_ITERATIONS]; + float grad_reg = 0.0f; + +#pragma unroll + for (int i = 0; i < iterations; i++) { + int data_id = i * iteration_stride + id; + if (data_id < block_width) { + val_reg[i] = out_grad[row * block_width + data_id]; + soft_reg[i] = soft_inp[row * block_width + data_id]; + + grad_reg += ((float)val_reg[i] * + (float)soft_reg[i]); // if done in half, the multiplication, we may + // lose 2% of accuracy in computation!! + } + } + for (int i = 1; i < tbSize; i *= 2) grad_reg += sg.shuffle_xor(grad_reg, i); + + if (seq_length > tbSize) { + if (lane == 0) partialSum[wid] = grad_reg; + item_ct1.barrier(); + + if (lane < sg_num) grad_reg = partialSum[lane]; + + int iters = sg_num; + if (seq_length < iteration_stride) iters = sg_num / (iteration_stride / seq_length); + + for (int i = 1; i < iters; i *= 2) grad_reg += sg.shuffle_xor(grad_reg, i); + + grad_reg = sg.shuffle(grad_reg, id / tbSize); + } + + for (int i = 0; i < iterations; i++) { + int data_id = i * iteration_stride + id; + if (data_id < block_width) { + float temp = (float)soft_reg[i] * ((float)val_reg[i] - grad_reg); + out_grad[row * block_width + data_id] = (T)temp; + } + } +} + +template +void softmax_backward_kernel_v2(T* grad /* input & output*/, + const T* output, + int softmax_length, + nd_item<3> item_ct1) +{ + int batch_idx = + item_ct1.get_group(2) * item_ct1.get_local_range().get(1) + item_ct1.get_local_id(1); + int offset = batch_idx * softmax_length + item_ct1.get_local_id(2); + + grad += offset; + output += offset; + + float sum = 0.0; + if constexpr (std::is_same_v) { + float grad_reg[ITERATIONS]; + float output_reg[ITERATIONS]; + ushort* grad_cast = (ushort*)grad; + const ushort* output_cast = (const ushort*)output; + for (int i = 0; i < ITERATIONS; ++i) { + int curr_idx = item_ct1.get_local_id(2) + i * MAX_SG_NUM; + if (curr_idx < softmax_length) { + grad_reg[i] = bf16::to_float(grad_cast[i * MAX_SG_NUM]); + output_reg[i] = bf16::to_float(output_cast[i * MAX_SG_NUM]); + sum += grad_reg[i] * output_reg[i]; + } + } + + sub_group sg = item_ct1.get_sub_group(); + + for (int i = 1; i < MAX_SG_NUM; i <<= 1) sum += sg.shuffle_xor(sum, i); + +#pragma unroll + for (int i = 0; i < ITERATIONS; ++i) { + int curr_idx = item_ct1.get_local_id(2) + i * MAX_SG_NUM; + if (curr_idx < softmax_length) { + grad_cast[i * MAX_SG_NUM] = bf16::from_float(output_reg[i] * (grad_reg[i] - sum)); + } + } + } else { + T grad_reg[ITERATIONS]; + T output_reg[ITERATIONS]; + +#pragma unroll + for (int i = 0; i < ITERATIONS; ++i) { + int curr_idx = item_ct1.get_local_id(2) + i * MAX_SG_NUM; + if (curr_idx < softmax_length) { + grad_reg[i] = grad[i * MAX_SG_NUM]; + output_reg[i] = output[i * MAX_SG_NUM]; + sum += (float)grad_reg[i] * (float)output_reg[i]; + } + } + sub_group sg = item_ct1.get_sub_group(); + + for (int i = 1; i < MAX_SG_NUM; i <<= 1) sum += sg.shuffle_xor(sum, i); + +#pragma unroll + for (int i = 0; i < ITERATIONS; ++i) { + int curr_idx = item_ct1.get_local_id(2) + i * MAX_SG_NUM; + if (curr_idx < softmax_length) + grad[i * MAX_SG_NUM] = (float)output_reg[i] * ((float)grad_reg[i] - sum); + } + } +} + +template +void launch_attn_softmax_backward_v2(T* out_grad, + const T* soft_inp, + int batch_size, + int heads, + int seq_length, + queue* stream) +{ + const int sgs_per_block = 4; + range<3> grid_dim(1, 1, batch_size * heads * seq_length / sgs_per_block); + range<3> block_dim(1, sgs_per_block, MAX_SG_NUM); + + if (seq_length <= 32) + stream->parallel_for(nd_range<3>(grid_dim * block_dim, block_dim), + [=](nd_item<3> item_ct1) [[intel::reqd_sub_group_size(MAX_SG_NUM)]] { + softmax_backward_kernel_v2( + out_grad, soft_inp, seq_length, item_ct1); + }); + else if (seq_length <= 64) + stream->parallel_for(nd_range<3>(grid_dim * block_dim, block_dim), + [=](nd_item<3> item_ct1) [[intel::reqd_sub_group_size(MAX_SG_NUM)]] { + softmax_backward_kernel_v2( + out_grad, soft_inp, seq_length, item_ct1); + }); + else if (seq_length <= 128) + stream->parallel_for(nd_range<3>(grid_dim * block_dim, block_dim), + [=](nd_item<3> item_ct1) [[intel::reqd_sub_group_size(MAX_SG_NUM)]] { + softmax_backward_kernel_v2( + out_grad, soft_inp, seq_length, item_ct1); + }); + else if (seq_length <= 256) + stream->parallel_for(nd_range<3>(grid_dim * block_dim, block_dim), + [=](nd_item<3> item_ct1) [[intel::reqd_sub_group_size(MAX_SG_NUM)]] { + softmax_backward_kernel_v2( + out_grad, soft_inp, seq_length, item_ct1); + }); + else if (seq_length <= 384) + stream->parallel_for(nd_range<3>(grid_dim * block_dim, block_dim), + [=](nd_item<3> item_ct1) [[intel::reqd_sub_group_size(MAX_SG_NUM)]] { + softmax_backward_kernel_v2( + out_grad, soft_inp, seq_length, item_ct1); + }); + else if (seq_length <= 512) + stream->parallel_for(nd_range<3>(grid_dim * block_dim, block_dim), + [=](nd_item<3> item_ct1) [[intel::reqd_sub_group_size(MAX_SG_NUM)]] { + softmax_backward_kernel_v2( + out_grad, soft_inp, seq_length, item_ct1); + }); + else if (seq_length <= 768) + stream->parallel_for(nd_range<3>(grid_dim * block_dim, block_dim), + [=](nd_item<3> item_ct1) [[intel::reqd_sub_group_size(MAX_SG_NUM)]] { + softmax_backward_kernel_v2( + out_grad, soft_inp, seq_length, item_ct1); + }); + else if (seq_length <= 1024) + stream->parallel_for(nd_range<3>(grid_dim * block_dim, block_dim), + [=](nd_item<3> item_ct1) [[intel::reqd_sub_group_size(MAX_SG_NUM)]] { + softmax_backward_kernel_v2( + out_grad, soft_inp, seq_length, item_ct1); + }); + else if (seq_length <= 2048) + stream->parallel_for(nd_range<3>(grid_dim * block_dim, block_dim), + [=](nd_item<3> item_ct1) [[intel::reqd_sub_group_size(MAX_SG_NUM)]] { + softmax_backward_kernel_v2( + out_grad, soft_inp, seq_length, item_ct1); + }); + else + throw std::runtime_error( + std::string("Special sequence length found in softmax backward, seq_length: ") + + std::to_string(seq_length)); +} + +template void launch_attn_softmax_backward_v2(float* out_grad, + const float* soft_inp, + int batch_size, + int heads, + int seq_length, + queue* stream); +template void launch_attn_softmax_backward_v2(bf16* out_grad, + const bf16* soft_inp, + int batch_size, + int heads, + int seq_length, + queue* stream); +template void launch_attn_softmax_backward_v2(half* out_grad, + const half* soft_inp, + int batch_size, + int heads, + int seq_length, + queue* stream); diff --git a/intel_extension_for_deepspeed/op_builder/csrc/transformer/sycl/transform_kernels.dp.cpp b/intel_extension_for_deepspeed/op_builder/csrc/transformer/sycl/transform_kernels.dp.cpp new file mode 100644 index 0000000..220b07c --- /dev/null +++ b/intel_extension_for_deepspeed/op_builder/csrc/transformer/sycl/transform_kernels.dp.cpp @@ -0,0 +1,891 @@ +#include +#include "sycl/custom_sycl_layers.hpp" + +#define rows_trans 16 +#define cols_trans 16 + +template +void Transpose_Kernel(const T* inp, + T* out, + int row_width, + int col_width, + sycl::nd_item<3> item_ct1, + T* data_block) +{ + int r = item_ct1.get_local_id(2) / cols_trans; + int c = item_ct1.get_local_id(2) % cols_trans; + + int m = row_width / cols_trans; + + int i = item_ct1.get_group(2) / m * rows_trans + r; + int j = item_ct1.get_group(2) % m * cols_trans + c; + + int row_stride = rows_trans / ((rows_trans * cols_trans + THREADS - 1) / THREADS); + + for (int k = 0; k < rows_trans; k += row_stride) + data_block[(k + r) * cols_trans + c] = inp[(i + k) * row_width + j]; + + item_ct1.barrier(); + + i = item_ct1.get_group(2) % m * rows_trans + r; + j = item_ct1.get_group(2) / m * cols_trans + c; + + for (int k = 0; k < rows_trans; k += row_stride) + out[(i + k) * col_width + j] = data_block[c * cols_trans + r + k]; +} + +template <> +void Transpose(const sycl::half* inp_mat, + sycl::half* out_mat, + int rows, + int cols, + sycl::queue* stream) +{ + int threads = THREADS; + + sycl::range<3> grid_dim(1, 1, (rows * cols + threads - 1) / threads); + sycl::range<3> block_dim(1, 1, threads); + stream->submit([&](sycl::handler& cgh) { + sycl::accessor + data_block_acc_ct1(sycl::range<1>(rows_trans * (cols_trans + 1)), cgh); + cgh.parallel_for(sycl::nd_range<3>(grid_dim, block_dim), [=](sycl::nd_item<3> item_ct1) { + Transpose_Kernel( + inp_mat, out_mat, cols, rows, item_ct1, data_block_acc_ct1.get_pointer()); + }); + }); +} + +template <> +void Transpose(const float* inp_mat, float* out_mat, int rows, int cols, sycl::queue* stream) +{ + int threads = THREADS; + sycl::range<3> grid_dim(1, 1, (rows * cols + threads - 1) / threads); + sycl::range<3> block_dim(1, 1, threads); + + stream->submit([&](sycl::handler& cgh) { + sycl::accessor + data_block_acc_ct1(sycl::range<1>(rows_trans * (cols_trans + 1)), cgh); + cgh.parallel_for( + sycl::nd_range<3>(grid_dim * block_dim, block_dim), [=](sycl::nd_item<3> item_ct1) { + Transpose_Kernel( + inp_mat, out_mat, cols, rows, item_ct1, data_block_acc_ct1.get_pointer()); + }); + }); +} + +template +void transform_0213(T* output, + const T* vals, + int hidden_dim, + int seq_length, + int heads, + int head_ext, + sycl::nd_item<3> item_ct1); + +template <> +void transform_0213(float* output, + const float* vals, + int hidden_dim, + int seq_length, + int heads, + int head_ext, + sycl::nd_item<3> item_ct1) +{ + int d0_stride = hidden_dim * seq_length; + int d1_stride = hidden_dim; + int d2_stride = hidden_dim / heads; + + int d0_out_stride = d0_stride; + int d1_out_stride = d2_stride; + int d2_out_stride = d2_stride * seq_length; + + int d0 = item_ct1.get_group(2); // Batch + int d1 = item_ct1.get_group(1) / head_ext; // Sequence ID (0-127) + int d2 = item_ct1.get_local_id(1) + + (item_ct1.get_group(1) % head_ext) * (heads / head_ext); // Head (0-11) + int d3 = item_ct1.get_local_id(2); // Values (groups of 4) + + const sycl::float4* vals_vec = reinterpret_cast(vals); + sycl::float4* output_vec = reinterpret_cast(output); + + sycl::float4 inputs = vals_vec[d0 * d0_stride + d1 * d1_stride + d2 * d2_stride + d3]; + output_vec[d0 * d0_out_stride + d1 * d1_out_stride + d2 * d2_out_stride + d3] = inputs; +} + +template <> +void transform_0213(bf16* output, + const bf16* vals, + int hidden_dim, + int seq_length, + int heads, + int head_ext, + sycl::nd_item<3> item_ct1) +{ + int d0_stride = hidden_dim * seq_length; + int d1_stride = hidden_dim; + int d2_stride = hidden_dim / heads; + + int d0_out_stride = d0_stride; + int d1_out_stride = d2_stride; + int d2_out_stride = d2_stride * seq_length; + + int d0 = item_ct1.get_group(2); // Batch + int d1 = item_ct1.get_group(1) / head_ext; // Sequence ID (0-127) + int d2 = item_ct1.get_local_id(1) + + (item_ct1.get_group(1) % head_ext) * (heads / head_ext); // Head (0-11) + int d3 = item_ct1.get_local_id(2); // Values (groups of 4) + + const sycl::ushort4* vals_vec = reinterpret_cast(vals); + sycl::ushort4* output_vec = reinterpret_cast(output); + + sycl::ushort4 inputs_cast = vals_vec[d0 * d0_stride + d1 * d1_stride + d2 * d2_stride + d3]; + float4 inputs = {bf16::to_float(inputs_cast.x()), + bf16::to_float(inputs_cast.y()), + bf16::to_float(inputs_cast.z()), + bf16::to_float(inputs_cast.w())}; + + sycl::float4 outputs; + outputs.x() = inputs.x(); + outputs.y() = inputs.y(); + outputs.z() = inputs.z(); + outputs.w() = inputs.w(); + + ushort4 outputs_cast = {bf16::from_float(outputs.x()), + bf16::from_float(outputs.y()), + bf16::from_float(outputs.z()), + bf16::from_float(outputs.w())}; + + output_vec[d0 * d0_out_stride + d1 * d1_out_stride + d2 * d2_out_stride + d3] = outputs_cast; +} + +template <> +void transform_0213(sycl::half* output, + const sycl::half* vals, + int hidden_dim, + int seq_length, + int heads, + int head_ext, + sycl::nd_item<3> item_ct1) +{ + int d0_stride = hidden_dim * seq_length; + int d1_stride = hidden_dim; + int d2_stride = hidden_dim / heads; + + int d0_out_stride = d0_stride; + int d1_out_stride = d2_stride; + int d2_out_stride = d2_stride * seq_length; + + int d0 = item_ct1.get_group(2); // Batch + int d1 = item_ct1.get_group(1) / head_ext; // Sequence ID (0-127) + int d2 = item_ct1.get_local_id(1) + + (item_ct1.get_group(1) % head_ext) * (heads / head_ext); // Head (0-11) + int d3 = item_ct1.get_local_id(2); // Values (groups of 4) + + sycl::float4 vals_arr[1]; + + const sycl::float4* vals_vec = reinterpret_cast(vals); + sycl::float4* output_vec = reinterpret_cast(output); + + vals_arr[0] = vals_vec[d0 * d0_stride + d1 * d1_stride + d2 * d2_stride + d3]; + output_vec[d0 * d0_out_stride + d1 * d1_out_stride + d2 * d2_out_stride + d3] = vals_arr[0]; +} + +template <> +void launch_transform_0213(float* output, + const float* vals, + int batch_size, + int seq_length, + int hidden_dim, + int heads, + sycl::queue* stream) +{ + hidden_dim >>= 2; + int head_ext = (hidden_dim - 1) / MAX_THREADS + 1; + sycl::range<3> block_dim(1, (heads / head_ext), hidden_dim / heads); + sycl::range<3> grid_dim(1, (seq_length * head_ext), batch_size); + + stream->submit([&](sycl::handler& cgh) { + cgh.parallel_for(sycl::nd_range<3>(grid_dim * block_dim, block_dim), + [=](sycl::nd_item<3> item_ct1) { + transform_0213( + output, vals, hidden_dim, seq_length, heads, head_ext, item_ct1); + }); + }); +} + +template <> +void launch_transform_0213(bf16* output, + const bf16* vals, + int batch_size, + int seq_length, + int hidden_dim, + int heads, + sycl::queue* stream) +{ + hidden_dim >>= 2; + int head_ext = (hidden_dim - 1) / MAX_THREADS + 1; + sycl::range<3> block_dim(1, (heads / head_ext), hidden_dim / heads); + sycl::range<3> grid_dim(1, (seq_length * head_ext), batch_size); + + stream->submit([&](sycl::handler& cgh) { + cgh.parallel_for(sycl::nd_range<3>(grid_dim * block_dim, block_dim), + [=](sycl::nd_item<3> item_ct1) { + transform_0213( + output, vals, hidden_dim, seq_length, heads, head_ext, item_ct1); + }); + }); +} + +template <> +void launch_transform_0213(sycl::half* output, + const sycl::half* vals, + int batch_size, + int seq_length, + int hidden_dim, + int heads, + sycl::queue* stream) +{ + hidden_dim >>= 3; + int head_ext = (hidden_dim - 1) / MAX_THREADS + 1; + sycl::range<3> block_dim(1, (heads / head_ext), hidden_dim / heads); + sycl::range<3> grid_dim(1, (seq_length * head_ext), batch_size); + + stream->submit([&](sycl::handler& cgh) { + cgh.parallel_for(sycl::nd_range<3>(grid_dim * block_dim, block_dim), + [=](sycl::nd_item<3> item_ct1) { + transform_0213( + output, vals, hidden_dim, seq_length, heads, head_ext, item_ct1); + }); + }); +} + +// Bias add +template +void bias_add_transform_0213(T* output, + const T* vals, + const T* bias, + int hidden_dim, + int seq_length, + int heads, + int head_ext, + sycl::nd_item<3> item_ct1); + +template <> +void bias_add_transform_0213(float* output, + const float* vals, + const float* bias, + int hidden_dim, + int seq_length, + int heads, + int head_ext, + sycl::nd_item<3> item_ct1) +{ + int d0_stride = hidden_dim * seq_length; + int d1_stride = hidden_dim; + int d2_stride = hidden_dim / heads; + + int d0_out_stride = d0_stride; + int d1_out_stride = d2_stride; + int d2_out_stride = d2_stride * seq_length; + + int d0 = item_ct1.get_group(2); // Batch + int d1 = item_ct1.get_group(1); // Sequence ID (0-127) + int cnt = item_ct1.get_group(0) / head_ext; // Hidden count + int d2 = item_ct1.get_local_id(1) + + (item_ct1.get_group(0) % head_ext) * (heads / head_ext); // Head (0-11) + int d3 = item_ct1.get_local_id(2); // Values (groups of 4) + + const sycl::float4* vals_vec = reinterpret_cast(vals); + const sycl::float4* bias_vec = reinterpret_cast(bias); + sycl::float4* output_vec = reinterpret_cast(output); + + sycl::float4 inputs = + vals_vec[d0 * d0_stride * (item_ct1.get_group_range(0) / head_ext) + cnt * d1_stride + + d1 * d1_stride * (item_ct1.get_group_range(0) / head_ext) + d2 * d2_stride + d3]; + sycl::float4 biases = bias_vec[cnt * d1_stride + d2 * d2_stride + d3]; + + sycl::float4 outputs; + outputs.x() = inputs.x() + biases.x(); + outputs.y() = inputs.y() + biases.y(); + outputs.z() = inputs.z() + biases.z(); + outputs.w() = inputs.w() + biases.w(); + + output_vec[cnt * d0_out_stride * item_ct1.get_group_range(2) + d0 * d0_out_stride + + d1 * d1_out_stride + d2 * d2_out_stride + d3] = outputs; +} + +template <> +void bias_add_transform_0213(bf16* output, + const bf16* vals, + const bf16* bias, + int hidden_dim, + int seq_length, + int heads, + int head_ext, + sycl::nd_item<3> item_ct1) +{ + int d0_stride = hidden_dim * seq_length; + int d1_stride = hidden_dim; + int d2_stride = hidden_dim / heads; + + int d0_out_stride = d0_stride; + int d1_out_stride = d2_stride; + int d2_out_stride = d2_stride * seq_length; + + int d0 = item_ct1.get_group(2); // Batch + int d1 = item_ct1.get_group(1); // Sequence ID (0-127) + int cnt = item_ct1.get_group(0) / head_ext; // Hidden count + int d2 = item_ct1.get_local_id(1) + + (item_ct1.get_group(0) % head_ext) * (heads / head_ext); // Head (0-11) + int d3 = item_ct1.get_local_id(2); // Values (groups of 4) + + const sycl::ushort4* vals_vec = reinterpret_cast(vals); + const sycl::ushort4* bias_vec = reinterpret_cast(bias); + sycl::ushort4* output_vec = reinterpret_cast(output); + + sycl::ushort4 inputs_cast = + vals_vec[d0 * d0_stride * (item_ct1.get_group_range(0) / head_ext) + cnt * d1_stride + + d1 * d1_stride * (item_ct1.get_group_range(0) / head_ext) + d2 * d2_stride + d3]; + sycl::ushort4 biases_cast = bias_vec[cnt * d1_stride + d2 * d2_stride + d3]; + float4 inputs = {bf16::to_float(inputs_cast.x()), + bf16::to_float(inputs_cast.y()), + bf16::to_float(inputs_cast.z()), + bf16::to_float(inputs_cast.w())}; + + float4 biases = {bf16::to_float(biases_cast.x()), + bf16::to_float(biases_cast.y()), + bf16::to_float(biases_cast.z()), + bf16::to_float(biases_cast.w())}; + + sycl::float4 outputs; + outputs.x() = inputs.x() + biases.x(); + outputs.y() = inputs.y() + biases.y(); + outputs.z() = inputs.z() + biases.z(); + outputs.w() = inputs.w() + biases.w(); + + ushort4 outputs_cast = {bf16::from_float(outputs.x()), + bf16::from_float(outputs.y()), + bf16::from_float(outputs.z()), + bf16::from_float(outputs.w())}; + output_vec[cnt * d0_out_stride * item_ct1.get_group_range(2) + d0 * d0_out_stride + + d1 * d1_out_stride + d2 * d2_out_stride + d3] = outputs_cast; +} + +#define ATTN_H 3 +#define MAX_SEQ_LINE 10 + +template <> +void bias_add_transform_0213(sycl::half* output, + const sycl::half* vals, + const sycl::half* bias, + int hidden_dim, + int seq_length, + int heads, + int head_ext, + sycl::nd_item<3> item_ct1) +{ + int d0_stride = hidden_dim * seq_length; + int d1_stride = hidden_dim; + int d2_stride = hidden_dim / heads; + + int d2_out_stride = d2_stride * seq_length; + + int d0 = item_ct1.get_group(2); // Batch + int d1 = item_ct1.get_group(1); // Sequence ID (0-127) + int cnt = item_ct1.get_group(0) / head_ext; // Hidden count + int d2 = item_ct1.get_local_id(1) + + (item_ct1.get_group(0) % head_ext) * (heads / head_ext); // Head (0-11) + int d3 = item_ct1.get_local_id(2); // Values (groups of 4) + + sycl::float4 vals_arr; + sycl::float4 bias_arr; + sycl::float4 output_arr; + sycl::half2* vals_half = reinterpret_cast(&vals_arr); + sycl::half2* bias_half = reinterpret_cast(&bias_arr); + sycl::half2* output_half = reinterpret_cast(&output_arr); + + const sycl::float4* vals_vec = reinterpret_cast(vals); + const sycl::float4* bias_vec = reinterpret_cast(bias); + sycl::float4* output_vec = reinterpret_cast(output); + + vals_vec += (d0 * d0_stride * (item_ct1.get_group_range(0) / head_ext)); + vals_vec += (d1 * d1_stride * (item_ct1.get_group_range(0) / head_ext)); + vals_vec += (cnt * d1_stride); + vals_vec += (d2 * d2_stride); + + bias_vec += (cnt * d1_stride); + bias_vec += (d2 * d2_stride); + + output_vec += (cnt * d0_stride * item_ct1.get_group_range(2)); + output_vec += (d1 * d2_stride); + output_vec += (d0 * d0_stride); + output_vec += (d2 * d2_out_stride); + + bias_arr = bias_vec[d3]; + vals_arr = vals_vec[d3]; + +#if defined(__ACC_HALF__) + output_half[0] = vals_half[0] + bias_half[0]; + output_half[1] = vals_half[1] + bias_half[1]; + output_half[2] = vals_half[2] + bias_half[2]; + output_half[3] = vals_half[3] + bias_half[3]; +#else + sycl::float2 bias_arr_f[4]; + sycl::float2 vals_arr_f[4]; +#pragma unroll + for (int l = 0; l < 4; l++) { + bias_arr_f[l] = bias_half[l].convert(); + vals_arr_f[l] = vals_half[l].convert(); + vals_arr_f[l].x() += bias_arr_f[l].x(); + vals_arr_f[l].y() += bias_arr_f[l].y(); + output_half[l] = vals_arr_f[l].convert(); + } +#endif + output_vec[d3] = output_arr; +} + +void bias_add_transform_0213_v2(sycl::half* output, + const sycl::half* vals, + const sycl::half* bias, + int hidden_dim, + int seq_length, + int heads, + sycl::nd_item<3> item_ct1, + sycl::float4* in_data) +{ + //__shared__ sycl::float4 in_data[3072]; + + int d0_stride = hidden_dim * seq_length; + int d1_stride = hidden_dim; + int d2_stride = hidden_dim / heads; + int iteration_stride = d1_stride * item_ct1.get_local_range(0); // Hidden * 3 / 8 + int batch_stride = d0_stride * item_ct1.get_local_range(0); // Hidden * S * 3 / 8 + + int d0_out_stride = d0_stride; + int d1_out_stride = d2_stride; + int d2_out_stride = d2_stride * seq_length; + + int d0 = item_ct1.get_group(2); // Batch + int d1 = item_ct1.get_group(1); // Sequence ID (0-127) + int cnt = item_ct1.get_local_id(0); // item_ct1.get_group(0); Hidden count + int d2 = item_ct1.get_local_id(1); // Head (0-11) + int d3 = item_ct1.get_local_id(2); // Values (groups of 4) + + sycl::float4 vals_arr[1]; + sycl::float4 bias_arr[1]; + sycl::float4 output_arr[1]; + sycl::half2* vals_half = reinterpret_cast(vals_arr); + sycl::half2* bias_half = reinterpret_cast(bias_arr); + sycl::half2* output_half = reinterpret_cast(output_arr); + + const sycl::float4* vals_vec = reinterpret_cast(vals); + const sycl::float4* bias_vec = reinterpret_cast(bias); + sycl::float4* output_vec = reinterpret_cast(output); + + int iter_index = cnt * d1_stride + d2 * d2_stride + d3; + int input_offset = d0 * batch_stride + d1 * (iteration_stride << 1); + bias_arr[0] = bias_vec[iter_index]; + +#pragma unroll + for (int iter = 0; iter < 2; iter++) { + int iter_id = iter * iteration_stride + iter_index; + vals_arr[0] = vals_vec[input_offset + iter_id]; + + output_half[0] = vals_half[0] + bias_half[0]; + output_half[1] = vals_half[1] + bias_half[1]; + output_half[2] = vals_half[2] + bias_half[2]; + output_half[3] = vals_half[3] + bias_half[3]; + + in_data[iter_id] = output_arr[0]; + } + item_ct1.barrier(); + + iteration_stride = item_ct1.get_local_range(0) * (item_ct1.get_local_range(1) >> 1); + int matrix_stride = (d0_out_stride * item_ct1.get_group_range(2)); + int head_count = (d2 >> 1) + cnt * (item_ct1.get_local_range(1) >> 1); + + int out_index = d0 * d0_out_stride + d1 * (d1_out_stride << 1) + d3 + (d2 % 2) * d2_stride; + +#pragma unroll + for (int iter = 0; iter < 2; iter++) { + int iter_row = (iter * iteration_stride) + head_count; + int iter_offset = (iter_row % item_ct1.get_local_range(1)) * d2_out_stride + + (iter_row / item_ct1.get_local_range(1)) * matrix_stride; + output_vec[out_index + iter_offset] = + in_data[iter_row * d2_stride + d3 + + (d2 % 2) * (d1_stride * item_ct1.get_local_range(0))]; + } +} + +// [B S C*H] - > C * [B A S N] +template <> +void launch_bias_add_transform_0213(float* output, + const float* vals, + const float* bias, + int batch_size, + int seq_length, + int hidden_dim, + int heads, + sycl::queue* stream, + int trans_count) +{ + hidden_dim >>= 2; + int head_ext = (hidden_dim - 1) / MAX_THREADS + 1; + + sycl::range<3> block_dim(1, (heads / head_ext), hidden_dim / heads); + sycl::range<3> grid_dim((trans_count * head_ext), seq_length, batch_size); + + stream->submit([&](sycl::handler& cgh) { + cgh.parallel_for( + sycl::nd_range<3>(grid_dim * block_dim, block_dim), [=](sycl::nd_item<3> item_ct1) { + bias_add_transform_0213( + output, vals, bias, hidden_dim, seq_length, heads, head_ext, item_ct1); + }); + }); + // bias_add_transform_0213<<>>( + // output, vals, bias, hidden_dim, seq_length, heads, head_ext); +} + +template <> +void launch_bias_add_transform_0213(bf16* output, + const bf16* vals, + const bf16* bias, + int batch_size, + int seq_length, + int hidden_dim, + int heads, + sycl::queue* stream, + int trans_count) +{ + hidden_dim >>= 2; + int head_ext = (hidden_dim - 1) / MAX_THREADS + 1; + + sycl::range<3> block_dim(1, (heads / head_ext), hidden_dim / heads); + sycl::range<3> grid_dim((trans_count * head_ext), seq_length, batch_size); + + stream->submit([&](sycl::handler& cgh) { + cgh.parallel_for( + sycl::nd_range<3>(grid_dim * block_dim, block_dim), [=](sycl::nd_item<3> item_ct1) { + bias_add_transform_0213( + output, vals, bias, hidden_dim, seq_length, heads, head_ext, item_ct1); + }); + }); + // bias_add_transform_0213<<>>( + // output, vals, bias, hidden_dim, seq_length, heads, head_ext); +} + +template <> +void launch_bias_add_transform_0213(sycl::half* output, + const sycl::half* vals, + const sycl::half* bias, + int batch_size, + int seq_length, + int hidden_dim, + int heads, + sycl::queue* stream, + int trans_count) +{ + hidden_dim >>= 3; + if (hidden_dim > 128 || hidden_dim < 16) { + int head_ext = (hidden_dim - 1) / MAX_THREADS + 1; + sycl::range<3> block_dim(1, (heads / head_ext), hidden_dim / heads); + sycl::range<3> grid_dim((trans_count * head_ext), seq_length, batch_size); + stream->submit([&](sycl::handler& cgh) { + cgh.parallel_for( + sycl::nd_range<3>(grid_dim * block_dim, block_dim), [=](sycl::nd_item<3> item_ct1) { + bias_add_transform_0213( + output, vals, bias, hidden_dim, seq_length, heads, head_ext, item_ct1); + }); + }); + // bias_add_transform_0213<<>>( + // output, vals, bias, hidden_dim, seq_length, heads, head_ext); + } else { + sycl::range<3> block_dim(trans_count, heads, hidden_dim / heads); + sycl::range<3> grid_dim(1, seq_length / 2, batch_size); + stream->submit([&](sycl::handler& cgh) { + sycl::accessor + data_block_acc_ct1(sycl::range<1>(3072), cgh); + cgh.parallel_for(sycl::nd_range<3>(grid_dim * block_dim, block_dim), + [=](sycl::nd_item<3> item_ct1) { + bias_add_transform_0213_v2(output, + vals, + bias, + hidden_dim, + seq_length, + heads, + item_ct1, + data_block_acc_ct1.get_pointer()); + }); + }); + // bias_add_transform_0213_v2<<>>( + // output, vals, bias, hidden_dim, seq_length, heads); + } +} + +template +void transform4d_0213(T* out, + const T* in, + int heads, + int seq_length, + int hidden_dim, + int head_ext, + sycl::nd_item<3> item_ct1); + +template <> +void transform4d_0213(float* out, + const float* in, + int heads, + int seq_length, + int hidden_dim, + int head_ext, + sycl::nd_item<3> item_ct1) +{ + int d0_stride = hidden_dim * seq_length; + int d1_stride = d0_stride / heads; + int d2_stride = hidden_dim / heads; + + int d0_out_stride = d0_stride; + int d1_out_stride = d2_stride; + int d2_out_stride = hidden_dim; + + int d0 = item_ct1.get_group(2); // Batch + int d1 = item_ct1.get_group(1) / ((seq_length - 1) / item_ct1.get_local_range(1) + 1); // Head + int d2 = (item_ct1.get_local_id(1) + item_ct1.get_local_range(1) * item_ct1.get_group(1)) % + seq_length; + int cnt = item_ct1.get_group(0); + int d3 = item_ct1.get_local_id(2); // Values (groups of 8) + + if (d2 < seq_length) { + const sycl::float4* in_vec = reinterpret_cast(in); + sycl::float4* out_vec = reinterpret_cast(out); + + sycl::float4 vals_vec = in_vec[cnt * d0_stride * item_ct1.get_group_range(2) + + d0 * d0_stride + d1 * d1_stride + d2 * d2_stride + d3]; + out_vec[d0 * d0_out_stride * item_ct1.get_group_range(0) + cnt * d2_out_stride + + d1 * d1_out_stride + d2 * d2_out_stride * item_ct1.get_group_range(0) + d3] = + vals_vec; + } +} + +template <> +void transform4d_0213(bf16* out, + const bf16* in, + int heads, + int seq_length, + int hidden_dim, + int head_ext, + sycl::nd_item<3> item_ct1) +{ + int d0_stride = hidden_dim * seq_length; + int d1_stride = d0_stride / heads; + int d2_stride = hidden_dim / heads; + + int d0_out_stride = d0_stride; + int d1_out_stride = d2_stride; + int d2_out_stride = hidden_dim; + + int d0 = item_ct1.get_group(2); // Batch + int d1 = item_ct1.get_group(1) / ((seq_length - 1) / item_ct1.get_local_range(1) + 1); // Head + int d2 = (item_ct1.get_local_id(1) + item_ct1.get_local_range(1) * item_ct1.get_group(1)) % + seq_length; + int cnt = item_ct1.get_group(0); + int d3 = item_ct1.get_local_id(2); // Values (groups of 8) + + if (d2 < seq_length) { + const sycl::ushort4* in_vec = reinterpret_cast(in); + sycl::ushort4* output_vec = reinterpret_cast(out); + + sycl::ushort4 vals_vec = in_vec[cnt * d0_stride * item_ct1.get_group_range(2) + + d0 * d0_stride + d1 * d1_stride + d2 * d2_stride + d3]; + + output_vec[d0 * d0_out_stride * item_ct1.get_group_range(0) + cnt * d2_out_stride + + d1 * d1_out_stride + d2 * d2_out_stride * item_ct1.get_group_range(0) + d3] = + vals_vec; + } +} + +template <> +void transform4d_0213(sycl::half* out, + const sycl::half* in, + int heads, + int seq_length, + int hidden_dim, + int head_ext, + sycl::nd_item<3> item_ct1) +{ + int d0_stride = hidden_dim * (seq_length / head_ext); + int d1_stride = hidden_dim; + int d2_stride = hidden_dim / heads; + + int d0 = item_ct1.get_group(2); // Batch + int d1 = + item_ct1.get_local_id(1) + (item_ct1.get_group(0) % head_ext) * (heads / head_ext); // Head + int d2 = item_ct1.get_group(0) / head_ext; // Sequence + int cnt = item_ct1.get_group(1); // Hidden count + int d3 = item_ct1.get_local_id(2); // Values (groups of 8) + + const sycl::half4* in_vec = reinterpret_cast(in); + sycl::half4* out_vec = reinterpret_cast(out); + + in_vec += (cnt * d0_stride * item_ct1.get_group_range(2)); + in_vec += (d0 * d0_stride); + in_vec += (d2 * d2_stride); + in_vec += (d1 * d2_stride * seq_length); + + out_vec += (cnt * d1_stride); + out_vec += (d1 * d2_stride); + out_vec += (d0 * d0_stride * item_ct1.get_group_range(1)); + out_vec += (d2 * d1_stride * item_ct1.get_group_range(1)); + + out_vec[d3] = in_vec[d3]; +} + +void transform4d_0213_v2(sycl::half* out, + const sycl::half* in, + int heads, + int seq_length, + int hidden_dim, + sycl::nd_item<3> item_ct1, + sycl::float4* in_data) +{ + //__shared__ sycl::float4 in_data[3072]; + + int d0_stride = hidden_dim * seq_length; + int d1_stride = hidden_dim; + int d2_stride = hidden_dim / heads; + + int d0 = item_ct1.get_group(2); // Batch + int d1 = item_ct1.get_local_id(1); // Head + int d2 = item_ct1.get_group(1); // Sequence + int cnt = item_ct1.get_local_id(0); // Hidden count + int d3 = item_ct1.get_local_id(2); // Values (groups of 8) + + const sycl::float4* in_vec = reinterpret_cast(in); + sycl::float4* out_vec = reinterpret_cast(out); + + int input_offset = d0 * d0_stride + d2 * (d2_stride << 1) + d3 + (d1 % 2) * d2_stride; + int head_count = (d1 >> 1) + cnt * (item_ct1.get_local_range(1) >> 1); + int iteration_stride = item_ct1.get_local_range(0) * (item_ct1.get_local_range(1) >> 1); + int matrix_stride = (d0_stride * item_ct1.get_group_range(2)); + +#pragma unroll + for (int iter = 0; iter < 2; iter++) { + int iter_row = iter * iteration_stride + head_count; + int iter_offset = (iter_row % item_ct1.get_local_range(1)) * d2_stride; + + in_data[d3 + iter_offset + + (iter_row / item_ct1.get_local_range(1) + (d1 % 2) * item_ct1.get_local_range(0)) * + d1_stride] = in_vec[input_offset + iter_offset * seq_length + + (iter_row / item_ct1.get_local_range(1)) * matrix_stride]; + } + item_ct1.barrier(); + + iteration_stride = d1_stride * item_ct1.get_local_range(0); + int iter_index = cnt * d1_stride + d1 * d2_stride + d3; + int output_offset = d0 * d0_stride * item_ct1.get_local_range(0) + d2 * (iteration_stride << 1); + +#pragma unroll + for (int iter = 0; iter < 2; iter++) { + int iter_id = iter * iteration_stride + iter_index; + out_vec[output_offset + iter_id] = in_data[iter_id]; + } +} + +// 3 * [B A S N] - > [B S C*H] +template <> +void launch_transform4d_0213(float* out, + const float* in, + int batch_size, + int heads, + int seq_length, + int hidden_dim, + sycl::queue* stream, + int trans_count) +{ + hidden_dim >>= 2; + sycl::range<3> grid_dims(trans_count, heads * ((seq_length - 1) / 8 + 1), batch_size); + sycl::range<3> block_dims(1, 8, hidden_dim / heads); + stream->submit([&](sycl::handler& cgh) { + cgh.parallel_for( + sycl::nd_range<3>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { + transform4d_0213(out, in, heads, seq_length, hidden_dim, 1, item_ct1); + }); + }); + // transform4d_0213 + // <<>>(out, in, heads, seq_length, + // hidden_dim, 1); +} + +template <> +void launch_transform4d_0213(bf16* out, + const bf16* in, + int batch_size, + int heads, + int seq_length, + int hidden_dim, + sycl::queue* stream, + int trans_count) +{ + hidden_dim >>= 2; + sycl::range<3> grid_dims(trans_count, heads * ((seq_length - 1) / 8 + 1), batch_size); + sycl::range<3> block_dims(1, 8, hidden_dim / heads); + stream->submit([&](sycl::handler& cgh) { + cgh.parallel_for( + sycl::nd_range<3>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { + transform4d_0213(out, in, heads, seq_length, hidden_dim, 1, item_ct1); + }); + }); + // transform4d_0213 + // <<>>(out, in, heads, seq_length, + // hidden_dim, 1); +} + +template <> +void launch_transform4d_0213(sycl::half* out, + const sycl::half* in, + int batch_size, + int heads, + int seq_length, + int hidden_dim, + sycl::queue* stream, + int trans_count) +{ + hidden_dim >>= 3; + if (hidden_dim > 128 || hidden_dim < 16) { + int head_ext = (hidden_dim - 1) / MAX_THREADS + 1; + sycl::range<3> grid_dims((seq_length * head_ext), trans_count, batch_size); + sycl::range<3> block_dims(1, (heads / head_ext), hidden_dim / heads); + stream->submit([&](sycl::handler& cgh) { + cgh.parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + transform4d_0213( + out, in, heads, seq_length, hidden_dim, head_ext, item_ct1); + }); + }); + } else { + sycl::range<3> grid_dims(1, seq_length / 2, batch_size); + sycl::range<3> block_dims(trans_count, heads, hidden_dim / heads); + stream->submit([&](sycl::handler& cgh) { + sycl::accessor + data_block_acc_ct1(sycl::range<1>(3072), cgh); + cgh.parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + transform4d_0213_v2(out, + in, + heads, + seq_length, + hidden_dim, + item_ct1, + data_block_acc_ct1.get_pointer()); + }); + }); + // transform4d_0213_v2<<>>( + // out, in, heads, seq_length, hidden_dim); + } +} diff --git a/intel_extension_for_deepspeed/op_builder/fused_adam.py b/intel_extension_for_deepspeed/op_builder/fused_adam.py new file mode 100644 index 0000000..3e2bd5b --- /dev/null +++ b/intel_extension_for_deepspeed/op_builder/fused_adam.py @@ -0,0 +1,31 @@ +""" +Copyright 2020 The Microsoft DeepSpeed Team +""" +from .builder import SYCLOpBuilder, sycl_kernel_path, sycl_kernel_include + + +class FusedAdamBuilder(SYCLOpBuilder): + BUILD_VAR = "DS_BUILD_FUSED_ADAM" + NAME = "fused_adam" + + def __init__(self): + super().__init__(name=self.NAME) + + def absolute_name(self): + return f'deepspeed.ops.adam.{self.NAME}_op' + + def sources(self): + return [ + sycl_kernel_path('csrc/adam/sycl/fused_adam_frontend.cpp'), + sycl_kernel_path('csrc/adam/sycl/multi_tensor_adam.dp.cpp'), + ] + + def include_paths(self): + return [ + sycl_kernel_include('csrc/includes'), + sycl_kernel_include('csrc/adam'), 'csrc/includes' + ] + + def cxx_args(self): + args = super().cxx_args() + return args + self.version_dependent_macros() diff --git a/intel_extension_for_deepspeed/op_builder/quantizer.py b/intel_extension_for_deepspeed/op_builder/quantizer.py new file mode 100644 index 0000000..b6d5ad5 --- /dev/null +++ b/intel_extension_for_deepspeed/op_builder/quantizer.py @@ -0,0 +1,19 @@ +from .builder import SYCLOpBuilder, sycl_kernel_path, sycl_kernel_include + + +class QuantizerBuilder(SYCLOpBuilder): + BUILD_VAR = "DS_BUILD_QUANTIZER" + NAME = "quantizer" + + def __init__(self, name=None): + name = self.NAME if name is None else name + super().__init__(name=name) + + def absolute_name(self): + return f'deepspeed.ops.quantizer.{self.NAME}_op' + + def sources(self): + return [] + + def include_paths(self): + return [] diff --git a/intel_extension_for_deepspeed/op_builder/transformer.py b/intel_extension_for_deepspeed/op_builder/transformer.py new file mode 100644 index 0000000..1150c93 --- /dev/null +++ b/intel_extension_for_deepspeed/op_builder/transformer.py @@ -0,0 +1,47 @@ +""" +Copyright 2020 The Microsoft DeepSpeed Team +""" +from .builder import SYCLOpBuilder, sycl_kernel_path, sycl_kernel_include + + +class TransformerBuilder(SYCLOpBuilder): + BUILD_VAR = "DS_BUILD_TRANSFORMER" + NAME = "transformer" + + def __init__(self, name=None): + name = self.NAME if name is None else name + super().__init__(name=name) + + def absolute_name(self): + return f'deepspeed.ops.transformer.{self.NAME}_op' + + def extra_ldflags(self): + return [] + + def sources(self): + return [ + sycl_kernel_path('csrc/transformer/sycl/onednn_wrappers.dp.cpp'), + sycl_kernel_path( + 'csrc/transformer/sycl/ds_transformer_sycl.dp.cpp'), + sycl_kernel_path('csrc/transformer/sycl/onemkl_wrappers.dp.cpp'), + sycl_kernel_path('csrc/transformer/sycl/transform_kernels.dp.cpp'), + sycl_kernel_path('csrc/transformer/sycl/ds_gelu_sycl.dp.cpp'), + sycl_kernel_path('csrc/transformer/sycl/gelu_kernels.dp.cpp'), + sycl_kernel_path('csrc/transformer/sycl/ds_dropout_sycl.dp.cpp'), + sycl_kernel_path('csrc/transformer/sycl/dropout_kernels.dp.cpp'), + sycl_kernel_path( + 'csrc/transformer/sycl/ds_feedforward_sycl.dp.cpp'), + sycl_kernel_path( + 'csrc/transformer/sycl/ds_layer_reorder_sycl.dp.cpp'), + sycl_kernel_path('csrc/transformer/sycl/ds_normalize_sycl.dp.cpp'), + sycl_kernel_path('csrc/transformer/sycl/normalize_kernels.dp.cpp'), + sycl_kernel_path('csrc/transformer/sycl/ds_softmax_sycl.dp.cpp'), + sycl_kernel_path('csrc/transformer/sycl/softmax_kernels.dp.cpp'), + sycl_kernel_path( + 'csrc/transformer/sycl/ds_stridedbatchgemm_sycl.dp.cpp'), + sycl_kernel_path('csrc/transformer/sycl/general_kernels.dp.cpp') + ] + + def include_paths(self): + includes = [sycl_kernel_include('csrc/includes'), 'csrc/includes'] + return includes diff --git a/intel_extension_for_deepspeed/op_builder/utils.py b/intel_extension_for_deepspeed/op_builder/utils.py new file mode 100644 index 0000000..81f9504 --- /dev/null +++ b/intel_extension_for_deepspeed/op_builder/utils.py @@ -0,0 +1,24 @@ +""" +Copyright 2020 The Microsoft DeepSpeed Team +""" +from .builder import OpBuilder + + +class UtilsBuilder(OpBuilder): + BUILD_VAR = "DS_BUILD_UTILS" + NAME = "utils" + + def __init__(self): + super().__init__(name=self.NAME) + + def absolute_name(self): + return f'deepspeed.ops.{self.NAME}_op' + + def sources(self): + return ['csrc/utils/flatten_unflatten.cpp'] + + def cxx_args(self): + return ['-O3', '-g', '-std=c++20', '-w', '-fPIC'] + + def extra_ldflags(self): + return ['-fPIC', '-Wl,-export-dynamic'] diff --git a/intel_extension_for_deepspeed/xpu_accelerator.py b/intel_extension_for_deepspeed/xpu_accelerator.py new file mode 100644 index 0000000..3aa2a70 --- /dev/null +++ b/intel_extension_for_deepspeed/xpu_accelerator.py @@ -0,0 +1,190 @@ +import torch +from deepspeed.accelerator.abstract_accelerator import DeepSpeedAccelerator +import intel_extension_for_pytorch as ipex # noqa: F401 +import oneccl_bindings_for_pytorch #noqa: F401 + + +class XPU_Accelerator(DeepSpeedAccelerator): + def __init__(self): + self._name = 'xpu' + self._communication_backend_name = 'ccl' + self.DoubleTensor = torch.xpu.DoubleTensor + self.LongTensor = torch.xpu.LongTensor + self.FloatTensor = torch.xpu.FloatTensor + self.BFloat16Tensor = torch.xpu.BFloat16Tensor + self.HalfTensor = torch.xpu.HalfTensor + self.IntTensor = torch.xpu.IntTensor + self.ByteTensor = torch.xpu.ByteTensor + + # Device APIs + def device_name(self, device_index=None): + if device_index == None: + return 'xpu' + return 'xpu:{}'.format(device_index) + + def device(self, device_index=None): + return torch.xpu.device(device_index) + + def set_device(self, device_index): + torch.xpu.set_device(device_index) + + def current_device(self): + return torch.xpu.current_device() + + def current_device_name(self): + return 'xpu:{}'.format(torch.xpu.current_device()) + + def device_count(self): + return torch.xpu.device_count() + + def synchronize(self, device_index=None): + return torch.xpu.synchronize(device_index) + + # RNG APIs + def random(self): + return torch.xpu.random + + def set_rng_state(self, new_state, device_index=None): + return torch.xpu.set_rng_state(new_state, device_index) + + def get_rng_state(self, device_index=None): + if device_index == None: + return torch.xpu.get_rng_state() + return torch.xpu.get_rng_state(device_index) + + def manual_seed(self, seed): + return torch.xpu.manual_seed(seed) + + def manual_seed_all(self, seed): + return torch.xpu.manual_seed_all(seed) + + def initial_seed(self, seed): + return torch.xpu.initial_seed(seed) + + def default_generator(self, device_index): + return torch.xpu.default_generators[device_index] + + # Streams/Events + def Stream(self, device=None, priority=0, **kwargs): + return torch.xpu.Stream(device, priority, **kwargs) + + def StreamContext(self, stream): + return torch.xpu.StreamContext(stream) + + def stream(self, stream): + return torch.xpu.stream(stream) + + def current_stream(self, device_index=None): + return torch.xpu.current_stream(device_index) + + def default_stream(self, device_index=None): + # torch.xpu does not support the sync behavior of default stream as cuda + # use current_stream as workaround + # see https://pytorch.org/docs/stable/notes/cuda.html#cuda-streams + return torch.xpu.current_stream(device_index) + + def Event(self, **kwargs): + return torch.xpu.Event(**kwargs) + + # Memory management + def empty_cache(self): + return torch.xpu.empty_cache() + + def memory_allocated(self, device_index=None): + return torch.xpu.memory_allocated(device_index) + + def max_memory_allocated(self, device_index=None): + return torch.xpu.max_memory_allocated(device_index) + + def reset_max_memory_allocated(self, device_index=None): + return torch.xpu.reset_max_memory_allocated(device_index) + + def memory_cached(self, device_index=None): + return torch.xpu.memory_reserved(device_index) + + def max_memory_cached(self, device_index=None): + return torch.xpu.max_memory_reserved(device_index) + + def reset_max_memory_cached(self, device_index=None): + return torch.xpu.reset_max_memory_reserved(device_index) + + def memory_stats(self, device_index=None): + return torch.xpu.memory_stats(device_index) + + def reset_peak_memory_stats(self, device_index=None): + return torch.xpu.reset_peak_memory_stats(device_index) + + def memory_reserved(self, device_index=None): + return torch.xpu.memory_reserved(device_index) + + def max_memory_reserved(self, device_index=None): + return torch.xpu.max_memory_reserved(device_index) + + def total_memory(self, device_index=None): + return torch.xpu.get_device_properties(device_index).total_memory + + # Misc + def amp(self): + return torch.xpu.amp + + def is_available(self): + return torch.xpu.is_available() + + def range_push(self, msg): + return torch.xpu.itt.range_push(msg) + + def range_pop(self): + return torch.xpu.itt.range_pop() + + def lazy_call(self, callback): + return torch.xpu.lazy_init._lazy_call(callback) + + def communication_backend_name(self): + return self._communication_backend_name + + # Data types + def is_bf16_supported(self): + return True + + def is_fp16_supported(self): + return True + + # Tensor operations + def pin_memory(self, tensor): + return tensor.pin_memory(device=self.current_device_name()) + + def on_accelerator(self, tensor): + device_str = str(tensor.device) + if device_str.startswith('xpu:'): + return True + else: + return False + + def create_op_builder(self, op_name): + from intel_extension_for_deepspeed.op_builder import CPUAdagradBuilder, CPUAdamBuilder, FusedAdamBuilder, QuantizerBuilder, TransformerBuilder, UtilsBuilder + from deepspeed.ops.op_builder import AsyncIOBuilder, SparseAttnBuilder + if op_name == "AsyncIOBuilder": + return AsyncIOBuilder() + elif op_name == "CPUAdagradBuilder": + return CPUAdagradBuilder() + elif op_name == "CPUAdamBuilder": + return CPUAdamBuilder() + elif op_name == "FusedAdamBuilder": + return FusedAdamBuilder() + elif op_name == "QuantizerBuilder": + return QuantizerBuilder() + elif op_name == "SparseAttnBuilder": + return SparseAttnBuilder() + elif op_name == "TransformerBuilder": + return TransformerBuilder() + elif op_name == "UtilsBuilder": + return UtilsBuilder() + else: + return None + + def build_extension(self): + try: + from intel_extension_for_pytorch.xpu.cpp_extension import DpcppBuildExtension + except ImportError: + from intel_extension_for_pytorch.xpu.utils import DpcppBuildExtension + return DpcppBuildExtension diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..086bc03 --- /dev/null +++ b/setup.py @@ -0,0 +1,37 @@ +from setuptools import setup +import subprocess + +version_str = "1.0" +git_branch_cmd = "git rev-parse --abbrev-ref HEAD" +git_hash_cmd = "git rev-parse --short HEAD" + + +def command_exists(cmd): + result = subprocess.Popen(f'type {cmd}', + stdout=subprocess.PIPE, + shell=True) + return result.wait() == 0 + + +if command_exists('git'): + try: + result = subprocess.check_output(git_hash_cmd, shell=True) + git_hash = result.decode('utf-8').strip() + result = subprocess.check_output(git_branch_cmd, shell=True) + git_branch = result.decode('utf-8').strip() + except subprocess.CalledProcessError: + git_hash = "unknown" + git_branch = "unknown" +else: + git_hash = "unknown" + git_branch = "unknown" + +print(f"version={version_str}, git_hash={git_hash}, git_branch={git_branch}") +version_str += f'+{git_hash}' + +setup(name="intel_extension_for_deepspeed", + version=version_str, + description="Intel Extension for DeepSpeed", + author="Intel Corporation", + include_package_data=True, + packages=["intel_extension_for_deepspeed"])