Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use ONNX Rewriter and IR to simplify the mnb_to_qdq pass #1482

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 96 additions & 0 deletions olive/passes/onnx/mnb_to_qdq.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

import numpy as np
import onnx
from onnxscript import ir
from onnxscript.rewriter import pattern as orp

from olive.hardware.accelerator import AcceleratorSpec
from olive.model import ONNXModelHandler
Expand Down Expand Up @@ -62,6 +64,100 @@
) -> ONNXModelHandler:
output_model_path = resolve_onnx_path(output_model_path, Path(model.model_path).name)

# 2 Step
# 1. pattern replacement
# 2. Repacking
ir_model = ir.load(model.model_path)

def mat_mul_n_bits_pattern(op, input_A, qweight, qscales, qzeros, g_idx, bias):
return op.MatMulNBits(input_A, qweight, qscales, qzeros, g_idx, bias)

def _is_initializer(context, value: ir.Value) -> bool:
graph: ir.Graph = context.graph
return value in graph.initializers.values()

def mat_mul_n_bits_pattern_check(context, input_A, qweight, qscales, qzeros, g_idx, bias) -> bool:
if not _is_initializer(context, qweight):
return False
node: ir.Node = _get_node(input_A)
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
block_size = node.attributes["block_size"].value
K = node.attributes["K"].value
Fixed Show fixed Hide fixed
g_idx = g_idx.constant_value.numpy()
trivial_g_idx = np.arange(K, dtype=np.int32) // block_size
if not np.array_equal(g_idx, trivial_g_idx):
Fixed Show fixed Hide fixed

Check warning

Code scanning / lintrunner

RUFF/SIM103 Warning

Return the negated condition directly.
See https://docs.astral.sh/ruff/rules/needless-bool
# Log
return False
return True

def mat_mul_n_bits_replacement(op, input_A, qweight, qscales, qzeros, g_idx, bias):
node: ir.Node = _get_node(input_A)
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
# TODO(justinchuby): Keep the old name of the node
K: int = node.attributes["K"].value
Fixed Show fixed Hide fixed
block_size: int = node.attributes["block_size"].value
num_k_blocks = math.ceil(K / block_size)
# will make this a per-axis DQ if num_k_blocks == 1
# - originally per-axis K == block_size
# - originally blockwise but K <= block_size
is_per_axis = num_k_blocks == 1

# dequantizelinear -> transpose -> matmul -> add (optional)
dq = op.DequantizeLinear(
qweight,
qscales,
qzeros,
block_size=None if is_per_axis else block_size,
# for some reason block_wise and per-axis appear to use swapped axis
# flip the axis if it is per-axis
axis=(1 if config["use_transpose_op"] else 0) ^ (1 if is_per_axis else 0),
)
# TODO(justinchuby): Improve the way we mark something that needs repacking
dq.producer().meta["needs_repacking"] = True
dq.producer().meta["K"] = K
dq.producer().meta["N"] = node.attributes["N"].value
if config["use_transpose_node"]:
dq = op.Transpose(dq, perm=[1, 0])
matmul = op.MatMul(input_A, dq)
if bias is not None:
matmul = op.Add(matmul, bias)
return matmul

replace_matmul_n_bits = orp.RewriteRule(
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
mat_mul_n_bits_pattern,
mat_mul_n_bits_pattern_check,
mat_mul_n_bits_replacement,
)

# Call the rewriter with replace_matmul_n_bits

# 2. Repacking
for node in ir_model.graph:
if "needs_repacking" not in node.meta:
continue

# Add Logic handling input 3

unpacked_weight_arrays = _unpack_weights(

Check failure

Code scanning / lintrunner

PYLINT/E0602 Error

Undefined variable '_unpack_weights' (undefined-variable)
See undefined-variable.

Check failure

Code scanning / lintrunner

RUFF/F821 Error

Undefined name \_unpack\_weights.
See https://docs.astral.sh/ruff/rules/undefined-name
node.meta["K"],
node.meta["N"],
node.inputs[1].const_value.numpy(),
node.inputs[2].const_value.numpy(),
node.inputs[3].const_value.numpy(),
)
array = unpacked_weight_arrays[0].view(ml_dtypes.int4)
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
node.inputs[1].const_value = ir.Tensor(array)
node.inputs[2].const_value = ir.Tensor(array)
input_3 = ir.Value(None)
input_3.const_value = ir.Tensor(array)
node.replace_input_with(3, input_3)
ir_model.graph.initializers[input_3.name] = input_3

# TODO(justinchuby): Register and remove initializers

ir_model.opset_imports[""] = max(21, ir_model.opset_imports[""])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: Use a more robust version conversion process


# save the model to the output path and return the model
return ir_model_to_olive_model(ir_model, output_model_path, config)
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed

# create a dag from the model
dag = OnnxDAG.from_model_path(model.model_path)
# remove unnecessary identity nodes
Expand Down
Loading