From 997a43b743f492e1deab237a084748013b3d9e2a Mon Sep 17 00:00:00 2001 From: Yi Liu Date: Tue, 7 Jan 2025 05:06:25 +0200 Subject: [PATCH] add pc Signed-off-by: Yi Liu --- auto_round/autoround.py | 13 ++-- auto_round/data_type/fp8.py | 61 +++++++++++++++++++ .../qlinear_triton_gptq.py | 14 ++++- auto_round/quantizer.py | 21 +++++-- 4 files changed, 99 insertions(+), 10 deletions(-) diff --git a/auto_round/autoround.py b/auto_round/autoround.py index c8082aab..403f3ba3 100644 --- a/auto_round/autoround.py +++ b/auto_round/autoround.py @@ -203,7 +203,6 @@ def __init__( all_blocks = get_block_names(model) self.quant_block_list = find_matching_blocks(model, all_blocks, self.to_quant_block_names) self.cache_device = torch.device("cpu") if self.low_gpu_mem_usage else self.device - ##activation self.act_group_size = act_group_size if not (act_group_size is None) else self.group_size @@ -864,8 +863,13 @@ def quant_layer(self, layer_name, inputs, q_inputs=None, device=torch.device("cp if q_inputs is not None: q_inputs[i] = q_inputs[i].to(layer.weight.dtype) - wrapper_linear = WrapperLinear(layer, enable_minmax_tuning=self.enable_minmax_tuning, device=device).to( - device) + wrapper_linear = WrapperLinear( + layer, + enable_minmax_tuning=self.enable_minmax_tuning, + device=device, + _inner_layer_name=layer_name, + ).to(device) + round_params = [] minmax_params = [] for key in wrapper_linear.params.keys(): @@ -1212,6 +1216,8 @@ def quant_blocks( pbar = tqdm(range(0, len(block_names), nblocks)) # for i in pbar: for i in range(len(block_names)): + if os.getenv("DEBUG_QUANT_BLOCK", "0") == "1" and i > 2: + break if nblocks == 1: n = block_names[i] pbar.set_description(f"Quantizing {n}") @@ -1743,4 +1749,3 @@ def __init__( optimizer=optimizer, **kwargs, ) - diff --git a/auto_round/data_type/fp8.py b/auto_round/data_type/fp8.py index b236f8bc..c4f6e056 100644 --- a/auto_round/data_type/fp8.py +++ b/auto_round/data_type/fp8.py @@ -292,6 +292,67 @@ def progressive_quant_fp8_int4(tensor, bits=4, group_size=-1, v=0, min_scale=1.0 return qdq_tensor, (scale_fp8_to_int4 * scale_bf16_to_fp8, scale_bf16_to_fp8), zp_fp8_to_int4 + +@register_dtype("fp8_gaudi2_to_int_sym_pc") +def progressive_quant_fp8_int4_per_channel(tensor, bits=4, group_size=-1, v=0, min_scale=1.0, max_scale=1.0, q_scale_thresh=1e-5, + weight_fp8_max_scale=1.0,**kwargs): + """ + Per-Channle quantization + + Two-stage quantization: quantize tensor to fp8 by per tensor, then quantize fp8 to w4g128 + + This method first quantizes the input tensor into float8 format and then performs + a secondary quantization to int4 with grouping. + + Args: + tensor (torch.Tensor): Input tensor to quantize. + bits (int, optional): Bit precision for secondary quantization. Defaults to 4. + group_size (int, optional): Group size for int4 quantization. Defaults to -1 (no grouping). + v (float, optional): Optional parameter for variance tuning. Defaults to 0. + min_scale (float, optional): Minimum scaling factor for int4 quantization. Defaults to 1.0. + max_scale (float, optional): Maximum scaling factor for int4 quantization. Defaults to 1.0. + q_scale_thresh (float, optional): Threshold for scaling. Defaults to 1e-5. + weight_fp8_max_scale (float, optional): Maximum scaling factor for float8 quantization. Defaults to 1.0. + **kwargs: Additional arguments for compatibility. + + Returns: + tuple: + - Quantized and dequantized tensor (torch.Tensor). + - Combined scaling factor (torch.Tensor). + - Placeholder for zp (None). + """ + # tensor: [out_feats, in_feats] + # scale_bf16_to_fp8: [out_feats, 1] + out_feats, in_feats = tensor.shape + fp8_max = STANDARD_FP8E4M3FN_MAX * global_config.FP8_WEIGHT_BACKOFF + dim = 1 + tensor_max = torch.max(torch.abs(tensor), dim=dim, keepdim=True)[0].to(torch.float32) * weight_fp8_max_scale ## better train a ratio + scale = tensor_max.to(torch.float32) / fp8_max + min_scaling_factor = 1.0 / (fp8_max* 512.0) ##copy from vllm + scale_bf16_to_fp8 = torch.clip(scale, min=min_scaling_factor) + fp8_res = tensor / scale_bf16_to_fp8 + fp8_res = torch.clip(fp8_res, -fp8_max, fp8_max) + float8_e4m3fn_ste_gaudi2 = get_gaudi2_fp8_ste_func() + fp8_res = float8_e4m3fn_ste_gaudi2(fp8_res) + + ##convert to bf16 + fp8_res_using_16bit = fp8_res.to(tensor.dtype) + ##convert to int4 + from auto_round.data_type.int import quant_tensor_sym + qdq_int4_tensor, scale_fp8_to_int4, zp_fp8_to_int4 = quant_tensor_sym(fp8_res_using_16bit, bits=bits, + group_size=group_size, v=v, + min_scale=min_scale, + max_scale=max_scale, + scale_dtype=torch.bfloat16, + q_scale_thresh=q_scale_thresh) + qdq_tensor = qdq_int4_tensor * scale_bf16_to_fp8 + scale_fp8_to_int4_with_group = scale_fp8_to_int4 + scale_fp8_to_int4_with_group_reshape_back = scale_fp8_to_int4_with_group.reshape(out_feats, -1) + scale_bf16_to_int4 = scale_fp8_to_int4_with_group_reshape_back * scale_bf16_to_fp8 + scale_bf16_to_int4_with_group = scale_bf16_to_int4.reshape(-1, 1) + return qdq_tensor, (scale_bf16_to_int4_with_group, scale_bf16_to_fp8), zp_fp8_to_int4 + # return qdq_tensor, (scale_fp8_to_int4 * scale_bf16_to_fp8, scale_bf16_to_fp8), zp_fp8_to_int4 + @register_dtype("fp8_gaudi2_to_int_sym_v2") def progressive_quant_fp8_int4_v2(tensor, bits=4, group_size=-1, v=0, min_scale=1.0, max_scale=1.0, q_scale_thresh=1e-5, weight_fp8_max_scale=1.0,**kwargs): diff --git a/auto_round/export/export_to_autoround/qlinear_triton_gptq.py b/auto_round/export/export_to_autoround/qlinear_triton_gptq.py index e8caf173..eb2a2c78 100644 --- a/auto_round/export/export_to_autoround/qlinear_triton_gptq.py +++ b/auto_round/export/export_to_autoround/qlinear_triton_gptq.py @@ -40,7 +40,7 @@ import torch import torch.nn as nn import transformers - +import os # from auto_round_extension.cuda.triton_utils.mixin import TritonModuleMixin logger = getLogger(__name__) @@ -117,10 +117,12 @@ def __init__(self, bits, group_size, infeatures, outfeatures, bias, trainable=Fa dtype=torch.bfloat16, ), ) + + _shape = (1, self.outfeatures) if os.environ.get("W4A8_PC", "0") == "1" else (1) self.register_buffer( "w_bf16_to_fp8_scale", torch.zeros( - (1), + _shape, dtype=torch.bfloat16, ), ) @@ -143,6 +145,14 @@ def __init__(self, bits, group_size, infeatures, outfeatures, bias, trainable=Fa self.trainable = trainable + def __repr__(self): + return ( + f"{self.__class__.__name__}({self.infeatures}, {self.outfeatures}, " + f"bits={self.bits}, group_size={self.group_size}," + f"scales shape: {self.scales.shape}, act_scales shape: {self.act_scales.shape}, w_bf16_to_fp8_scale shape: {self.w_bf16_to_fp8_scale.shape}" + f")" + ) + def post_init(self): pass diff --git a/auto_round/quantizer.py b/auto_round/quantizer.py index 447eb05b..4f5886e6 100644 --- a/auto_round/quantizer.py +++ b/auto_round/quantizer.py @@ -22,6 +22,7 @@ set_module, logger ) +from loguru import logger as rich_logger def reshape_and_pad_tensor(v, group_size=-1): @@ -58,7 +59,14 @@ class WrapperLinear(torch.nn.Module): device (str): Device on which to run computations (e.g., 'cpu' or 'cuda'). """ - def __init__(self, orig_layer, enable_minmax_tuning=True, enable_norm_bias_tuning=False, device='cpu'): + def __init__( + self, + orig_layer, + enable_minmax_tuning=True, + enable_norm_bias_tuning=False, + device="cpu", + _inner_layer_name=None, + ): """Initializes the WrapperLinear module. Args: @@ -68,6 +76,7 @@ def __init__(self, orig_layer, enable_minmax_tuning=True, enable_norm_bias_tunin device (str): The computation device, such as 'cpu' or 'cuda'. """ super(WrapperLinear, self).__init__() + self._inner_layer_name = _inner_layer_name self.orig_layer = orig_layer self.device = device self.enable_minmax_tuning = enable_minmax_tuning @@ -152,7 +161,6 @@ def _qdq_weight(self, value, min_scale, max_scale): weight = self.orig_layer.get_weight().to(self.device) if isinstance(self.orig_layer, transformers.modeling_utils.Conv1D): weight = weight.t() - weight_q, scale, zp = self.weight_quant_func(weight, bits=self.orig_layer.bits, group_size=self.orig_layer.group_size, v=value, min_scale=min_scale, max_scale=max_scale, @@ -468,8 +476,13 @@ def wrapper_block(block, enable_minmax_tuning, enable_norm_bias_tuning, device=' if not check_to_quantized(m): unquantized_layers.append(n) continue - new_m = WrapperLinear(m, enable_minmax_tuning=enable_minmax_tuning, - enable_norm_bias_tuning=enable_norm_bias_tuning, device=device) + new_m = WrapperLinear( + m, + enable_minmax_tuning=enable_minmax_tuning, + enable_norm_bias_tuning=enable_norm_bias_tuning, + device=device, + _inner_layer_name=n, + ) set_module(block, n, new_m) quantized_layers.append(n)