diff --git a/csrc/fp_quantizer/fp_quantize.cpp b/csrc/fp_quantizer/fp_quantize.cpp index 1a887b50e1a3..ddbc985e30ff 100644 --- a/csrc/fp_quantizer/fp_quantize.cpp +++ b/csrc/fp_quantizer/fp_quantize.cpp @@ -29,13 +29,13 @@ at::Tensor quantize(torch::Tensor& out, int q_bits, int q_mantisa_bits) { - int total_elems = at::numel(val); + size_t total_elems = at::numel(val); float q_range = q_bits == 8 ? (q_mantisa_bits == 3 ? 480.0 : 114688.0) : // fp8 ranges (q_bits == 12 ? 510.0 : // fp12 range (q_bits == 6 ? 28.0 : // fp6 range 6.0)); // fp4 range (using power 2); TODO (Reza): add the power-4 // in case accuracy is not matching! - int num_groups = total_elems / group_size; + size_t num_groups = total_elems / group_size; DISPATCH_QUANTIZE(kHalf, __half, 23, 8); #ifdef BF16_AVAILABLE @@ -45,6 +45,18 @@ at::Tensor quantize(torch::Tensor& out, return out; } +at::Tensor get_scales(torch::Tensor& out, int num_groups) +{ + auto options = at::TensorOptions() + .dtype(torch::kFloat) + .layout(at::kStrided) + .device(at::kCUDA) + .requires_grad(false); + auto scales = + torch::from_blob(out.data_ptr(), {num_groups, 1}, {out.stride(0) / 4, 1}, options); + return scales; +} + #define DISPATCH_DEQUANTIZE(T_TYPE, C_TYPE, mantisa) \ if (val.options().dtype() == torch::T_TYPE) { \ launch_dequantization((uint8_t*)val_q.data_ptr(), \ @@ -63,9 +75,9 @@ void dequantize(torch::Tensor& val, int q_mantisa_bits, int q_exponent_bits) { - int total_elems = at::numel(val); + size_t total_elems = at::numel(val); - int num_groups = total_elems / group_size; + size_t num_groups = total_elems / group_size; DISPATCH_DEQUANTIZE(kHalf, __half, 10); #ifdef BF16_AVAILABLE @@ -93,9 +105,9 @@ void selective_dequantize(torch::Tensor& val, int q_mantisa_bits, int q_exponent_bits) { - int total_elems = at::numel(val); + size_t total_elems = at::numel(val); int num_indexes = indexes.size(0); - int num_groups = total_elems / group_size; + size_t num_groups = total_elems / group_size; DISPATCH_DEQUANTIZE_INDEX(kHalf, __half, 10); #ifdef BF16_AVAILABLE diff --git a/csrc/fp_quantizer/fp_quantize.cu b/csrc/fp_quantizer/fp_quantize.cu index 66ea7392e011..cb1f234fe94e 100644 --- a/csrc/fp_quantizer/fp_quantize.cu +++ b/csrc/fp_quantizer/fp_quantize.cu @@ -71,10 +71,10 @@ __global__ void apply_quantization(T* val, std::pair seed, float q_range) { - int tidx = threadIdx.x; - int wid = tidx >> 5; - int lane = tidx & 0x1f; - int gid = blockIdx.x * quantization::warps + wid; + unsigned int tidx = threadIdx.x; + unsigned int wid = tidx >> 5; + unsigned int lane = tidx & 0x1f; + unsigned int gid = blockIdx.x * quantization::warps + wid; constexpr int q_exponent_bits = total_q_bits - q_mantisa_bits - 1; constexpr uint32_t _mantisa_mask = (1 << _mantisa_bits) - 1; @@ -98,7 +98,7 @@ __global__ void apply_quantization(T* val, T cur_max; reduce::init(&cur_max); - int idx = blockIdx.x * blockDim.x + threadIdx.x; + unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x; curandStatePhilox4_32_10_t state; curand_init(seed.first, idx, seed.second, &state); @@ -228,7 +228,7 @@ template void launch_quantization(T* val, uint8_t* q_val, - int num_groups, + size_t num_groups, int group_size, cudaStream_t stream, float q_range, @@ -344,12 +343,12 @@ void launch_quantization(T* val, { const dim3 grid((num_groups + quantization::warps - 1) / quantization::warps); const dim3 block(quantization::threads); - std::pair seed = FPContext::Instance().IncrementOffset(16); constexpr int vals_per_unroll = hw_warp_size * quantization::access_granularity / sizeof(T); const int copy_unroll = (group_size + vals_per_unroll - 1) / vals_per_unroll; + QUANT_SWITCH((q_bits - q_mantisa_bits - 1) * q_mantisa_bits + stochastic_rounding, [&] { switch (copy_unroll) { LAUNCH_FOR_QUANTIZATION_UNROLL(1) @@ -363,7 +362,7 @@ void launch_quantization(T* val, } #define INSTANTIATE_LAUNCH_QUANTIZATION(T, mantisa, exponent) \ template void launch_quantization( \ - T*, uint8_t*, int, int, cudaStream_t, float q_range, int, int, int); + T*, uint8_t*, size_t, int, cudaStream_t, float q_range, int, int, int); // fp8(E4M3), nearest-rounding #ifdef BF16_AVAILABLE INSTANTIATE_LAUNCH_QUANTIZATION(__nv_bfloat16, 23, 8); @@ -373,7 +372,7 @@ INSTANTIATE_LAUNCH_QUANTIZATION(__half, 23, 8); template void launch_dequantization(uint8_t* val, T* q_val, - int num_groups, + size_t num_groups, int group_size, int q_mantisa_bits, int q_exponent_bits, @@ -390,7 +389,8 @@ void launch_dequantization(uint8_t* val, }); } #define INSTANTIATE_LAUNCH_DEQUANTIZATION(T, mantisa) \ - template void launch_dequantization(uint8_t*, T*, int, int, int, int, cudaStream_t); + template void launch_dequantization( \ + uint8_t*, T*, size_t, int, int, int, cudaStream_t); // fp8(E4M3) #ifdef BF16_AVAILABLE INSTANTIATE_LAUNCH_DEQUANTIZATION(__nv_bfloat16, 7); @@ -406,12 +406,12 @@ __global__ void apply_selective_dequantization(uint8_t* val, T* q_val, int32_t* indexes, int group_size, - int total_num_elements) + size_t total_num_elements) { - int index = indexes[blockIdx.x]; + unsigned int index = indexes[blockIdx.x]; constexpr uint32_t vector_size = quantization::access_granularity / sizeof(T); - int tidx = (blockIdx.y * blockDim.x + threadIdx.x) * vector_size; - int input_index = index * total_num_elements + tidx; + unsigned int tidx = (blockIdx.y * blockDim.x + threadIdx.x) * vector_size; + unsigned int input_index = index * total_num_elements + tidx; constexpr int quantized_bits = _mantisa_bits + _exponent_bits + 1; constexpr int q_exponent_bits = total_q_bits - q_mantisa_bits - 1; constexpr uint16_t _mantisa_mask = (1 << _mantisa_bits) - 1; @@ -504,17 +504,17 @@ template void launch_selective_dequantization(uint8_t* val, T* q_val, int32_t* indexes, - int num_groups, + size_t num_groups, int group_size, int num_indexes, int q_mantisa_bits, int q_exponent_bits, cudaStream_t stream) { - int total_elements_per_index = (num_groups / num_indexes) * group_size; - int blocks = (total_elements_per_index - 1) / - (quantization::threads * (quantization::access_granularity / sizeof(T))) + - 1; + size_t total_elements_per_index = (num_groups / num_indexes) * group_size; + size_t blocks = (total_elements_per_index - 1) / + (quantization::threads * (quantization::access_granularity / sizeof(T))) + + 1; const dim3 grid(num_indexes, blocks); const dim3 block(quantization::threads); DEQUANT_SWITCH(q_mantisa_bits * q_exponent_bits, [&] { @@ -524,7 +524,7 @@ void launch_selective_dequantization(uint8_t* val, } #define INSTANTIATE_LAUNCH_SELECTIVE_DEQUANTIZATION(T, mantisa) \ template void launch_selective_dequantization( \ - uint8_t*, T*, int32_t*, int, int, int, int, int, cudaStream_t); + uint8_t*, T*, int32_t*, size_t, int, int, int, int, cudaStream_t); // fp8(E4M3) #ifdef BF16_AVAILABLE INSTANTIATE_LAUNCH_SELECTIVE_DEQUANTIZATION(__nv_bfloat16, 7); diff --git a/csrc/fp_quantizer/includes/fp_quantize.h b/csrc/fp_quantizer/includes/fp_quantize.h index 60c75541f603..8b4c3af61b09 100644 --- a/csrc/fp_quantizer/includes/fp_quantize.h +++ b/csrc/fp_quantizer/includes/fp_quantize.h @@ -99,7 +99,7 @@ template void launch_quantization(T* val, uint8_t* q_val, - int num_groups, + size_t num_groups, int group_size, cudaStream_t stream, float q_range, @@ -110,7 +110,7 @@ void launch_quantization(T* val, template void launch_dequantization(uint8_t* val, T* q_val, - int num_groups, + size_t num_groups, int group_size, int q_mantisa_bits, int q_exponent_bits, @@ -120,7 +120,7 @@ template void launch_selective_dequantization(uint8_t* val, T* q_val, int32_t* indexes, - int num_groups, + size_t num_groups, int group_size, int num_indexes, int q_mantisa_bits, diff --git a/deepspeed/linear/quantization.py b/deepspeed/linear/quantization.py index 70fabea845ba..a4d55b9027b3 100644 --- a/deepspeed/linear/quantization.py +++ b/deepspeed/linear/quantization.py @@ -62,7 +62,8 @@ def _ensure_quantized(self, tensor: torch.Tensor): tensor.data = self.quantizer.quantize(tensor.data, q_bits=self.quantization_config.q_bits, q_mantisa_bits=self.quantization_config.mantissa_bits) - assert tensor.dtype == torch.uint8 + assert (tensor.dtype == torch.int8), \ + f"Quantize conversion dtype ({tensor.dtype}) error!" def dequantized(self) -> torch.Tensor: """ diff --git a/deepspeed/ops/__init__.py b/deepspeed/ops/__init__.py index 15179984173c..574d7c6db381 100755 --- a/deepspeed/ops/__init__.py +++ b/deepspeed/ops/__init__.py @@ -13,3 +13,4 @@ from .transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig from ..git_version_info import compatible_ops as __compatible_ops__ +from . import fp_quantizer \ No newline at end of file diff --git a/deepspeed/ops/fp_quantizer/quantize.py b/deepspeed/ops/fp_quantizer/quantize.py index edd4ef57302c..3ce750cf7541 100644 --- a/deepspeed/ops/fp_quantizer/quantize.py +++ b/deepspeed/ops/fp_quantizer/quantize.py @@ -54,8 +54,9 @@ def quantize(self, q_bits=8, q_mantisa_bits=3, stochastic_mode=False, - return_meta_tensor=False) -> torch.Tensor: - assert input.dtype == torch.bfloat16, "only support bf16 for now" + return_meta_tensor=False, + out=None) -> torch.Tensor: + assert input.dtype == torch.bfloat16, f"only support bf16 for now, dtype is {input.dtype}" if return_meta_tensor: assert q_bits == 8, "meta tensor is only supported with q_bit=8" @@ -73,23 +74,23 @@ def quantize(self, else: assert (0), \ f"Missing {q_bits}-quantization, please add the template arguments for the kernel to support this precision!" + self.num_groups = input.numel() // self.group_size - self.input_q = torch.ones(self.num_groups, - int(self.group_size * q_bits) // 8 + 4, - dtype=torch.uint8, - device=input.device) - out = fp_quant_module.quantize(self.input_q, input, self.group_size, stochastic_mode, q_bits, q_mantisa_bits) + self.input_q = torch.ones( + self.num_groups, int(self.group_size * q_bits) // 8 + + 4, dtype=torch.uint8, device=input.device) if out is None else out + input_q_reshaped = fp_quant_module.quantize(self.input_q, input, self.group_size, stochastic_mode, q_bits, + q_mantisa_bits) if return_meta_tensor: - data, self.scale = out.split(self.group_size, dim=-1) - data = data.contiguous().reshape(input.shape) - self.scale = self.scale.contiguous() + self.scales = input_q_reshaped[:, -4:].contiguous().reshape(-1, 4) + input_q_reshaped = self.input_q[:, :-4].contiguous().reshape(self.orig_shape) del self.input_q - del out - gc.collect() - get_accelerator().empty_cache() - return data, self.scale + self.input_q = None + return input_q_reshaped, self.scales + return input_q_reshaped - return out + def get_scales(self): + return fp_quant_module.get_scales(self.scales, self.num_groups) def to(self, *args, **kwargs): # Intermediate tensors may need to be moved to different devices @@ -123,6 +124,7 @@ def dequantize(self, input_q, fp_out=None, q_bits=8, q_mantisa_bits=3, scale=Non f'[De-quantization Error]: quantized data should have the same size as original tensor when scale is not None!' input_q = torch.cat([input_q.reshape(-1, self.group_size), scale], dim=-1).contiguous() fp_quant_module.dequantize(fp_out, input_q, self.group_size, q_mantisa_bits, q_bits - q_mantisa_bits - 1) + return fp_out def selective_dequantize(self, @@ -151,11 +153,6 @@ def selective_dequantize(self, assert (0), \ f"Missing {q_bits}-dequantization, please add the template arguments for the kernel to support this precision!" - if scale is not None: - assert input_q.numel() == fp_out.numel(), \ - f'[De-quantization Error]: quantized data should have the same size as original tensor when scale is not None!' - input_q = torch.cat([input_q.reshape(-1, self.group_size), scale], dim=-1).contiguous() - fp_quant_module.selective_dequantize(fp_out, input_q, indexes, self.group_size, q_mantisa_bits, q_bits - q_mantisa_bits - 1) return fp_out