Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LPT] ConcatTransformation: support scalar equal DQ propagation through dynamic dimension #28350

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 67 additions & 20 deletions src/common/low_precision_transformations/src/concat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,12 @@ bool ConcatTransformation::transform(TransformationContext& context, ov::pass::p
allDequantizationShiftConvertAreNotZero = false;
}

const auto& concat_out_shape = concat->get_output_partial_shape(0);
const auto axis = ov::util::try_normalize_axis(concat->get_axis(), concat_out_shape.rank(), *concat);
const bool scalar_equal_constants_requested = concat_out_shape[axis].is_dynamic();

// constant shape must be broadcastable to the shape on data.
auto broadcastElementWiseConst = [](std::shared_ptr<opset1::Constant> operation, const Shape targetShape) {
auto adaptConstForConcatenation = [](std::shared_ptr<opset1::Constant> operation, const Shape targetShape) {
auto targetShapeConst = std::make_shared<opset1::Constant>(element::i64, Shape{ targetShape.size() }, targetShape);
auto broadcast = fold<ov::opset1::Broadcast>(operation, targetShapeConst);
return broadcast;
Expand All @@ -100,8 +104,6 @@ bool ConcatTransformation::transform(TransformationContext& context, ov::pass::p
[](const FakeQuantizeDequantization& value) { return !value.isLowPrecision(); });

bool DqWithDifferentPrecision = someDqInLowPrecision && someDqInFpPrecision;
const auto axis =
ov::util::try_normalize_axis(concat->get_axis(), concat->get_output_partial_shape(0).rank(), *concat);

OutputVector dataNodes;
NodeVector convertNodes;
Expand All @@ -121,8 +123,13 @@ bool ConcatTransformation::transform(TransformationContext& context, ov::pass::p
convertNodes.push_back(dequantization.convert);
}

Shape targetShape(concat->get_input_partial_shape(i).rank().get_length(), 1ul);
targetShape[axis] = concat->get_input_partial_shape(i)[axis].get_length();
const auto targetShape = [&]() {
if (scalar_equal_constants_requested)
return ov::Shape{};
Shape targetShape(concat->get_input_partial_shape(i).rank().get_length(), 1ul);
targetShape[axis] = concat->get_input_partial_shape(i)[axis].get_length();
return targetShape;
}();

if (!allDequantizationShiftAreZero) {
auto subtractInput = dequantization.subtract == nullptr ?
Expand All @@ -132,7 +139,7 @@ bool ConcatTransformation::transform(TransformationContext& context, ov::pass::p
deqPrecision),
targetShape,
std::vector<float>({ 0.f })) :
broadcastElementWiseConst(dequantization.subtractConstant, targetShape);
adaptConstForConcatenation(dequantization.subtractConstant, targetShape);
if (allDequantizationShiftConvertAreNotZero) {
if (subtractConvert == nullptr && dequantization.subtractConvert != nullptr) {
subtractConvert = dequantization.subtractConvert;
Expand All @@ -147,7 +154,7 @@ bool ConcatTransformation::transform(TransformationContext& context, ov::pass::p
if (!allDequantizationMultiplyAreZero) {
mulConstants.push_back(dequantization.multiply == nullptr ?
std::make_shared<ov::opset1::Constant>(deqPrecision, targetShape, std::vector<float>({ 1.0f })) :
broadcastElementWiseConst(dequantization.multiplyConstant, targetShape));
adaptConstForConcatenation(dequantization.multiplyConstant, targetShape));
}
}

Expand All @@ -162,10 +169,31 @@ bool ConcatTransformation::transform(TransformationContext& context, ov::pass::p
lastDequantization = convert;
}

auto concat_constants_if_needed = [&](const NodeVector& constants) {
OPENVINO_ASSERT(!constants.empty(), "concat_constants_if_needed expects non empty constants vec");
if (constants.size() == 1ul) {
return constants[0];
}
if (scalar_equal_constants_requested) {
if (ov::shape_size(constants[0]->get_output_shape(0)) == 1) {
const auto ref_value = ov::as_type_ptr<ov::op::v0::Constant>(constants[0])->cast_vector<float>();
bool all_constants_are_equal = true;
for (size_t i = 1ul; i < constants.size(); i++) {
const auto cur_value = ov::as_type_ptr<ov::op::v0::Constant>(constants[i])->cast_vector<float>();
if (ref_value != cur_value) {
all_constants_are_equal = false;
}
}
if (all_constants_are_equal)
return constants[0];
}
OPENVINO_THROW("in case of dynamic concatenation dim all constants must be scalar and equal");
}
return ov::pass::low_precision::fold<ov::opset1::Concat>(constants, axis);
};

if (!subConstants.empty()) {
std::shared_ptr<ov::Node> subtractNode = subConstants.size() == 1ul ?
subConstants[0] :
ov::pass::low_precision::fold<ov::opset1::Concat>(subConstants, axis);
auto subtractNode = concat_constants_if_needed(subConstants);
if (subtractConvert != nullptr)
subtractNode = subtractConvert->clone_with_new_inputs({subtractNode});
const auto subtract = std::make_shared<opset1::Subtract>(
Expand All @@ -181,9 +209,7 @@ bool ConcatTransformation::transform(TransformationContext& context, ov::pass::p
const auto multiply = std::make_shared<ov::op::TypeRelaxed<opset1::Multiply>>(
opset1::Multiply(
lastDequantization,
NetworkHelper::toScalarIfPossible(mulConstants.size() == 1ul ?
mulConstants[0] :
ov::pass::low_precision::fold<ov::opset1::Concat>(mulConstants, axis))),
NetworkHelper::toScalarIfPossible(concat_constants_if_needed(mulConstants))),
layerDequantizations[0].multiply->get_output_element_type(0));

NetworkHelper::copyInfo({ concat, multiply }, multiply);
Expand Down Expand Up @@ -216,9 +242,32 @@ bool ConcatTransformation::canBeTransformed(const TransformationContext& context
return false;
}

auto base_dq_check = [&](const FakeQuantizeDequantization& dequantization) {
return !dequantization.empty() && (!updatePrecisions || dequantization.isLowPrecision());
};

const size_t normalizedAxis = ov::util::try_normalize_axis(axis, outRank, *concat);
if (outPShape[normalizedAxis].is_dynamic()) {
return false;
// in case of dynamic dimension we can propagate all dequantizations only if they are all scalar and equal,
// since DQ broadcast is impossible (requested shape is unknown), and only single scalar DQ after Concat can be set
const auto dequantization_ref = NetworkHelper::getDequantization(concat, defaultPrecisions, 0);
if (!base_dq_check(dequantization_ref) || !dequantization_ref.isPerTensor())
return false;

auto extract_values = [](const std::shared_ptr<ov::op::v0::Constant>& constant) {
return constant ? constant->cast_vector<float>() : std::vector<float>();
};
const auto ref_shifts = extract_values(dequantization_ref.subtractConstant);
const auto ref_scales = extract_values(dequantization_ref.multiplyConstant);

for (size_t i = 1ul; i < concat->get_input_size(); i++) {
const auto cur_dequantization = NetworkHelper::getDequantization(concat, defaultPrecisions, i);
if (!base_dq_check(dequantization_ref) ||
ref_shifts != extract_values(cur_dequantization.subtractConstant) ||
ref_scales != extract_values(cur_dequantization.multiplyConstant))
return false;
}
return true;
}

auto checkConstShape = [&normalizedAxis, &outRank](const std::shared_ptr<opset1::Constant>& constant) {
Expand All @@ -235,7 +284,6 @@ bool ConcatTransformation::canBeTransformed(const TransformationContext& context
};

const auto check_const_precision = [](
const FakeQuantizeDequantization& dequantization,
const std::shared_ptr<Node>& constant,
ov::element::Type& const_precision) {
if (constant == nullptr) {
Expand All @@ -253,9 +301,8 @@ bool ConcatTransformation::canBeTransformed(const TransformationContext& context

for (size_t i = 0ul; i < concat->get_input_size(); i++) {
const FakeQuantizeDequantization dequantization = NetworkHelper::getDequantization(concat, defaultPrecisions, i);
if (dequantization.empty() || (updatePrecisions && !dequantization.isLowPrecision())) {
if (!base_dq_check(dequantization))
return false;
}

if (((dequantization.subtract != nullptr) && (!checkConstShape(dequantization.subtractConstant))) ||
((dequantization.multiply != nullptr) && (!checkConstShape(dequantization.multiplyConstant)))) {
Expand All @@ -268,9 +315,9 @@ bool ConcatTransformation::canBeTransformed(const TransformationContext& context
return false;
}

if (!check_const_precision(dequantization, dequantization.subtractConvert, const_precision) ||
((dequantization.subtractConvert == nullptr) && !check_const_precision(dequantization, dequantization.subtractConstant, const_precision)) ||
!check_const_precision(dequantization, dequantization.multiplyConstant, const_precision)) {
if (!check_const_precision(dequantization.subtractConvert, const_precision) ||
((dequantization.subtractConvert == nullptr) && !check_const_precision(dequantization.subtractConstant, const_precision)) ||
!check_const_precision(dequantization.multiplyConstant, const_precision)) {
return false;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,91 @@ const std::vector<ConcatTransformationTestValues> testValues = {
{ov::element::f32, {128.f}, {0.1f}}
}
},
// dynamic concatenation axis, but the same per-tensor values
{
{{1, -1, 4, 4}, {1, -1, 4, 4}},
std::int64_t{1},
LayerTransformation::createParamsU8I8(),
{
ov::element::u8,
{
{ov::element::f32, {128.f}, {0.1f}},
{ov::element::f32, {128.f}, {0.1f}}
}
},
{
ov::element::u8,
{{}, {}},
ov::element::u8,
{ov::element::f32, {128.f}, {0.1f}}
}
},
// dynamic concatenation axis, dq don't match
{
{{1, -1, 4, 4}, {1, -1, 4, 4}},
std::int64_t{1},
LayerTransformation::createParamsU8I8(),
{
ov::element::u8,
{
{ov::element::f32, {128.f}, {0.1f}},
{ov::element::f32, {}, {0.1f}}
}
},
{
ov::element::u8,
{
{ov::element::f32, {128.f}, {0.1f}},
{ov::element::f32, {}, {0.1f}}
},
ov::element::f32,
{}
}
},
// dynamic concatenation axis, different per-tensor values
{
{{1, -1, 4, 4}, {1, -1, 4, 4}},
std::int64_t{1},
LayerTransformation::createParamsU8I8(),
{
ov::element::u8,
{
{ov::element::f32, {128.f}, {0.1f}},
{ov::element::f32, {128.f}, {10.f}}
}
},
{
ov::element::u8,
{
{ov::element::f32, {128.f}, {0.1f}},
{ov::element::f32, {128.f}, {10.f}}
},
ov::element::f32,
{}
}
},
// dynamic output concatenation axis, but one input dim is static
{
{{1, -1, 4, 4}, {1, 3, 4, 4}},
std::int64_t{1},
LayerTransformation::createParamsU8I8(),
{
ov::element::u8,
{
{ov::element::f32, {128.f}, {0.1f}},
{ov::element::f32, {{128.f, 64.f, 128.f}}, {{10.f, 1.f, 10.f}}}
}
},
{
ov::element::u8,
{
{ov::element::f32, {128.f}, {0.1f}},
{ov::element::f32, {{128.f, 64.f, 128.f}}, {{10.f, 1.f, 10.f}}}
},
ov::element::f32,
{}
}
},
{
{{1, 3, 4, 4}, {1, 3, 4, 4}},
std::int64_t{1},
Expand Down
Loading