Skip to content

Commit

Permalink
Merge branch 'master' into duli/capability
Browse files Browse the repository at this point in the history
  • Loading branch information
loadams authored Apr 24, 2024
2 parents 6ad67c7 + fbdf0ea commit 89b9f55
Show file tree
Hide file tree
Showing 16 changed files with 639 additions and 63 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/cpu-torch-latest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,5 +50,5 @@ jobs:
run: |
unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch
cd tests
TRANSFORMERS_CACHE=/tmp/transformers_cache/ pytest $PYTEST_OPTS -n 4 unit/ --torch_ver="2.2"
TRANSFORMERS_CACHE=/tmp/transformers_cache/ pytest $PYTEST_OPTS -m 'sequential' unit/ --torch_ver="2.2"
TRANSFORMERS_CACHE=/tmp/transformers_cache/ pytest $PYTEST_OPTS -n 4 unit/ --torch_ver="2.3"
TRANSFORMERS_CACHE=/tmp/transformers_cache/ pytest $PYTEST_OPTS -m 'sequential' unit/ --torch_ver="2.3"
2 changes: 1 addition & 1 deletion .github/workflows/nv-torch-latest-v100.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:

- name: Install pytorch
run: |
pip install -U --cache-dir $TORCH_CACHE torch torchvision --index-url https://download.pytorch.org/whl/cu118
pip install -U --cache-dir $TORCH_CACHE torch==2.2.2 torchvision --index-url https://download.pytorch.org/whl/cu118
python -c "import torch; print('torch:', torch.__version__, torch)"
python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
<b> <span style="color:orange" > DeepSpeed empowers ChatGPT-like model training with a single click, offering 15x speedup over SOTA RLHF systems with unprecedented cost reduction at all scales; [learn how](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat)</span>.</b>

