Skip to content

Commit

Permalink
Fix summary row issues in case postaggregations are happening (apache…
Browse files Browse the repository at this point in the history
…#15232)

* fix-1/2

* add message v1

* extend test to cover for IOB issue

* move stuff around

* change message

* fix testcase string

* compute postaggs (thank you Clint!)

* enable feature for test

* ignore tests in msq

---------

Co-authored-by: Soumyava Das <[email protected]>
  • Loading branch information
kgyrtkirk and soumyava authored Oct 25, 2023
1 parent 06f40a0 commit 6784e9c
Show file tree
Hide file tree
Showing 8 changed files with 66 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,18 @@ public void testQueryWithMoreThanMaxNumericInFilter()

}

@Ignore
@Override
public void testUnSupportedNullsFirst()
{
}

@Ignore
@Override
public void testUnSupportedNullsLast()
{
}

/**
* Same query as {@link CalciteQueryTest#testArrayAggQueryOnComplexDatatypes}. ARRAY_AGG is not supported in MSQ currently.
* Once support is added, this test can be removed and msqCompatible() can be added to the one in CalciteQueryTest.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -788,11 +788,19 @@ private static boolean summaryRowPreconditions(GroupByQuery query)
private static Iterator<ResultRow> summaryRowIterator(GroupByQuery q)
{
List<AggregatorFactory> aggSpec = q.getAggregatorSpecs();
Object[] values = new Object[aggSpec.size()];
ResultRow resultRow = ResultRow.create(q.getResultRowSizeWithPostAggregators());
for (int i = 0; i < aggSpec.size(); i++) {
values[i] = aggSpec.get(i).factorize(new AllNullColumnSelectorFactory()).get();
resultRow.set(i, aggSpec.get(i).factorize(new AllNullColumnSelectorFactory()).get());
}
return Collections.singleton(ResultRow.of(values)).iterator();
Map<String, Object> map = resultRow.toMap(q);
for (int i = 0; i < q.getPostAggregatorSpecs().size(); i++) {
final PostAggregator postAggregator = q.getPostAggregatorSpecs().get(i);
final Object value = postAggregator.compute(map);

resultRow.set(q.getResultRowPostAggregatorStart() + i, value);
map.put(postAggregator.getName(), value);
}
return Collections.singleton(resultRow).iterator();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -12962,6 +12962,9 @@ public void testSummaryrowForEmptyInput()
new FloatSumAggregatorFactory("idxFloat", "indexFloat"),
new DoubleSumAggregatorFactory("idxDouble", "index")
)
.setPostAggregatorSpecs(
ImmutableList.of(
new ExpressionPostAggregator("post", "idx * 2", null, TestExprMacroTable.INSTANCE)))
.setGranularity(QueryRunnerTestHelper.ALL_GRAN)
.build();

Expand All @@ -12976,7 +12979,9 @@ public void testSummaryrowForEmptyInput()
"idxFloat",
NullHandling.replaceWithDefault() ? 0.0 : null,
"idxDouble",
NullHandling.replaceWithDefault() ? 0.0 : null
NullHandling.replaceWithDefault() ? 0.0 : null,
"post",
NullHandling.replaceWithDefault() ? 0L : null
)
);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,8 @@ private SqlValidator createSqlValidator(CalciteCatalogReader catalogReader)
opTab,
catalogReader,
getTypeFactory(),
validatorConfig
validatorConfig,
context.unwrapOrThrow(PlannerContext.class)
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,26 +30,41 @@
import org.apache.calcite.sql.SqlOperatorTable;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.sql.validate.SqlValidatorScope;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.sql.calcite.run.EngineFeature;

/**
* Druid extended SQL validator. (At present, it doesn't actually
* have any extensions yet, but it will soon.)
*/
class DruidSqlValidator extends BaseDruidSqlValidator
{
private final PlannerContext plannerContext;

protected DruidSqlValidator(
SqlOperatorTable opTab,
CalciteCatalogReader catalogReader,
JavaTypeFactory typeFactory,
Config validatorConfig
Config validatorConfig,
PlannerContext plannerContext
)
{
super(opTab, catalogReader, typeFactory, validatorConfig);
this.plannerContext = plannerContext;
}

@Override
public void validateCall(SqlCall call, SqlValidatorScope scope)
{
if (call.getKind() == SqlKind.OVER) {
if (!plannerContext.featureAvailable(EngineFeature.WINDOW_FUNCTIONS)) {
throw buildCalciteContextException(
StringUtils.format(
"The query contains window functions; To run these window functions, enable [%s] in query context.",
EngineFeature.WINDOW_FUNCTIONS),
call);
}
}
if (call.getKind() == SqlKind.NULLS_FIRST) {
SqlNode op0 = call.getOperandList().get(0);
if (op0.getKind() == SqlKind.DESCENDING) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,9 +182,12 @@ public SqlConformance conformance()
return DruidConformance.instance();
}
};
} else {
return null;
}
if (aClass.equals(PlannerContext.class)) {
return (C) plannerContext;
}

return null;
}
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14293,6 +14293,17 @@ public void testUnSupportedNullsLast()
assertThat(e, invalidSqlIs("ASCENDING ordering with NULLS LAST is not supported! (line [1], column [41])"));
}

@Test
public void testWindowingErrorWithoutFeatureFlag()
{
DruidException e = assertThrows(DruidException.class, () -> testBuilder()
.queryContext(ImmutableMap.of(PlannerContext.CTX_ENABLE_WINDOW_FNS, false))
.sql("SELECT dim1,ROW_NUMBER() OVER () from druid.foo")
.run());

assertThat(e, invalidSqlIs("The query contains window functions; To run these window functions, enable [WINDOW_FUNCTIONS] in query context. (line [1], column [13])"));
}

@Test
public void testInGroupByLimitOutGroupByOrderBy()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
package org.apache.druid.sql.calcite;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import org.apache.druid.sql.calcite.NotYetSupported.Modes;
import org.apache.druid.sql.calcite.NotYetSupported.NotYetSupportedProcessor;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.junit.Rule;
import org.junit.Test;

Expand Down Expand Up @@ -53,6 +55,7 @@ public void testTasksSumOver()
msqIncompatible();

testBuilder()
.queryContext(ImmutableMap.of(PlannerContext.CTX_ENABLE_WINDOW_FNS, true))
.sql("select datasource, sum(duration) over () from sys.tasks group by datasource")
.expectedResults(ImmutableList.of(
new Object[]{"foo", 11L},
Expand Down

0 comments on commit 6784e9c

Please sign in to comment.