Skip to content

Commit

Permalink
add pc
Browse files Browse the repository at this point in the history
Signed-off-by: Yi Liu <[email protected]>
  • Loading branch information
Yi4Liu committed Jan 7, 2025
1 parent 14a59a2 commit 997a43b
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 10 deletions.
13 changes: 9 additions & 4 deletions auto_round/autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,6 @@ def __init__(
all_blocks = get_block_names(model)
self.quant_block_list = find_matching_blocks(model, all_blocks, self.to_quant_block_names)
self.cache_device = torch.device("cpu") if self.low_gpu_mem_usage else self.device


##activation
self.act_group_size = act_group_size if not (act_group_size is None) else self.group_size
Expand Down Expand Up @@ -864,8 +863,13 @@ def quant_layer(self, layer_name, inputs, q_inputs=None, device=torch.device("cp
if q_inputs is not None:
q_inputs[i] = q_inputs[i].to(layer.weight.dtype)

wrapper_linear = WrapperLinear(layer, enable_minmax_tuning=self.enable_minmax_tuning, device=device).to(
device)
wrapper_linear = WrapperLinear(
layer,
enable_minmax_tuning=self.enable_minmax_tuning,
device=device,
_inner_layer_name=layer_name,
).to(device)

round_params = []
minmax_params = []
for key in wrapper_linear.params.keys():
Expand Down Expand Up @@ -1212,6 +1216,8 @@ def quant_blocks(
pbar = tqdm(range(0, len(block_names), nblocks))
# for i in pbar:
for i in range(len(block_names)):
if os.getenv("DEBUG_QUANT_BLOCK", "0") == "1" and i > 2:
break
if nblocks == 1:
n = block_names[i]
pbar.set_description(f"Quantizing {n}")
Expand Down Expand Up @@ -1743,4 +1749,3 @@ def __init__(
optimizer=optimizer,
**kwargs,
)

61 changes: 61 additions & 0 deletions auto_round/data_type/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,67 @@ def progressive_quant_fp8_int4(tensor, bits=4, group_size=-1, v=0, min_scale=1.0

return qdq_tensor, (scale_fp8_to_int4 * scale_bf16_to_fp8, scale_bf16_to_fp8), zp_fp8_to_int4


@register_dtype("fp8_gaudi2_to_int_sym_pc")
def progressive_quant_fp8_int4_per_channel(tensor, bits=4, group_size=-1, v=0, min_scale=1.0, max_scale=1.0, q_scale_thresh=1e-5,
weight_fp8_max_scale=1.0,**kwargs):
"""
Per-Channle quantization
Two-stage quantization: quantize tensor to fp8 by per tensor, then quantize fp8 to w4g128
This method first quantizes the input tensor into float8 format and then performs
a secondary quantization to int4 with grouping.
Args:
tensor (torch.Tensor): Input tensor to quantize.
bits (int, optional): Bit precision for secondary quantization. Defaults to 4.
group_size (int, optional): Group size for int4 quantization. Defaults to -1 (no grouping).
v (float, optional): Optional parameter for variance tuning. Defaults to 0.
min_scale (float, optional): Minimum scaling factor for int4 quantization. Defaults to 1.0.
max_scale (float, optional): Maximum scaling factor for int4 quantization. Defaults to 1.0.
q_scale_thresh (float, optional): Threshold for scaling. Defaults to 1e-5.
weight_fp8_max_scale (float, optional): Maximum scaling factor for float8 quantization. Defaults to 1.0.
**kwargs: Additional arguments for compatibility.
Returns:
tuple:
- Quantized and dequantized tensor (torch.Tensor).
- Combined scaling factor (torch.Tensor).
- Placeholder for zp (None).
"""
# tensor: [out_feats, in_feats]
# scale_bf16_to_fp8: [out_feats, 1]
out_feats, in_feats = tensor.shape
fp8_max = STANDARD_FP8E4M3FN_MAX * global_config.FP8_WEIGHT_BACKOFF
dim = 1
tensor_max = torch.max(torch.abs(tensor), dim=dim, keepdim=True)[0].to(torch.float32) * weight_fp8_max_scale ## better train a ratio
scale = tensor_max.to(torch.float32) / fp8_max
min_scaling_factor = 1.0 / (fp8_max* 512.0) ##copy from vllm
scale_bf16_to_fp8 = torch.clip(scale, min=min_scaling_factor)
fp8_res = tensor / scale_bf16_to_fp8
fp8_res = torch.clip(fp8_res, -fp8_max, fp8_max)
float8_e4m3fn_ste_gaudi2 = get_gaudi2_fp8_ste_func()
fp8_res = float8_e4m3fn_ste_gaudi2(fp8_res)

##convert to bf16
fp8_res_using_16bit = fp8_res.to(tensor.dtype)
##convert to int4
from auto_round.data_type.int import quant_tensor_sym
qdq_int4_tensor, scale_fp8_to_int4, zp_fp8_to_int4 = quant_tensor_sym(fp8_res_using_16bit, bits=bits,
group_size=group_size, v=v,
min_scale=min_scale,
max_scale=max_scale,
scale_dtype=torch.bfloat16,
q_scale_thresh=q_scale_thresh)
qdq_tensor = qdq_int4_tensor * scale_bf16_to_fp8
scale_fp8_to_int4_with_group = scale_fp8_to_int4
scale_fp8_to_int4_with_group_reshape_back = scale_fp8_to_int4_with_group.reshape(out_feats, -1)
scale_bf16_to_int4 = scale_fp8_to_int4_with_group_reshape_back * scale_bf16_to_fp8
scale_bf16_to_int4_with_group = scale_bf16_to_int4.reshape(-1, 1)
return qdq_tensor, (scale_bf16_to_int4_with_group, scale_bf16_to_fp8), zp_fp8_to_int4
# return qdq_tensor, (scale_fp8_to_int4 * scale_bf16_to_fp8, scale_bf16_to_fp8), zp_fp8_to_int4

@register_dtype("fp8_gaudi2_to_int_sym_v2")
def progressive_quant_fp8_int4_v2(tensor, bits=4, group_size=-1, v=0, min_scale=1.0, max_scale=1.0, q_scale_thresh=1e-5,
weight_fp8_max_scale=1.0,**kwargs):
Expand Down
14 changes: 12 additions & 2 deletions auto_round/export/export_to_autoround/qlinear_triton_gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
import torch
import torch.nn as nn
import transformers

import os
# from auto_round_extension.cuda.triton_utils.mixin import TritonModuleMixin

logger = getLogger(__name__)
Expand Down Expand Up @@ -117,10 +117,12 @@ def __init__(self, bits, group_size, infeatures, outfeatures, bias, trainable=Fa
dtype=torch.bfloat16,
),
)

_shape = (1, self.outfeatures) if os.environ.get("W4A8_PC", "0") == "1" else (1)
self.register_buffer(
"w_bf16_to_fp8_scale",
torch.zeros(
(1),
_shape,
dtype=torch.bfloat16,
),
)
Expand All @@ -143,6 +145,14 @@ def __init__(self, bits, group_size, infeatures, outfeatures, bias, trainable=Fa

self.trainable = trainable

def __repr__(self):
return (
f"{self.__class__.__name__}({self.infeatures}, {self.outfeatures}, "
f"bits={self.bits}, group_size={self.group_size},"
f"scales shape: {self.scales.shape}, act_scales shape: {self.act_scales.shape}, w_bf16_to_fp8_scale shape: {self.w_bf16_to_fp8_scale.shape}"
f")"
)

def post_init(self):
pass

Expand Down
21 changes: 17 additions & 4 deletions auto_round/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
set_module,
logger
)
from loguru import logger as rich_logger


def reshape_and_pad_tensor(v, group_size=-1):
Expand Down Expand Up @@ -58,7 +59,14 @@ class WrapperLinear(torch.nn.Module):
device (str): Device on which to run computations (e.g., 'cpu' or 'cuda').
"""

def __init__(self, orig_layer, enable_minmax_tuning=True, enable_norm_bias_tuning=False, device='cpu'):
def __init__(
self,
orig_layer,
enable_minmax_tuning=True,
enable_norm_bias_tuning=False,
device="cpu",
_inner_layer_name=None,
):
"""Initializes the WrapperLinear module.
Args:
Expand All @@ -68,6 +76,7 @@ def __init__(self, orig_layer, enable_minmax_tuning=True, enable_norm_bias_tunin
device (str): The computation device, such as 'cpu' or 'cuda'.
"""
super(WrapperLinear, self).__init__()
self._inner_layer_name = _inner_layer_name
self.orig_layer = orig_layer
self.device = device
self.enable_minmax_tuning = enable_minmax_tuning
Expand Down Expand Up @@ -152,7 +161,6 @@ def _qdq_weight(self, value, min_scale, max_scale):
weight = self.orig_layer.get_weight().to(self.device)
if isinstance(self.orig_layer, transformers.modeling_utils.Conv1D):
weight = weight.t()

weight_q, scale, zp = self.weight_quant_func(weight, bits=self.orig_layer.bits,
group_size=self.orig_layer.group_size, v=value,
min_scale=min_scale, max_scale=max_scale,
Expand Down Expand Up @@ -468,8 +476,13 @@ def wrapper_block(block, enable_minmax_tuning, enable_norm_bias_tuning, device='
if not check_to_quantized(m):
unquantized_layers.append(n)
continue
new_m = WrapperLinear(m, enable_minmax_tuning=enable_minmax_tuning,
enable_norm_bias_tuning=enable_norm_bias_tuning, device=device)
new_m = WrapperLinear(
m,
enable_minmax_tuning=enable_minmax_tuning,
enable_norm_bias_tuning=enable_norm_bias_tuning,
device=device,
_inner_layer_name=n,
)
set_module(block, n, new_m)
quantized_layers.append(n)

Expand Down

0 comments on commit 997a43b

Please sign in to comment.