* [2024/03] [DeepSpeed-FP6:The power of FP6-Centric Serving for Large Language Models](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fp6/03-05-2024) [[English](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fp6/03-05-2024/README.md)] [[中文](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fp6/03-05-2024/README-Chinese.md)]
* [2024/01] [DeepSpeed-FastGen: Introducting Mixtral, Phi-2, and Falcon support with major performance and feature enhancements.](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fastgen/2024-01-19)
* [2024/01] [DeepSpeed-FastGen: Introducing Mixtral, Phi-2, and Falcon support with major performance and feature enhancements.](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fastgen/2024-01-19)
* [2023/11] [Llama 2 Inference on 4th Gen Intel® Xeon® Scalable Processor with DeepSpeed](https://github.com/microsoft/DeepSpeed/tree/master/blogs/intel-inference) [[Intel version]](https://www.intel.com/content/www/us/en/developer/articles/technical/xllama-2-on-xeon-scalable-processor-with-deepspeed.html)
* [2023/11] [DeepSpeed ZeRO-Offload++: 6x Higher Training Throughput via Collaborative CPU/GPU Twin-Flow](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-offloadpp)
* [2023/11] [DeepSpeed-FastGen: High-throughput Text Generation for LLMs via MII and DeepSpeed-Inference](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fastgen) [[English](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fastgen)] [[中文](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fastgen/chinese/README.md)] [[日本語](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fastgen/japanese/README.md)]
Expand Down
77 changes: 56 additions & 21 deletions csrc/adam/multi_tensor_adam.cu
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ typedef enum : int {

using MATH_T = float;

template <typename T>
template <typename T, typename index_t>
struct AdamFunctor {
__device__ __forceinline__ void operator()(int chunk_size,
volatile int* noop_gmem,
Expand All @@ -48,13 +48,13 @@ struct AdamFunctor {
// if(*noop_gmem == 1)
// return;

int tensor_loc = tl.block_to_tensor[blockIdx.x];
index_t tensor_loc = tl.block_to_tensor[blockIdx.x];

// potentially use to pass in list of scalar
// int tensor_num = tl.start_tensor_this_launch + tensor_loc;

int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
index_t chunk_idx = tl.block_to_chunk[blockIdx.x];
index_t n = tl.sizes[tensor_loc];

T* g = (T*)tl.addresses[0][tensor_loc];
g += chunk_idx * chunk_size;
Expand All @@ -71,7 +71,8 @@ struct AdamFunctor {
n -= chunk_idx * chunk_size;

// see note in multi_tensor_scale_kernel.cu
for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) {
for (index_t i_start = 0; i_start < n && i_start < chunk_size;
i_start += blockDim.x * ILP) {
MATH_T r_g[ILP];
MATH_T r_p[ILP];
MATH_T r_m[ILP];
Expand Down Expand Up @@ -146,23 +147,57 @@ void multi_tensor_adam_cuda(int chunk_size,
bias_correction2 = 1 - std::pow(beta2, step);
}

size_t max_size = 0;
bool requires_64bit_indexing = false;
for (auto it = tensor_lists.begin(); it != tensor_lists.end(); it++) {
for (auto it2 = it->begin(); it2 != it->end(); it2++) {
if (it2->numel() > max_size) {
max_size = it2->numel();
if (max_size >= INT_MAX) {
requires_64bit_indexing = true;
break;
}
}
}
if (requires_64bit_indexing) { break; }
}

// Assume single type across p,g,m1,m2 now
DISPATCH_DOUBLE_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(),
0,
"adam",
multi_tensor_apply<4>(BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<scalar_t_0>(),
beta1,
beta2,
bias_correction1,
bias_correction2,
epsilon,
lr,
(adamMode_t)mode,
weight_decay);)
if (requires_64bit_indexing) {
DISPATCH_DOUBLE_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(),
0,
"adam",
multi_tensor_apply<4>((int64_t)BLOCK_SIZE,
(int64_t)chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<scalar_t_0, int64_t>(),
beta1,
beta2,
bias_correction1,
bias_correction2,
epsilon,
lr,
(adamMode_t)mode,
weight_decay);)
} else {
DISPATCH_DOUBLE_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(),
0,
"adam",
multi_tensor_apply<4>(BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<scalar_t_0, int32_t>(),
beta1,
beta2,
bias_correction1,
bias_correction2,
epsilon,
lr,
(adamMode_t)mode,
weight_decay);)
}

AT_CUDA_CHECK(cudaGetLastError());
}
10 changes: 5 additions & 5 deletions csrc/adam/multi_tensor_apply.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ struct TensorListMetadata {
};

template <typename T, typename U, typename... ArgTypes>
__global__ void multi_tensor_apply_kernel(int chunk_size,
__global__ void multi_tensor_apply_kernel(int64_t chunk_size,
volatile int* noop_flag,
T tl,
U callable,
Expand All @@ -46,8 +46,8 @@ __global__ void multi_tensor_apply_kernel(int chunk_size,
}

template <int depth, typename T, typename... ArgTypes>
void multi_tensor_apply(int block_size,
int chunk_size,
void multi_tensor_apply(int64_t block_size,
int64_t chunk_size,
const at::Tensor& noop_flag,
const std::vector<std::vector<at::Tensor>>& tensor_lists,
T callable,
Expand Down Expand Up @@ -91,9 +91,9 @@ void multi_tensor_apply(int block_size,
tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr();
loc_tensor_info++;

int chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size;
auto chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size;

for (int chunk = 0; chunk < chunks_this_tensor; chunk++) {
for (auto chunk = 0; chunk < chunks_this_tensor; chunk++) {
// std::cout << chunks_this_tensor << std::endl;
tl.block_to_tensor[loc_block_info] = loc_tensor_info - 1;
tl.block_to_chunk[loc_block_info] = chunk;
Expand Down
34 changes: 10 additions & 24 deletions deepspeed/checkpoint/ds_to_universal.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import argparse
import glob
import itertools
import multiprocessing
from concurrent.futures import ProcessPoolExecutor
import os
import re
import shutil
Expand Down Expand Up @@ -292,27 +292,18 @@ def get_matched_sub_params_pattern(name_):
return unmatched_patterns


def _get_chunks(l, n):
for i in range(0, len(l), n):
yield l[i:i + n]


def _do_parallel_work(do_work, work_chunks, num_workers):
results = []
if num_workers > 1:
pool = multiprocessing.Pool(num_workers)
results = []
for batch in tqdm.tqdm(work_chunks):
res = pool.map(do_work, batch)
results.extend(res)
pool.close()
pool.join()
with ProcessPoolExecutor(max_workers=num_workers) as executor:
future_list = [executor.submit(do_work, work) for work in work_chunks]
for f in tqdm.tqdm(future_list):
results.append(f.result())
else:
# No parallel pass for unit testing
# We can't create child processes in tests
results = []
for batch in tqdm.tqdm(work_chunks):
res = [do_work(x) for x in batch]
results.extend(res)
for work in tqdm.tqdm(work_chunks):
results.append(do_work(work))
return results


Expand All @@ -321,20 +312,15 @@ def _extract_zero_shard_files(args, ds_checkpoint, temp_dir):
itertools.product(range(ds_checkpoint.pp_degree), range(ds_checkpoint.tp_degree),
range(ds_checkpoint.dp_degree)))
#pprint(f'{_3d_range_list=}')
work_chunks = list(_get_chunks(_3d_range_list, args.num_extract_workers))
#pprint(f'{work_chunks=}')

# extract_zero_shards(temp_dir, ds_checkpoint, _3d_range_list[0])
do_work = partial(extract_zero_shards, temp_dir, ds_checkpoint)
_do_parallel_work(do_work, work_chunks, args.num_extract_workers)
_do_parallel_work(do_work, _3d_range_list, args.num_extract_workers)


def _merge_tp_slice_files(args, ds_checkpoint, slice_shapes, temp_dir):
work_chunks = list(_get_chunks(list(slice_shapes.items()), args.num_merge_workers))
#pprint(work_chunks)
zero_output_folder = os.path.join(args.output_folder, "zero")
do_work = partial(merge_tp_slices, ds_checkpoint, zero_output_folder, temp_dir, ds_checkpoint.tp_degree)
unmatched_patterns_lists = _do_parallel_work(do_work, work_chunks, args.num_merge_workers)
unmatched_patterns_lists = _do_parallel_work(do_work, list(slice_shapes.items()), args.num_merge_workers)

# verify that all patterns were used
# if a pattern was not used by any of the workers, then it was not used at all -> assert/alert
Expand Down
7 changes: 7 additions & 0 deletions deepspeed/linear/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from .optimized_linear import OptimizedLinear
from .config import LoRAConfig, QuantizationConfig
39 changes: 39 additions & 0 deletions deepspeed/linear/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from dataclasses import dataclass


@dataclass
class LoRAConfig:
"""
Configuration settings for LoRAOptimizedLinear.
Attributes:
lora_r (int): LoRA attention dimension, also know as the rank. Defaults is 64.
lora_alpha (float): LoRA scaling factor, default is 16.
base_weight_sharding (int): The degree to which the base weights are sharded,
should typically be set to the data-parallel world size to maximize the memory
reduction benefits. Defaults to 1, which means this feature is disabled.
"""
lora_r: int = 64
lora_alpha: float = 16.
base_weight_sharding: int = 1


@dataclass
class QuantizationConfig:
"""
Configuration settings for quantization for LoRAOptimizedLinear, QuantizedLinear,
and QuantizedParameter
Attributes:
q_bits (int): The number of bits used for quantization. Default is 8.
mantissa_bits (int): The number of bits reserved for the mantissa in fixed-point quantization. Default is 3.
group_size (int): The size of the group used for quantization. Default is 512.
"""
q_bits: int = 8
mantissa_bits: int = 3
group_size: int = 512
Loading

0 comments on commit 89b9f55

Please sign in to comment.