diff --git a/src/antlr/Lexer.g b/src/antlr/Lexer.g index 1e30ad18edd2..100541caf913 100644 --- a/src/antlr/Lexer.g +++ b/src/antlr/Lexer.g @@ -227,6 +227,7 @@ K_DROPPED: D R O P P E D; K_COLUMN: C O L U M N; K_RECORD: R E C O R D; K_ANN_OF: A N N WS+ O F; +K_BM25_OF: 'BM25' WS+ 'OF'; // Case-insensitive alpha characters fragment A: ('a'|'A'); diff --git a/src/antlr/Parser.g b/src/antlr/Parser.g index 7fb85a9aeeea..43c4fe67c96f 100644 --- a/src/antlr/Parser.g +++ b/src/antlr/Parser.g @@ -457,14 +457,18 @@ customIndexExpression [WhereClause.Builder clause] ; orderByClause[List orderings] - @init{ + @init { Ordering.Direction direction = Ordering.Direction.ASC; + Ordering.Raw.Expression expr = null; } - : c=cident (K_ANN_OF t=term)? (K_ASC | K_DESC { direction = Ordering.Direction.DESC; })? + : c=cident + ( K_ANN_OF t=term { expr = new Ordering.Raw.Ann(c, t); } + | K_BM25_OF t=term { expr = new Ordering.Raw.Bm25(c, t); } + )? + (K_ASC | K_DESC { direction = Ordering.Direction.DESC; })? { - Ordering.Raw.Expression expr = (t == null) - ? new Ordering.Raw.SingleColumn(c) - : new Ordering.Raw.Ann(c, t); + if (expr == null) + expr = new Ordering.Raw.SingleColumn(c); orderings.add(new Ordering.Raw(expr, direction)); } ; @@ -1967,6 +1971,7 @@ basic_unreserved_keyword returns [String str] | K_COLUMN | K_RECORD | K_ANN_OF + | K_BM25_OF | K_OFFSET ) { $str = $k.text; } ; diff --git a/src/java/org/apache/cassandra/cql3/GeoDistanceRelation.java b/src/java/org/apache/cassandra/cql3/GeoDistanceRelation.java index 3d5fb2eeeef7..640e7e600686 100644 --- a/src/java/org/apache/cassandra/cql3/GeoDistanceRelation.java +++ b/src/java/org/apache/cassandra/cql3/GeoDistanceRelation.java @@ -141,6 +141,12 @@ protected Restriction newAnnRestriction(TableMetadata table, VariableSpecificati throw invalidRequest("%s cannot be used with the GEO_DISTANCE function", operator()); } + @Override + protected Restriction newBm25Restriction(TableMetadata table, VariableSpecifications boundNames) + { + throw invalidRequest("%s cannot be used with the GEO_DISTANCE function", operator()); + } + @Override protected Restriction newAnalyzerMatchesRestriction(TableMetadata table, VariableSpecifications boundNames) { diff --git a/src/java/org/apache/cassandra/cql3/MultiColumnRelation.java b/src/java/org/apache/cassandra/cql3/MultiColumnRelation.java index fcd505fcf5df..f56d76e2ced9 100644 --- a/src/java/org/apache/cassandra/cql3/MultiColumnRelation.java +++ b/src/java/org/apache/cassandra/cql3/MultiColumnRelation.java @@ -250,6 +250,12 @@ protected Restriction newAnnRestriction(TableMetadata table, VariableSpecificati throw invalidRequest("%s cannot be used for multi-column relations", operator()); } + @Override + protected Restriction newBm25Restriction(TableMetadata table, VariableSpecifications boundNames) + { + throw invalidRequest("%s cannot be used for multi-column relations", operator()); + } + @Override protected Restriction newAnalyzerMatchesRestriction(TableMetadata table, VariableSpecifications boundNames) { diff --git a/src/java/org/apache/cassandra/cql3/Operator.java b/src/java/org/apache/cassandra/cql3/Operator.java index 41b7985ffc4d..cd80020c1257 100644 --- a/src/java/org/apache/cassandra/cql3/Operator.java +++ b/src/java/org/apache/cassandra/cql3/Operator.java @@ -259,7 +259,7 @@ public String toString() @Override public boolean isSatisfiedBy(AbstractType type, ByteBuffer leftOperand, ByteBuffer rightOperand, @Nullable Index.Analyzer analyzer) { - return true; + throw new UnsupportedOperationException(); } }, NOT_IN(16) @@ -415,6 +415,7 @@ private boolean hasToken(AbstractType type, List tokens, ByteBuff return false; } }, + /** * An operator that performs a distance bounded approximate nearest neighbor search against a vector column such * that all result vectors are within a given distance of the query vector. The notable difference between this @@ -459,6 +460,20 @@ public String toString() return "DESC"; } + @Override + public boolean isSatisfiedBy(AbstractType type, ByteBuffer leftOperand, ByteBuffer rightOperand, @Nullable Index.Analyzer analyzer) + { + throw new UnsupportedOperationException(); + } + }, + BM25(104) + { + @Override + public String toString() + { + return "BM25"; + } + @Override public boolean isSatisfiedBy(AbstractType type, ByteBuffer leftOperand, ByteBuffer rightOperand, @Nullable Index.Analyzer analyzer) { diff --git a/src/java/org/apache/cassandra/cql3/Ordering.java b/src/java/org/apache/cassandra/cql3/Ordering.java index 81aa94a076cd..2dd817818a76 100644 --- a/src/java/org/apache/cassandra/cql3/Ordering.java +++ b/src/java/org/apache/cassandra/cql3/Ordering.java @@ -20,6 +20,7 @@ import org.apache.cassandra.cql3.restrictions.SingleColumnRestriction; import org.apache.cassandra.cql3.restrictions.SingleRestriction; +import org.apache.cassandra.cql3.statements.SelectStatement; import org.apache.cassandra.schema.ColumnMetadata; import org.apache.cassandra.schema.TableMetadata; @@ -48,6 +49,11 @@ public interface Expression SingleRestriction toRestriction(); ColumnMetadata getColumn(); + + default boolean isScored() + { + return false; + } } /** @@ -118,6 +124,54 @@ public ColumnMetadata getColumn() { return column; } + + @Override + public boolean isScored() + { + return SelectStatement.ANN_USE_SYNTHETIC_SCORE; + } + } + + /** + * An expression used in BM25 ordering. + * ORDER BY column BM25 OF value + */ + public static class Bm25 implements Expression + { + final ColumnMetadata column; + final Term queryValue; + final Direction direction; + + public Bm25(ColumnMetadata column, Term queryValue, Direction direction) + { + this.column = column; + this.queryValue = queryValue; + this.direction = direction; + } + + @Override + public boolean hasNonClusteredOrdering() + { + return true; + } + + @Override + public SingleRestriction toRestriction() + { + return new SingleColumnRestriction.Bm25Restriction(column, queryValue); + } + + @Override + public ColumnMetadata getColumn() + { + return column; + } + + @Override + public boolean isScored() + { + return true; + } } public enum Direction @@ -190,6 +244,27 @@ public Ordering.Expression bind(TableMetadata table, VariableSpecifications boun return new Ordering.Ann(column, value, direction); } } + + public static class Bm25 implements Expression + { + final ColumnIdentifier columnId; + final Term.Raw queryValue; + + Bm25(ColumnIdentifier column, Term.Raw queryValue) + { + this.columnId = column; + this.queryValue = queryValue; + } + + @Override + public Ordering.Expression bind(TableMetadata table, VariableSpecifications boundNames, Direction direction) + { + ColumnMetadata column = table.getExistingColumn(columnId); + Term value = queryValue.prepare(table.keyspace, column); + value.collectMarkerSpecification(boundNames); + return new Ordering.Bm25(column, value, direction); + } + } } } diff --git a/src/java/org/apache/cassandra/cql3/Relation.java b/src/java/org/apache/cassandra/cql3/Relation.java index 5cca2d257323..42cf3c9c8287 100644 --- a/src/java/org/apache/cassandra/cql3/Relation.java +++ b/src/java/org/apache/cassandra/cql3/Relation.java @@ -202,6 +202,8 @@ public final Restriction toRestriction(TableMetadata table, VariableSpecificatio return newLikeRestriction(table, boundNames, relationType); case ANN: return newAnnRestriction(table, boundNames); + case BM25: + return newBm25Restriction(table, boundNames); case ANALYZER_MATCHES: return newAnalyzerMatchesRestriction(table, boundNames); default: throw invalidRequest("Unsupported \"!=\" relation: %s", this); @@ -296,6 +298,11 @@ protected abstract Restriction newSliceRestriction(TableMetadata table, */ protected abstract Restriction newAnnRestriction(TableMetadata table, VariableSpecifications boundNames); + /** + * Creates a new BM25 restriction instance. + */ + protected abstract Restriction newBm25Restriction(TableMetadata table, VariableSpecifications boundNames); + /** * Creates a new Analyzer Matches restriction instance. */ diff --git a/src/java/org/apache/cassandra/cql3/SingleColumnRelation.java b/src/java/org/apache/cassandra/cql3/SingleColumnRelation.java index ec66ad70b529..4dc6aaa78a7c 100644 --- a/src/java/org/apache/cassandra/cql3/SingleColumnRelation.java +++ b/src/java/org/apache/cassandra/cql3/SingleColumnRelation.java @@ -21,8 +21,11 @@ import java.util.ArrayList; import java.util.List; import java.util.Objects; +import java.util.stream.Collectors; import org.apache.cassandra.db.marshal.VectorType; +import org.apache.cassandra.index.IndexRegistry; +import org.apache.cassandra.index.sai.analyzer.AnalyzerEqOperatorSupport; import org.apache.cassandra.schema.ColumnMetadata; import org.apache.cassandra.schema.TableMetadata; import org.apache.cassandra.cql3.Term.Raw; @@ -33,6 +36,7 @@ import org.apache.cassandra.db.marshal.ListType; import org.apache.cassandra.db.marshal.MapType; import org.apache.cassandra.exceptions.InvalidRequestException; +import org.apache.cassandra.service.ClientWarn; import static org.apache.cassandra.cql3.statements.RequestValidations.checkFalse; import static org.apache.cassandra.cql3.statements.RequestValidations.checkTrue; @@ -191,7 +195,27 @@ protected Restriction newEQRestriction(TableMetadata table, VariableSpecificatio if (mapKey == null) { Term term = toTerm(toReceivers(columnDef), value, table.keyspace, boundNames); - return new SingleColumnRestriction.EQRestriction(columnDef, term); + // Leave the restriction as EQ if no analyzed index in backwards compatibility mode is present + var ebi = IndexRegistry.obtain(table).getEqBehavior(columnDef); + if (ebi.behavior == IndexRegistry.EqBehavior.EQ) + return new SingleColumnRestriction.EQRestriction(columnDef, term); + + // the index is configured to transform EQ into MATCH for backwards compatibility + if (ebi.behavior == IndexRegistry.EqBehavior.MATCH) + { + ClientWarn.instance.warn(String.format(AnalyzerEqOperatorSupport.EQ_RESTRICTION_ON_ANALYZED_WARNING, + columnDef.toString(), + ebi.matchIndex.getIndexMetadata().name), + columnDef); + return new SingleColumnRestriction.AnalyzerMatchesRestriction(columnDef, term); + } + + // multiple indexes support EQ, this is unsupported + assert ebi.behavior == IndexRegistry.EqBehavior.AMBIGUOUS; + throw invalidRequest(AnalyzerEqOperatorSupport.EQ_AMBIGUOUS_ERROR, + columnDef.toString(), + ebi.matchIndex.getIndexMetadata().name, + ebi.eqIndex.getIndexMetadata().name); } List receivers = toReceivers(columnDef); Term entryKey = toTerm(Collections.singletonList(receivers.get(0)), mapKey, table.keyspace, boundNames); @@ -333,6 +357,14 @@ protected Restriction newAnnRestriction(TableMetadata table, VariableSpecificati return new SingleColumnRestriction.AnnRestriction(columnDef, term); } + @Override + protected Restriction newBm25Restriction(TableMetadata table, VariableSpecifications boundNames) + { + ColumnMetadata columnDef = table.getExistingColumn(entity); + Term term = toTerm(toReceivers(columnDef), value, table.keyspace, boundNames); + return new SingleColumnRestriction.Bm25Restriction(columnDef, term); + } + @Override protected Restriction newAnalyzerMatchesRestriction(TableMetadata table, VariableSpecifications boundNames) { diff --git a/src/java/org/apache/cassandra/cql3/TokenRelation.java b/src/java/org/apache/cassandra/cql3/TokenRelation.java index a3ca586eee76..ca849dc82a30 100644 --- a/src/java/org/apache/cassandra/cql3/TokenRelation.java +++ b/src/java/org/apache/cassandra/cql3/TokenRelation.java @@ -138,7 +138,13 @@ protected Restriction newLikeRestriction(TableMetadata table, VariableSpecificat @Override protected Restriction newAnnRestriction(TableMetadata table, VariableSpecifications boundNames) { - throw invalidRequest("%s cannot be used for toekn relations", operator()); + throw invalidRequest("%s cannot be used for token relations", operator()); + } + + @Override + protected Restriction newBm25Restriction(TableMetadata table, VariableSpecifications boundNames) + { + throw invalidRequest("%s cannot be used for token relations", operator()); } @Override diff --git a/src/java/org/apache/cassandra/cql3/restrictions/ClusteringColumnRestrictions.java b/src/java/org/apache/cassandra/cql3/restrictions/ClusteringColumnRestrictions.java index 092cc93cc7b7..0b33dfb16f69 100644 --- a/src/java/org/apache/cassandra/cql3/restrictions/ClusteringColumnRestrictions.java +++ b/src/java/org/apache/cassandra/cql3/restrictions/ClusteringColumnRestrictions.java @@ -189,7 +189,7 @@ public ClusteringColumnRestrictions.Builder addRestriction(Restriction restricti SingleRestriction lastRestriction = restrictions.lastRestriction(); ColumnMetadata lastRestrictionStart = lastRestriction.getFirstColumn(); ColumnMetadata newRestrictionStart = newRestriction.getFirstColumn(); - restrictions.addRestriction(newRestriction, isDisjunction, indexRegistry); + restrictions.addRestriction(newRestriction, isDisjunction); checkFalse(lastRestriction.isSlice() && newRestrictionStart.position() > lastRestrictionStart.position(), "Clustering column \"%s\" cannot be restricted (preceding column \"%s\" is restricted by a non-EQ relation)", @@ -203,7 +203,7 @@ public ClusteringColumnRestrictions.Builder addRestriction(Restriction restricti } else { - restrictions.addRestriction(newRestriction, isDisjunction, indexRegistry); + restrictions.addRestriction(newRestriction, isDisjunction); } return this; diff --git a/src/java/org/apache/cassandra/cql3/restrictions/MultiColumnRestriction.java b/src/java/org/apache/cassandra/cql3/restrictions/MultiColumnRestriction.java index b08925fafe6f..30df843651db 100644 --- a/src/java/org/apache/cassandra/cql3/restrictions/MultiColumnRestriction.java +++ b/src/java/org/apache/cassandra/cql3/restrictions/MultiColumnRestriction.java @@ -127,10 +127,7 @@ protected final String getColumnsInCommons(Restriction otherRestriction) @Override public final boolean hasSupportingIndex(IndexRegistry indexRegistry) { - for (Index index : indexRegistry.listIndexes()) - if (isSupportingIndex(index)) - return true; - return false; + return findSupportingIndex(indexRegistry) != null; } @Override diff --git a/src/java/org/apache/cassandra/cql3/restrictions/PartitionKeySingleRestrictionSet.java b/src/java/org/apache/cassandra/cql3/restrictions/PartitionKeySingleRestrictionSet.java index 50e6ea481f09..faddd9d9ff06 100644 --- a/src/java/org/apache/cassandra/cql3/restrictions/PartitionKeySingleRestrictionSet.java +++ b/src/java/org/apache/cassandra/cql3/restrictions/PartitionKeySingleRestrictionSet.java @@ -189,7 +189,7 @@ public PartitionKeyRestrictions build(IndexRegistry indexRegistry, boolean isDis if (restriction.isOnToken()) return buildWithTokens(restrictionSet, i, indexRegistry); - restrictionSet.addRestriction((SingleRestriction) restriction, isDisjunction, indexRegistry); + restrictionSet.addRestriction((SingleRestriction) restriction, isDisjunction); } return buildPartitionKeyRestrictions(restrictionSet); diff --git a/src/java/org/apache/cassandra/cql3/restrictions/RestrictionSet.java b/src/java/org/apache/cassandra/cql3/restrictions/RestrictionSet.java index 92736c0d298a..9d459e1da247 100644 --- a/src/java/org/apache/cassandra/cql3/restrictions/RestrictionSet.java +++ b/src/java/org/apache/cassandra/cql3/restrictions/RestrictionSet.java @@ -413,45 +413,36 @@ private Builder() { } - public void addRestriction(SingleRestriction restriction, boolean isDisjunction, IndexRegistry indexRegistry) + public void addRestriction(SingleRestriction restriction, boolean isDisjunction) { List columnDefs = restriction.getColumnDefs(); if (isDisjunction) { // If this restriction is part of a disjunction query then we don't want - // to merge the restrictions (if that is possible), we just add the - // restriction to the set of restrictions for the column. + // to merge the restrictions, we just add the new restriction addRestrictionForColumns(columnDefs, restriction, null); } else { - // In some special cases such as EQ in analyzed index we need to skip merging the restriction, - // so we can send multiple EQ restrictions to the index. - if (restriction.skipMerge(indexRegistry)) - { - addRestrictionForColumns(columnDefs, restriction, null); - return; - } - - // If this restriction isn't part of a disjunction then we need to get - // the set of existing restrictions for the column and merge them with the - // new restriction + // ANDed together restrictions against the same columns should be merged. Set existingRestrictions = getRestrictions(newRestrictions, columnDefs); - SingleRestriction merged = restriction; - Set replacedRestrictions = new HashSet<>(); - - for (SingleRestriction existing : existingRestrictions) + // merge the new restriction into an existing one. note that there is only ever a single + // restriction (per column), UNLESS one is ORDER BY BM25 and the other is MATCH. + for (var existing : existingRestrictions) { - if (!existing.skipMerge(indexRegistry)) + // shouldMerge exists for the BM25/MATCH case + if (existing.shouldMerge(restriction)) { - merged = existing.mergeWith(merged); - replacedRestrictions.add(existing); + var merged = existing.mergeWith(restriction); + addRestrictionForColumns(merged.getColumnDefs(), merged, Set.of(existing)); + return; } } - addRestrictionForColumns(merged.getColumnDefs(), merged, replacedRestrictions); + // no existing restrictions that we should merge the new one with, add a new one + addRestrictionForColumns(columnDefs, restriction, null); } } diff --git a/src/java/org/apache/cassandra/cql3/restrictions/SingleColumnRestriction.java b/src/java/org/apache/cassandra/cql3/restrictions/SingleColumnRestriction.java index 264353a34928..137ee3ab1a54 100644 --- a/src/java/org/apache/cassandra/cql3/restrictions/SingleColumnRestriction.java +++ b/src/java/org/apache/cassandra/cql3/restrictions/SingleColumnRestriction.java @@ -24,14 +24,18 @@ import java.util.List; import java.util.Map; -import org.apache.cassandra.db.filter.RowFilter; -import org.apache.cassandra.schema.ColumnMetadata; -import org.apache.cassandra.cql3.*; +import org.apache.cassandra.cql3.MarkerOrTerms; +import org.apache.cassandra.cql3.Operator; +import org.apache.cassandra.cql3.QueryOptions; +import org.apache.cassandra.cql3.Term; +import org.apache.cassandra.cql3.Terms; import org.apache.cassandra.cql3.functions.Function; import org.apache.cassandra.cql3.statements.Bound; import org.apache.cassandra.db.MultiClusteringBuilder; +import org.apache.cassandra.db.filter.RowFilter; import org.apache.cassandra.index.Index; import org.apache.cassandra.index.IndexRegistry; +import org.apache.cassandra.schema.ColumnMetadata; import org.apache.cassandra.serializers.ListSerializer; import org.apache.cassandra.transport.ProtocolVersion; import org.apache.cassandra.utils.ByteBufferUtil; @@ -198,24 +202,6 @@ public String toString() return String.format("EQ(%s)", term); } - @Override - public boolean skipMerge(IndexRegistry indexRegistry) - { - // We should skip merging this EQ if there is an analyzed index for this column that supports EQ, - // so there can be multiple EQs for the same column. - - if (indexRegistry == null) - return false; - - for (Index index : indexRegistry.listIndexes()) - { - if (index.supportsExpression(columnDef, Operator.ANALYZER_MATCHES) && - index.supportsExpression(columnDef, Operator.EQ)) - return true; - } - return false; - } - @Override public SingleRestriction doMergeWith(SingleRestriction otherRestriction) { @@ -1191,6 +1177,88 @@ public boolean isBoundedAnn() } } + public static final class Bm25Restriction extends SingleColumnRestriction + { + private final Term value; + + public Bm25Restriction(ColumnMetadata columnDef, Term value) + { + super(columnDef); + this.value = value; + } + + public ByteBuffer value(QueryOptions options) + { + return value.bindAndGet(options); + } + + @Override + public void addFunctionsTo(List functions) + { + value.addFunctionsTo(functions); + } + + @Override + MultiColumnRestriction toMultiColumnRestriction() + { + throw new UnsupportedOperationException(); + } + + @Override + public void addToRowFilter(RowFilter.Builder filter, + IndexRegistry indexRegistry, + QueryOptions options) + { + var index = findSupportingIndex(indexRegistry); + var valueBytes = value.bindAndGet(options); + var terms = index.getAnalyzer().get().analyze(valueBytes); + if (terms.isEmpty()) + throw invalidRequest("BM25 query must contain at least one term (perhaps your analyzer is discarding tokens you didn't expect)"); + filter.add(columnDef, Operator.BM25, valueBytes); + } + + @Override + public MultiClusteringBuilder appendTo(MultiClusteringBuilder builder, QueryOptions options) + { + throw new UnsupportedOperationException(); + } + + @Override + public String toString() + { + return String.format("BM25(%s)", value); + } + + @Override + public SingleRestriction doMergeWith(SingleRestriction otherRestriction) + { + throw invalidRequest("%s cannot be restricted by both BM25 and %s", columnDef.name, otherRestriction.toString()); + } + + @Override + protected boolean isSupportedBy(Index index) + { + return index.supportsExpression(columnDef, Operator.BM25); + } + + @Override + public boolean isIndexBasedOrdering() + { + return true; + } + + @Override + public boolean shouldMerge(SingleRestriction other) + { + // we don't want to merge MATCH restrictions with ORDER BY BM25 + // so shouldMerge = false for that scenario, and true for others + // (because even though we can't meaningfully merge with others, we want doMergeWith to be called to throw) + // + // (Note that because ORDER BY is processed before WHERE, we only need this check in the BM25 class) + return !other.isAnalyzerMatches(); + } + } + /** * A Bounded ANN Restriction is one that uses a similarity score as the limiting factor for ANN instead of a number * of results. @@ -1336,10 +1404,12 @@ public String toString() @Override public SingleRestriction doMergeWith(SingleRestriction otherRestriction) { - if (!(otherRestriction.isAnalyzerMatches())) + if (!otherRestriction.isAnalyzerMatches()) throw invalidRequest(CANNOT_BE_MERGED_ERROR, columnDef.name); - List otherValues = ((AnalyzerMatchesRestriction) otherRestriction).getValues(); + List otherValues = otherRestriction instanceof AnalyzerMatchesRestriction + ? ((AnalyzerMatchesRestriction) otherRestriction).getValues() + : List.of(((EQRestriction) otherRestriction).term); List newValues = new ArrayList<>(values.size() + otherValues.size()); newValues.addAll(values); newValues.addAll(otherValues); diff --git a/src/java/org/apache/cassandra/cql3/restrictions/SingleRestriction.java b/src/java/org/apache/cassandra/cql3/restrictions/SingleRestriction.java index 595451f812de..bdd80badc0ae 100644 --- a/src/java/org/apache/cassandra/cql3/restrictions/SingleRestriction.java +++ b/src/java/org/apache/cassandra/cql3/restrictions/SingleRestriction.java @@ -20,7 +20,6 @@ import org.apache.cassandra.cql3.QueryOptions; import org.apache.cassandra.cql3.statements.Bound; import org.apache.cassandra.db.MultiClusteringBuilder; -import org.apache.cassandra.index.IndexRegistry; /** * A single restriction/clause on one or multiple column. @@ -97,17 +96,6 @@ public default boolean isInclusive(Bound b) return true; } - /** - * Checks if this restriction shouldn't be merged with other restrictions. - * - * @param indexRegistry the index registry - * @return {@code true} if this shouldn't be merged with other restrictions - */ - default boolean skipMerge(IndexRegistry indexRegistry) - { - return false; - } - /** * Merges this restriction with the specified one. * @@ -141,4 +129,16 @@ public default MultiClusteringBuilder appendBoundTo(MultiClusteringBuilder build { return appendTo(builder, options); } + + /** + * @return true if the other restriction should be merged with this one. + * This is NOT for preventing illegal combinations of restrictions, e.g. + * a=1 AND a=2; that is handled by mergeWith. Instead, this is for the case + * where we want two completely different semantics against the same column. + * Currently the only such case is BM25 with MATCH. + */ + default boolean shouldMerge(SingleRestriction other) + { + return true; + } } diff --git a/src/java/org/apache/cassandra/cql3/restrictions/StatementRestrictions.java b/src/java/org/apache/cassandra/cql3/restrictions/StatementRestrictions.java index 95bdff2391e8..6509b741ef3d 100644 --- a/src/java/org/apache/cassandra/cql3/restrictions/StatementRestrictions.java +++ b/src/java/org/apache/cassandra/cql3/restrictions/StatementRestrictions.java @@ -81,6 +81,7 @@ public class StatementRestrictions "Restriction on partition key column %s must not be nested under OR operator"; public static final String GEO_DISTANCE_REQUIRES_INDEX_MESSAGE = "GEO_DISTANCE requires the vector column to be indexed"; + public static final String BM25_ORDERING_REQUIRES_ANALYZED_INDEX_MESSAGE = "BM25 ordering on column %s requires an analyzed index"; public static final String NON_CLUSTER_ORDERING_REQUIRES_INDEX_MESSAGE = "Ordering on non-clustering column %s requires the column to be indexed."; public static final String NON_CLUSTER_ORDERING_REQUIRES_ALL_RESTRICTED_NON_PARTITION_KEY_COLUMNS_INDEXED_MESSAGE = "Ordering on non-clustering column requires each restricted column to be indexed except for fully-specified partition keys"; @@ -442,7 +443,7 @@ else if (def.isClusteringColumn() && nestingLevel == 0) } else { - nonPrimaryKeyRestrictionSet.addRestriction((SingleRestriction) restriction, element.isDisjunction(), indexRegistry); + nonPrimaryKeyRestrictionSet.addRestriction((SingleRestriction) restriction, element.isDisjunction()); } } } @@ -682,7 +683,8 @@ else if (indexOrderings.size() == 1) if (orderings.size() > 1) throw new InvalidRequestException("Cannot combine clustering column ordering with non-clustering column ordering"); Ordering ordering = indexOrderings.get(0); - if (ordering.direction != Ordering.Direction.ASC && ordering.expression instanceof Ordering.Ann) + // TODO remove the instanceof with SelectStatement.ANN_USE_SYNTHETIC_SCORE. + if (ordering.direction != Ordering.Direction.ASC && (ordering.expression.isScored() || ordering.expression instanceof Ordering.Ann)) throw new InvalidRequestException("Descending ANN ordering is not supported"); if (!ENABLE_SAI_GENERAL_ORDER_BY && ordering.expression instanceof Ordering.SingleColumn) throw new InvalidRequestException("SAI based ORDER BY on non-vector column is not supported"); @@ -695,10 +697,14 @@ else if (indexOrderings.size() == 1) throw new InvalidRequestException(String.format("SAI based ordering on column %s of type %s is not supported", restriction.getFirstColumn(), restriction.getFirstColumn().type.asCQL3Type())); - throw new InvalidRequestException(String.format(NON_CLUSTER_ORDERING_REQUIRES_INDEX_MESSAGE, - restriction.getFirstColumn())); + if (ordering.expression instanceof Ordering.Bm25) + throw new InvalidRequestException(String.format(BM25_ORDERING_REQUIRES_ANALYZED_INDEX_MESSAGE, + restriction.getFirstColumn())); + else + throw new InvalidRequestException(String.format(NON_CLUSTER_ORDERING_REQUIRES_INDEX_MESSAGE, + restriction.getFirstColumn())); } - receiver.addRestriction(restriction, false, indexRegistry); + receiver.addRestriction(restriction, false); } } diff --git a/src/java/org/apache/cassandra/cql3/selection/ColumnFilterFactory.java b/src/java/org/apache/cassandra/cql3/selection/ColumnFilterFactory.java index 63fa0520101e..00225cca4108 100644 --- a/src/java/org/apache/cassandra/cql3/selection/ColumnFilterFactory.java +++ b/src/java/org/apache/cassandra/cql3/selection/ColumnFilterFactory.java @@ -38,9 +38,21 @@ abstract class ColumnFilterFactory */ abstract ColumnFilter newInstance(List selectors); - public static ColumnFilterFactory wildcard(TableMetadata table) + public static ColumnFilterFactory wildcard(TableMetadata table, Set orderingColumns) { - return new PrecomputedColumnFilter(ColumnFilter.all(table)); + ColumnFilter cf; + if (orderingColumns.isEmpty()) + { + cf = ColumnFilter.all(table); + } + else + { + ColumnFilter.Builder builder = ColumnFilter.selectionBuilder(); + builder.addAll(table.regularAndStaticColumns()); + builder.addAll(orderingColumns); + cf = builder.build(); + } + return new PrecomputedColumnFilter(cf); } public static ColumnFilterFactory fromColumns(TableMetadata table, diff --git a/src/java/org/apache/cassandra/cql3/selection/Selection.java b/src/java/org/apache/cassandra/cql3/selection/Selection.java index 02aae61dd5ff..12d8aa014e19 100644 --- a/src/java/org/apache/cassandra/cql3/selection/Selection.java +++ b/src/java/org/apache/cassandra/cql3/selection/Selection.java @@ -43,10 +43,24 @@ public abstract class Selection private static final Predicate STATIC_COLUMN_FILTER = (column) -> column.isStatic(); private final TableMetadata table; + + // Full list of columns needed for processing the query, including selected columns, ordering columns, + // and columns needed for restrictions. Wildcard columns are fully materialized here. + // + // This also includes synthetic columns, because unlike all the other not-physical-columns selectables, they are + // computed on the replica instead of the coordinator and so, like physical columns, they need to be sent back + // as part of the result. private final List columns; + + // maps ColumnSpecifications (columns, function calls, aliases) to the columns backing them private final SelectionColumnMapping columnMapping; + + // metadata matching the ColumnSpcifications protected final ResultSet.ResultMetadata metadata; + + // creates a ColumnFilter that breaks columns into `queried` and `fetched` protected final ColumnFilterFactory columnFilterFactory; + protected final boolean isJson; // Columns used to order the result set for JSON queries with post ordering. @@ -126,10 +140,15 @@ public ResultSet.ResultMetadata getResultMetadata() } public static Selection wildcard(TableMetadata table, boolean isJson, boolean returnStaticContentOnPartitionWithNoRows) + { + return wildcard(table, Collections.emptySet(), isJson, returnStaticContentOnPartitionWithNoRows); + } + + public static Selection wildcard(TableMetadata table, Set orderingColumns, boolean isJson, boolean returnStaticContentOnPartitionWithNoRows) { List all = new ArrayList<>(table.columns().size()); Iterators.addAll(all, table.allColumnsInSelectOrder()); - return new SimpleSelection(table, all, Collections.emptySet(), true, isJson, returnStaticContentOnPartitionWithNoRows); + return new SimpleSelection(table, all, orderingColumns, true, isJson, returnStaticContentOnPartitionWithNoRows); } public static Selection wildcardWithGroupBy(TableMetadata table, @@ -400,7 +419,7 @@ public SimpleSelection(TableMetadata table, selectedColumns, orderingColumns, SelectionColumnMapping.simpleMapping(selectedColumns), - isWildcard ? ColumnFilterFactory.wildcard(table) + isWildcard ? ColumnFilterFactory.wildcard(table, orderingColumns) : ColumnFilterFactory.fromColumns(table, selectedColumns, orderingColumns, Collections.emptySet(), returnStaticContentOnPartitionWithNoRows), isWildcard, isJson); diff --git a/src/java/org/apache/cassandra/cql3/statements/SelectStatement.java b/src/java/org/apache/cassandra/cql3/statements/SelectStatement.java index b15931e17831..d2025783ebf4 100644 --- a/src/java/org/apache/cassandra/cql3/statements/SelectStatement.java +++ b/src/java/org/apache/cassandra/cql3/statements/SelectStatement.java @@ -25,7 +25,6 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.MoreObjects; -import com.google.common.base.Preconditions; import com.google.common.math.IntMath; import org.apache.cassandra.cql3.Ordering; @@ -39,8 +38,8 @@ import org.apache.cassandra.config.DatabaseDescriptor; import org.apache.cassandra.cql3.restrictions.ExternalRestriction; import org.apache.cassandra.cql3.restrictions.Restrictions; -import org.apache.cassandra.cql3.restrictions.SingleRestriction; import org.apache.cassandra.cql3.selection.SortedRowsBuilder; +import org.apache.cassandra.db.marshal.FloatType; import org.apache.cassandra.guardrails.Guardrails; import org.apache.cassandra.index.Index; import org.apache.cassandra.schema.ColumnMetadata; @@ -102,6 +101,12 @@ */ public class SelectStatement implements CQLStatement.SingleKeyspaceCqlStatement { + // TODO remove this when we no longer need to downgrade to replicas that don't know about synthetic columns, + // and the related code in + // - StatementRestrictions.addOrderingRestrictions + // - StorageAttachedIndexSearcher.PrimaryKeyIterator constructor + public static final boolean ANN_USE_SYNTHETIC_SCORE = Boolean.parseBoolean(System.getProperty("cassandra.sai.ann_use_synthetic_score", "false")); + private static final Logger logger = LoggerFactory.getLogger(SelectStatement.class); private static final NoSpamLogger noSpamLogger = NoSpamLogger.getLogger(SelectStatement.logger, 1, TimeUnit.MINUTES); public static final String TOPK_CONSISTENCY_LEVEL_ERROR = "Top-K queries can only be run with consistency level ONE/LOCAL_ONE. Consistency level %s was used."; @@ -1082,12 +1087,16 @@ void processPartition(RowIterator partition, QueryOptions options, ResultSetBuil case CLUSTERING: result.add(row.clustering().bufferAt(def.position())); break; + case SYNTHETIC: + // treat as REGULAR case REGULAR: result.add(row.getColumnData(def), nowInSec); break; case STATIC: result.add(staticRow.getColumnData(def), nowInSec); break; + default: + throw new AssertionError(); } } } @@ -1095,19 +1104,16 @@ void processPartition(RowIterator partition, QueryOptions options, ResultSetBuil private boolean needsToSkipUserLimit() { - // if post query ordering is required, and it's not ordered by an index - return needsPostQueryOrdering() && !needIndexOrdering(); + // if we're querying by `pk IN (...)` and ordering by clustered columns, replicas don't sort + // before applying LIMIT + return needsPostQueryOrdering() && (orderingComparator != null && orderingComparator.isClustered()); } private boolean needsPostQueryOrdering() { // We need post-query ordering only for queries with IN on the partition key and an ORDER BY or index restriction reordering - return restrictions.keyIsInRelation() && !parameters.orderings.isEmpty() || needIndexOrdering(); - } - - private boolean needIndexOrdering() - { - return orderingComparator != null && orderingComparator.indexOrdering(); + return (restrictions.keyIsInRelation() && !parameters.orderings.isEmpty()) + || orderingComparator != null; } /** @@ -1115,35 +1121,23 @@ private boolean needIndexOrdering() */ public SortedRowsBuilder sortedRowsBuilder(int limit, int offset, QueryOptions options) { - assert (orderingComparator != null) == needsPostQueryOrdering() - : String.format("orderingComparator: %s, needsPostQueryOrdering: %s", - orderingComparator, needsPostQueryOrdering()); - if (orderingComparator == null) - { return SortedRowsBuilder.create(limit, offset); - } - else if (orderingComparator instanceof IndexColumnComparator) + + if (orderingComparator instanceof VectorColumnComparator) { - SingleRestriction restriction = ((IndexColumnComparator) orderingComparator).restriction; - int columnIndex = ((IndexColumnComparator) orderingComparator).columnIndex; + SingleRestriction restriction = ((VectorColumnComparator) orderingComparator).restriction; + int columnIndex = ((VectorColumnComparator) orderingComparator).columnIndex; Index index = restriction.findSupportingIndex(IndexRegistry.obtain(table)); assert index != null; - if (restriction instanceof SingleColumnRestriction.OrderRestriction) - { - var comparator = index.postQueryComparator(restriction, columnIndex, options); - return SortedRowsBuilder.create(limit, offset, comparator); - } - Index.Scorer scorer = index.postQueryScorer(restriction, columnIndex, options); return SortedRowsBuilder.create(limit, offset, scorer); } - else - { - return SortedRowsBuilder.create(limit, offset, orderingComparator); - } + + // else + return SortedRowsBuilder.create(limit, offset, orderingComparator); } public static class RawStatement extends QualifiedStatement @@ -1187,6 +1181,9 @@ public SelectStatement prepare(boolean forView, UnaryOperator keyspaceMa List selectables = RawSelector.toSelectables(selectClause, table); boolean containsOnlyStaticColumns = selectOnlyStaticColumns(table, selectables); + // Besides actual restrictions (where clauses), prepareRestrictions will include pseudo-restrictions + // on indexed columns to allow pushing ORDER BY into the index; see StatementRestrictions::addOrderingRestrictions. + // Therefore, we don't want to convert an ANN Ordering column into a +score column until after that. List orderings = getOrderings(table); StatementRestrictions restrictions = prepareRestrictions( table, bindVariables, orderings, containsOnlyStaticColumns, forView); @@ -1194,6 +1191,11 @@ public SelectStatement prepare(boolean forView, UnaryOperator keyspaceMa // If we order post-query, the sorted column needs to be in the ResultSet for sorting, // even if we don't ultimately ship them to the client (CASSANDRA-4911). Map orderingColumns = getOrderingColumns(orderings); + // +score column for ANN/BM25 + var scoreOrdering = getScoreOrdering(orderings); + assert scoreOrdering == null || orderingColumns.isEmpty() : "can't have both scored ordering and column ordering"; + if (scoreOrdering != null) + orderingColumns = scoreOrdering; Set resultSetOrderingColumns = getResultSetOrdering(restrictions, orderingColumns); Selection selection = prepareSelection(table, @@ -1222,9 +1224,9 @@ public SelectStatement prepare(boolean forView, UnaryOperator keyspaceMa if (!orderingColumns.isEmpty()) { assert !forView; - verifyOrderingIsAllowed(restrictions, orderingColumns); + verifyOrderingIsAllowed(table, restrictions, orderingColumns); orderingComparator = getOrderingComparator(selection, restrictions, orderingColumns); - isReversed = isReversed(table, orderingColumns, restrictions); + isReversed = isReversed(table, orderingColumns); if (isReversed && orderingComparator != null) orderingComparator = orderingComparator.reverse(); } @@ -1247,6 +1249,21 @@ public SelectStatement prepare(boolean forView, UnaryOperator keyspaceMa prepareLimit(bindVariables, offset, ks, offsetReceiver())); } + private Map getScoreOrdering(List orderings) + { + if (orderings.isEmpty()) + return null; + + var expr = orderings.get(0).expression; + if (!expr.isScored()) + return null; + + // Create synthetic score column + ColumnMetadata sourceColumn = expr.getColumn(); + var cm = ColumnMetadata.syntheticColumn(sourceColumn.ksName, sourceColumn.cfName, ColumnMetadata.SYNTHETIC_SCORE_ID, FloatType.instance); + return Map.of(cm, orderings.get(0)); + } + private Set getResultSetOrdering(StatementRestrictions restrictions, Map orderingColumns) { if (restrictions.keyIsInRelation() || orderingColumns.values().stream().anyMatch(o -> o.expression.hasNonClusteredOrdering())) @@ -1265,7 +1282,7 @@ private Selection prepareSelection(TableMetadata table, if (selectables.isEmpty()) // wildcard query { return hasGroupBy ? Selection.wildcardWithGroupBy(table, boundNames, parameters.isJson, restrictions.returnStaticContentOnPartitionWithNoRows()) - : Selection.wildcard(table, parameters.isJson, restrictions.returnStaticContentOnPartitionWithNoRows()); + : Selection.wildcard(table, resultSetOrderingColumns, parameters.isJson, restrictions.returnStaticContentOnPartitionWithNoRows()); } return Selection.fromSelectors(table, @@ -1307,13 +1324,14 @@ private Map getOrderingColumns(List ordering if (orderings.isEmpty()) return Collections.emptyMap(); - Map orderingColumns = new LinkedHashMap<>(); - for (Ordering ordering : orderings) - { - ColumnMetadata column = ordering.expression.getColumn(); - orderingColumns.put(column, ordering); - } - return orderingColumns; + return orderings.stream() + .filter(ordering -> !ordering.expression.isScored()) + .collect(Collectors.toMap(ordering -> ordering.expression.getColumn(), + ordering -> ordering, + (a, b) -> { + throw new IllegalStateException("Duplicate keys"); + }, + LinkedHashMap::new)); } private List getOrderings(TableMetadata table) @@ -1360,12 +1378,28 @@ private Term prepareLimit(VariableSpecifications boundNames, Term.Raw limit, return prepLimit; } - private static void verifyOrderingIsAllowed(StatementRestrictions restrictions, Map orderingColumns) throws InvalidRequestException + private static void verifyOrderingIsAllowed(TableMetadata table, StatementRestrictions restrictions, Map orderingColumns) throws InvalidRequestException { if (orderingColumns.values().stream().anyMatch(o -> o.expression.hasNonClusteredOrdering())) return; + checkFalse(restrictions.usesSecondaryIndexing(), "ORDER BY with 2ndary indexes is not supported."); checkFalse(restrictions.isKeyRange(), "ORDER BY is only supported when the partition key is restricted by an EQ or an IN."); + + // check that clustering columns are valid + int i = 0; + for (var entry : orderingColumns.entrySet()) + { + ColumnMetadata def = entry.getKey(); + checkTrue(def.isClusteringColumn(), + "Order by is currently only supported on indexed columns and the clustered columns of the PRIMARY KEY, got %s", def.name); + while (i != def.position()) + { + checkTrue(restrictions.isColumnRestrictedByEq(table.clusteringColumns().get(i++)), + "Ordering by clustered columns must follow the declared order in the PRIMARY KEY"); + } + i++; + } } private static void validateDistinctSelection(TableMetadata metadata, @@ -1455,20 +1489,23 @@ private ColumnComparator> getOrderingComparator(Selection selec Map orderingColumns) throws InvalidRequestException { - for (Map.Entry e : orderingColumns.entrySet()) + if (orderingColumns.values().stream().anyMatch(o -> o.expression.hasNonClusteredOrdering())) { - if (e.getValue().expression.hasNonClusteredOrdering()) - { - Preconditions.checkState(orderingColumns.size() == 1); - return new IndexColumnComparator<>(e.getValue().expression.toRestriction(), selection.getOrderingIndex(e.getKey())); - } + assert orderingColumns.size() == 1 : orderingColumns.keySet(); + var e = orderingColumns.entrySet().iterator().next(); + var column = e.getKey(); + var ordering = e.getValue(); + if (ordering.expression instanceof Ordering.Ann && !ANN_USE_SYNTHETIC_SCORE) + return new VectorColumnComparator(ordering.expression.toRestriction(), selection.getOrderingIndex(column)); + else + return new SingleColumnComparator(selection.getOrderingIndex(column), column.type, false); } if (!restrictions.keyIsInRelation()) return null; - List idToSort = new ArrayList(orderingColumns.size()); - List> sorters = new ArrayList>(orderingColumns.size()); + List idToSort = new ArrayList<>(orderingColumns.size()); + List> sorters = new ArrayList<>(orderingColumns.size()); for (ColumnMetadata orderingColumn : orderingColumns.keySet()) { @@ -1480,35 +1517,26 @@ private ColumnComparator> getOrderingComparator(Selection selec : new CompositeComparator(sorters, idToSort); } - private boolean isReversed(TableMetadata table, Map orderingColumns, StatementRestrictions restrictions) throws InvalidRequestException + private boolean isReversed(TableMetadata table, Map orderingColumns) throws InvalidRequestException { - // Nonclustered ordering handles descending logic in a different way - if (orderingColumns.values().stream().anyMatch(o -> o.expression.hasNonClusteredOrdering())) - return false; - - Boolean[] reversedMap = new Boolean[table.clusteringColumns().size()]; - int i = 0; + Boolean[] clusteredMap = new Boolean[table.clusteringColumns().size()]; for (var entry : orderingColumns.entrySet()) { ColumnMetadata def = entry.getKey(); Ordering ordering = entry.getValue(); - boolean reversed = ordering.direction == Ordering.Direction.DESC; - - // VSTODO move this to verifyOrderingIsAllowed? - checkTrue(def.isClusteringColumn(), - "Order by is currently only supported on the clustered columns of the PRIMARY KEY, got %s", def.name); - while (i != def.position()) - { - checkTrue(restrictions.isColumnRestrictedByEq(table.clusteringColumns().get(i++)), - "Order by currently only supports the ordering of columns following their declared order in the PRIMARY KEY"); - } - i++; - reversedMap[def.position()] = (reversed != def.isReversedType()); + // We defined ANN OF to be ASC ordering, as in, "order by near-ness". But since score goes from + // 0 (worst) to 1 (closest), we need to reverse the ordering for the comparator when we're sorting + // by synthetic +score column. + boolean cqlReversed = ordering.direction == Ordering.Direction.DESC; + if (def.position() == ColumnMetadata.NO_POSITION) + return ordering.expression.isScored() || cqlReversed; + else + clusteredMap[def.position()] = (cqlReversed != def.isReversedType()); } - // Check that all boolean in reversedMap, if set, agrees + // Check that all boolean in clusteredMap, if set, agrees Boolean isReversed = null; - for (Boolean b : reversedMap) + for (Boolean b : clusteredMap) { // Column on which order is specified can be in any order if (b == null) @@ -1631,11 +1659,11 @@ public ColumnComparator reverse() } /** - * @return true if ordering is performed by index + * @return true if ordering is performed by classic collation columns */ - public boolean indexOrdering() + public boolean isClustered() { - return false; + return true; } } @@ -1653,6 +1681,12 @@ public int compare(T o1, T o2) { return wrapped.compare(o2, o1); } + + @Override + public boolean isClustered() + { + return wrapped.isClustered(); + } } /** * Used in orderResults(...) method when single 'ORDER BY' condition where given @@ -1661,35 +1695,48 @@ private static class SingleColumnComparator extends ColumnComparator comparator; + private final boolean clustered; - public SingleColumnComparator(int columnIndex, Comparator orderer) + public SingleColumnComparator(int columnIndex, Comparator orderer, boolean clustered) { index = columnIndex; comparator = orderer; + this.clustered = clustered; + } + + public SingleColumnComparator(int columnIndex, Comparator orderer) + { + this(columnIndex, orderer, true); } public int compare(List a, List b) { return compare(comparator, a.get(index), b.get(index)); } + + @Override + public boolean isClustered() + { + return clustered; + } } - private static class IndexColumnComparator extends ColumnComparator> + // placeholder for postQueryScorer call; see usage in sortedRowsBuilder + private static class VectorColumnComparator extends ColumnComparator> { private final SingleRestriction restriction; private final int columnIndex; - // VSTODO maybe cache in prepared statement - public IndexColumnComparator(SingleRestriction restriction, int columnIndex) + public VectorColumnComparator(SingleRestriction restriction, int columnIndex) { this.restriction = restriction; this.columnIndex = columnIndex; } @Override - public boolean indexOrdering() + public boolean isClustered() { - return true; + return false; } @Override diff --git a/src/java/org/apache/cassandra/db/Columns.java b/src/java/org/apache/cassandra/db/Columns.java index 7ce9bd68cf15..f85a9b803141 100644 --- a/src/java/org/apache/cassandra/db/Columns.java +++ b/src/java/org/apache/cassandra/db/Columns.java @@ -28,6 +28,7 @@ import net.nicoulaj.compilecommand.annotations.DontInline; import org.apache.cassandra.cql3.ColumnIdentifier; +import org.apache.cassandra.db.marshal.AbstractType; import org.apache.cassandra.db.marshal.SetType; import org.apache.cassandra.db.marshal.UTF8Type; import org.apache.cassandra.db.rows.ColumnData; @@ -36,6 +37,7 @@ import org.apache.cassandra.io.util.DataOutputPlus; import org.apache.cassandra.schema.ColumnMetadata; import org.apache.cassandra.schema.TableMetadata; +import org.apache.cassandra.serializers.AbstractTypeSerializer; import org.apache.cassandra.utils.ByteBufferUtil; import org.apache.cassandra.utils.ObjectSizes; import org.apache.cassandra.utils.SearchIterator; @@ -459,37 +461,107 @@ public String toString() public static class Serializer { + AbstractTypeSerializer typeSerializer = new AbstractTypeSerializer(); + public void serialize(Columns columns, DataOutputPlus out) throws IOException { - out.writeUnsignedVInt(columns.size()); + int regularCount = 0; + int syntheticCount = 0; + + // Count regular and synthetic columns + for (ColumnMetadata column : columns) + { + if (column.isSynthetic()) + syntheticCount++; + else + regularCount++; + } + + // Jam the two counts into a single value to avoid massive backwards compatibility issues + long packedCount = getPackedCount(syntheticCount, regularCount); + out.writeUnsignedVInt(packedCount); + + // First pass - write synthetic columns with their full metadata for (ColumnMetadata column : columns) - ByteBufferUtil.writeWithVIntLength(column.name.bytes, out); + { + if (column.isSynthetic()) + { + ByteBufferUtil.writeWithVIntLength(column.name.bytes, out); + typeSerializer.serialize(column.type, out); + } + } + + // Second pass - write regular columns + for (ColumnMetadata column : columns) + { + if (!column.isSynthetic()) + ByteBufferUtil.writeWithVIntLength(column.name.bytes, out); + } + } + + private static long getPackedCount(int syntheticCount, int regularCount) + { + // Left shift of 20 gives us over 1M regular columns, and up to 4 synthetic columns + // before overflowing to a 4th byte. + return ((long) syntheticCount << 20) | regularCount; } public long serializedSize(Columns columns) { - long size = TypeSizes.sizeofUnsignedVInt(columns.size()); + int regularCount = 0; + int syntheticCount = 0; + long size = 0; + + // Count and calculate sizes for (ColumnMetadata column : columns) - size += ByteBufferUtil.serializedSizeWithVIntLength(column.name.bytes); - return size; + { + if (column.isSynthetic()) + { + syntheticCount++; + size += ByteBufferUtil.serializedSizeWithVIntLength(column.name.bytes); + size += typeSerializer.serializedSize(column.type); + } + else + { + regularCount++; + size += ByteBufferUtil.serializedSizeWithVIntLength(column.name.bytes); + } + } + + return TypeSizes.sizeofUnsignedVInt(getPackedCount(syntheticCount, regularCount)) + + size; } public Columns deserialize(DataInputPlus in, TableMetadata metadata) throws IOException { - int length = (int)in.readUnsignedVInt(); try (BTree.FastBuilder builder = BTree.fastBuilder()) { - for (int i = 0; i < length; i++) + long packedCount = in.readUnsignedVInt() ; + int regularCount = (int) (packedCount & 0xFFFFF); + int syntheticCount = (int) (packedCount >> 20); + + // First pass - synthetic columns + for (int i = 0; i < syntheticCount; i++) + { + ByteBuffer name = ByteBufferUtil.readWithVIntLength(in); + AbstractType type = typeSerializer.deserialize(in); + + if (!name.equals(ColumnMetadata.SYNTHETIC_SCORE_ID.bytes)) + throw new IllegalStateException("Unknown synthetic column " + UTF8Type.instance.getString(name)); + + ColumnMetadata column = ColumnMetadata.syntheticColumn(metadata.keyspace, metadata.name, ColumnMetadata.SYNTHETIC_SCORE_ID, type); + builder.add(column); + } + + // Second pass - regular columns + for (int i = 0; i < regularCount; i++) { ByteBuffer name = ByteBufferUtil.readWithVIntLength(in); ColumnMetadata column = metadata.getColumn(name); if (column == null) { - // If we don't find the definition, it could be we have data for a dropped column, and we shouldn't - // fail deserialization because of that. So we grab a "fake" ColumnMetadata that ensure proper - // deserialization. The column will be ignore later on anyway. + // If we don't find the definition, it could be we have data for a dropped column column = metadata.getDroppedColumn(name); - if (column == null) throw new RuntimeException("Unknown column " + UTF8Type.instance.getString(name) + " during deserialization"); } diff --git a/src/java/org/apache/cassandra/db/ReadCommand.java b/src/java/org/apache/cassandra/db/ReadCommand.java index d872147c0d0f..a8e88b14094d 100644 --- a/src/java/org/apache/cassandra/db/ReadCommand.java +++ b/src/java/org/apache/cassandra/db/ReadCommand.java @@ -72,6 +72,7 @@ import org.apache.cassandra.net.Message; import org.apache.cassandra.net.MessageFlag; import org.apache.cassandra.net.Verb; +import org.apache.cassandra.schema.ColumnMetadata; import org.apache.cassandra.schema.IndexMetadata; import org.apache.cassandra.schema.Schema; import org.apache.cassandra.schema.SchemaConstants; @@ -382,8 +383,6 @@ static Index.QueryPlan findIndexQueryPlan(TableMetadata table, RowFilter rowFilt @Override public void maybeValidateIndexes() { - IndexRegistry.obtain(metadata()).validate(rowFilter()); - if (null != indexQueryPlan) indexQueryPlan.validate(this); } @@ -413,9 +412,9 @@ public UnfilteredPartitionIterator executeLocally(ReadExecutionController execut } Context context = Context.from(this); - UnfilteredPartitionIterator iterator = (null == searcher) ? Transformation.apply(queryStorage(cfs, executionController), new TrackingRowIterator(context)) - : Transformation.apply(searchStorage(searcher, executionController), new TrackingRowIterator(context)); - + var storageTarget = (null == searcher) ? queryStorage(cfs, executionController) + : searchStorage(searcher, executionController); + UnfilteredPartitionIterator iterator = Transformation.apply(storageTarget, new TrackingRowIterator(context)); iterator = RTBoundValidator.validate(iterator, Stage.MERGED, false); try @@ -1047,6 +1046,19 @@ public ReadCommand deserialize(DataInputPlus in, int version) throws IOException TableMetadata metadata = schema.getExistingTableMetadata(TableId.deserialize(in)); int nowInSec = in.readInt(); ColumnFilter columnFilter = ColumnFilter.serializer.deserialize(in, version, metadata); + + // add synthetic columns to the tablemetadata so we can serialize them in our response + var tmb = metadata.unbuild(); + for (var it = columnFilter.fetchedColumns().regulars.simpleColumns(); it.hasNext(); ) + { + var c = it.next(); + // synthetic columns sort first, so when we hit the first non-synthetic, we're done + if (!c.isSynthetic()) + break; + tmb.addColumn(ColumnMetadata.syntheticColumn(c.ksName, c.cfName, c.name, c.type)); + } + metadata = tmb.build(); + RowFilter rowFilter = RowFilter.serializer.deserialize(in, version, metadata); DataLimits limits = DataLimits.serializer.deserialize(in, version, metadata.comparator); Index.QueryPlan indexQueryPlan = null; diff --git a/src/java/org/apache/cassandra/db/RegularAndStaticColumns.java b/src/java/org/apache/cassandra/db/RegularAndStaticColumns.java index b6da183d013f..55533eda0e97 100644 --- a/src/java/org/apache/cassandra/db/RegularAndStaticColumns.java +++ b/src/java/org/apache/cassandra/db/RegularAndStaticColumns.java @@ -163,7 +163,7 @@ public Builder add(ColumnMetadata c) } else { - assert c.isRegular(); + assert c.isRegular() || c.isSynthetic(); if (regularColumns == null) regularColumns = BTree.builder(naturalOrder()); regularColumns.add(c); @@ -197,7 +197,7 @@ public Builder addAll(RegularAndStaticColumns columns) public RegularAndStaticColumns build() { - return new RegularAndStaticColumns(staticColumns == null ? Columns.NONE : Columns.from(staticColumns), + return new RegularAndStaticColumns(staticColumns == null ? Columns.NONE : Columns.from(staticColumns), regularColumns == null ? Columns.NONE : Columns.from(regularColumns)); } } diff --git a/src/java/org/apache/cassandra/db/filter/ColumnFilter.java b/src/java/org/apache/cassandra/db/filter/ColumnFilter.java index d9a1b9d4e51a..644e6d661a61 100644 --- a/src/java/org/apache/cassandra/db/filter/ColumnFilter.java +++ b/src/java/org/apache/cassandra/db/filter/ColumnFilter.java @@ -75,6 +75,9 @@ public abstract class ColumnFilter public static final Serializer serializer = new Serializer(); + // TODO remove this with ANN_USE_SYNTHETIC_SCORE + public abstract boolean fetchesExplicitly(ColumnMetadata column); + /** * The fetching strategy for the different queries. */ @@ -103,7 +106,8 @@ boolean fetchesAllColumns(boolean isStatic) @Override RegularAndStaticColumns getFetchedColumns(TableMetadata metadata, RegularAndStaticColumns queried) { - return metadata.regularAndStaticColumns(); + var merged = queried.regulars.mergeTo(metadata.regularColumns()); + return new RegularAndStaticColumns(metadata.staticColumns(), merged); } }, @@ -124,7 +128,8 @@ boolean fetchesAllColumns(boolean isStatic) @Override RegularAndStaticColumns getFetchedColumns(TableMetadata metadata, RegularAndStaticColumns queried) { - return new RegularAndStaticColumns(queried.statics, metadata.regularColumns()); + var merged = queried.regulars.mergeTo(metadata.regularColumns()); + return new RegularAndStaticColumns(queried.statics, merged); } }, @@ -295,14 +300,16 @@ public static ColumnFilter selection(TableMetadata metadata, } /** - * The columns that needs to be fetched internally for this filter. + * The columns that needs to be fetched internally. See FetchingStrategy for why this is + * always a superset of the queried columns. * * @return the columns to fetch for this filter. */ public abstract RegularAndStaticColumns fetchedColumns(); /** - * The columns actually queried by the user. + * The columns needed to process the query, including selected columns, ordering columns, + * restriction (predicate) columns, and synthetic columns. *

* Note that this is in general not all the columns that are fetched internally (see {@link #fetchedColumns}). */ @@ -619,9 +626,7 @@ private SortedSetMultimap buildSubSelectio */ public static class WildCardColumnFilter extends ColumnFilter { - /** - * The queried and fetched columns. - */ + // for wildcards, there is no distinction between fetched and queried because queried is already "everything" private final RegularAndStaticColumns fetchedAndQueried; /** @@ -667,6 +672,12 @@ public boolean fetches(ColumnMetadata column) return true; } + @Override + public boolean fetchesExplicitly(ColumnMetadata column) + { + return false; + } + @Override public boolean fetchedColumnIsQueried(ColumnMetadata column) { @@ -739,14 +750,9 @@ public static class SelectionColumnFilter extends ColumnFilter { public final FetchingStrategy fetchingStrategy; - /** - * The selected columns - */ + // Materializes the columns required to implement queriedColumns() and fetchedColumns(), + // see the comments to superclass's methods private final RegularAndStaticColumns queried; - - /** - * The columns that need to be fetched to be able - */ private final RegularAndStaticColumns fetched; private final SortedSetMultimap subSelections; // can be null @@ -820,6 +826,12 @@ public boolean fetches(ColumnMetadata column) return fetchingStrategy.fetchesAllColumns(column.isStatic()) || fetched.contains(column); } + @Override + public boolean fetchesExplicitly(ColumnMetadata column) + { + return fetched.contains(column); + } + /** * Whether the provided complex cell (identified by its column and path), which is assumed to be _fetched_ by * this filter, is also _queried_ by the user. diff --git a/src/java/org/apache/cassandra/db/filter/RowFilter.java b/src/java/org/apache/cassandra/db/filter/RowFilter.java index d090cf37eb5f..6b316c91bfa0 100644 --- a/src/java/org/apache/cassandra/db/filter/RowFilter.java +++ b/src/java/org/apache/cassandra/db/filter/RowFilter.java @@ -1129,6 +1129,7 @@ public boolean isSatisfiedBy(TableMetadata metadata, DecoratedKey partitionKey, case LIKE_MATCHES: case ANALYZER_MATCHES: case ANN: + case BM25: { assert !column.isComplex() : "Only CONTAINS and CONTAINS_KEY are supported for 'complex' types"; ByteBuffer foundValue = getValue(metadata, partitionKey, row); diff --git a/src/java/org/apache/cassandra/index/Index.java b/src/java/org/apache/cassandra/index/Index.java index 7839df669495..46024d2e6adf 100644 --- a/src/java/org/apache/cassandra/index/Index.java +++ b/src/java/org/apache/cassandra/index/Index.java @@ -23,7 +23,6 @@ import java.nio.ByteBuffer; import java.util.Collection; import java.util.Collections; -import java.util.Comparator; import java.util.List; import java.util.Objects; import java.util.Optional; @@ -471,20 +470,6 @@ interface Analyzer */ public RowFilter getPostIndexQueryFilter(RowFilter filter); - /** - * Returns a {@link Comparator} of CQL result rows, so they can be ordered by the - * coordinator before sending them to client. - * - * @param restriction restriction that requires current index - * @param columnIndex idx of the indexed column in returned row - * @param options query options - * @return a comparator of rows - */ - default Comparator> postQueryComparator(Restriction restriction, int columnIndex, QueryOptions options) - { - throw new NotImplementedException(); - } - /** * Returns a {@link Scorer} to give a similarity/proximity score to CQL result rows, so they can be ordered by the * coordinator before sending them to client. diff --git a/src/java/org/apache/cassandra/index/IndexRegistry.java b/src/java/org/apache/cassandra/index/IndexRegistry.java index cc6b0d103ea9..2d7ac81a0e8c 100644 --- a/src/java/org/apache/cassandra/index/IndexRegistry.java +++ b/src/java/org/apache/cassandra/index/IndexRegistry.java @@ -21,6 +21,7 @@ package org.apache.cassandra.index; import java.util.Collection; +import java.util.HashSet; import java.util.Collections; import java.util.Optional; import java.util.Set; @@ -102,12 +103,6 @@ public Optional getBestIndexFor(RowFilter.Expression expression) public void validate(PartitionUpdate update) { } - - @Override - public void validate(RowFilter filter) - { - // no-op since it's an empty registry - } }; /** @@ -295,12 +290,6 @@ public Optional getBestIndexFor(RowFilter.Expression expression) public void validate(PartitionUpdate update) { } - - @Override - public void validate(RowFilter filter) - { - // no-op since it's an empty registry - } }; default void registerIndex(Index index) @@ -341,8 +330,6 @@ default Optional getAnalyzerFor(ColumnMetadata column, Operator */ void validate(PartitionUpdate update); - void validate(RowFilter filter); - /** * Returns the {@code IndexRegistry} associated to the specified table. * @@ -356,4 +343,74 @@ public static IndexRegistry obtain(TableMetadata table) return table.isVirtual() ? EMPTY : Keyspace.openAndGetStore(table).indexManager; } + + enum EqBehavior + { + EQ, + MATCH, + AMBIGUOUS + } + + class EqBehaviorIndexes + { + public EqBehavior behavior; + public final Index eqIndex; + public final Index matchIndex; + + private EqBehaviorIndexes(Index eqIndex, Index matchIndex, EqBehavior behavior) + { + this.eqIndex = eqIndex; + this.matchIndex = matchIndex; + this.behavior = behavior; + } + + public static EqBehaviorIndexes eq(Index eqIndex) + { + return new EqBehaviorIndexes(eqIndex, null, EqBehavior.EQ); + } + + public static EqBehaviorIndexes match(Index eqAndMatchIndex) + { + return new EqBehaviorIndexes(eqAndMatchIndex, eqAndMatchIndex, EqBehavior.MATCH); + } + + public static EqBehaviorIndexes ambiguous(Index firstEqIndex, Index secondEqIndex) + { + return new EqBehaviorIndexes(firstEqIndex, secondEqIndex, EqBehavior.AMBIGUOUS); + } + } + + /** + * @return + * - AMBIGUOUS if an index supports EQ and a different one supports both EQ and ANALYZER_MATCHES + * - MATCHES if an index supports both EQ and ANALYZER_MATCHES + * - otherwise EQ + */ + default EqBehaviorIndexes getEqBehavior(ColumnMetadata cm) + { + Index eqOnlyIndex = null; + Index bothIndex = null; + + for (Index index : listIndexes()) + { + boolean supportsEq = index.supportsExpression(cm, Operator.EQ); + boolean supportsMatches = index.supportsExpression(cm, Operator.ANALYZER_MATCHES); + + if (supportsEq && supportsMatches) + bothIndex = index; + else if (supportsEq) + eqOnlyIndex = index; + } + + // If we have one index supporting only EQ and another supporting both, return AMBIGUOUS + if (eqOnlyIndex != null && bothIndex != null) + return EqBehaviorIndexes.ambiguous(eqOnlyIndex, bothIndex); + + // If we have an index supporting both EQ and MATCHES, return MATCHES + if (bothIndex != null) + return EqBehaviorIndexes.match(bothIndex); + + // Otherwise return EQ + return EqBehaviorIndexes.eq(eqOnlyIndex == null ? bothIndex : eqOnlyIndex); + } } diff --git a/src/java/org/apache/cassandra/index/RowFilterValidator.java b/src/java/org/apache/cassandra/index/RowFilterValidator.java deleted file mode 100644 index fb70fbfc1452..000000000000 --- a/src/java/org/apache/cassandra/index/RowFilterValidator.java +++ /dev/null @@ -1,103 +0,0 @@ -/* - * Copyright DataStax, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.cassandra.index; - -import java.util.HashSet; -import java.util.Set; -import java.util.StringJoiner; - -import org.apache.cassandra.cql3.Operator; -import org.apache.cassandra.db.filter.RowFilter; -import org.apache.cassandra.index.sai.analyzer.AnalyzerEqOperatorSupport; -import org.apache.cassandra.schema.ColumnMetadata; -import org.apache.cassandra.service.ClientWarn; - -/** - * Class for validating the index-related aspects of a {@link RowFilter}, without considering what index is actually used. - *

- * It will emit a client warning when a query has EQ restrictions on columns having an analyzed index. - */ -class RowFilterValidator -{ - private final Iterable allIndexes; - - private Set columns; - private Set indexes; - - private RowFilterValidator(Iterable allIndexes) - { - this.allIndexes = allIndexes; - } - - private void addEqRestriction(ColumnMetadata column) - { - for (Index index : allIndexes) - { - if (index.supportsExpression(column, Operator.EQ) && - index.supportsExpression(column, Operator.ANALYZER_MATCHES)) - { - if (columns == null) - columns = new HashSet<>(); - columns.add(column); - - if (indexes == null) - indexes = new HashSet<>(); - indexes.add(index); - } - } - } - - private void validate() - { - if (columns == null || indexes == null) - return; - - StringJoiner columnNames = new StringJoiner(", "); - StringJoiner indexNames = new StringJoiner(", "); - columns.forEach(column -> columnNames.add(column.name.toString())); - indexes.forEach(index -> indexNames.add(index.getIndexMetadata().name)); - - ClientWarn.instance.warn(String.format(AnalyzerEqOperatorSupport.EQ_RESTRICTION_ON_ANALYZED_WARNING, columnNames, indexNames)); - } - - /** - * Emits a client warning if the filter contains EQ restrictions on columns having an analyzed index. - * - * @param filter the filter to validate - * @param indexes the existing indexes - */ - public static void validate(RowFilter filter, Iterable indexes) - { - RowFilterValidator validator = new RowFilterValidator(indexes); - validate(filter.root(), validator); - validator.validate(); - } - - private static void validate(RowFilter.FilterElement element, RowFilterValidator validator) - { - for (RowFilter.Expression expression : element.expressions()) - { - if (expression.operator() == Operator.EQ) - validator.addEqRestriction(expression.column()); - } - - for (RowFilter.FilterElement child : element.children()) - { - validate(child, validator); - } - } -} diff --git a/src/java/org/apache/cassandra/index/SecondaryIndexManager.java b/src/java/org/apache/cassandra/index/SecondaryIndexManager.java index b1930e741804..d4c75ed2a4f1 100644 --- a/src/java/org/apache/cassandra/index/SecondaryIndexManager.java +++ b/src/java/org/apache/cassandra/index/SecondaryIndexManager.java @@ -1277,12 +1277,6 @@ public void validate(PartitionUpdate update) throws InvalidRequestException index.validate(update); } - @Override - public void validate(RowFilter filter) - { - RowFilterValidator.validate(filter, indexes.values()); - } - /* * IndexRegistry methods */ diff --git a/src/java/org/apache/cassandra/index/sai/IndexContext.java b/src/java/org/apache/cassandra/index/sai/IndexContext.java index 94b8cd1afed1..890ecd2cc7aa 100644 --- a/src/java/org/apache/cassandra/index/sai/IndexContext.java +++ b/src/java/org/apache/cassandra/index/sai/IndexContext.java @@ -689,8 +689,8 @@ public boolean supports(Operator op) { if (op.isLike() || op == Operator.LIKE) return false; // Analyzed columns store the indexed result, so we are unable to compute raw equality. - // The only supported operator is ANALYZER_MATCHES. - if (op == Operator.ANALYZER_MATCHES) return isAnalyzed; + // The only supported operators are ANALYZER_MATCHES and BM25. + if (op == Operator.ANALYZER_MATCHES || op == Operator.BM25) return isAnalyzed; // If the column is analyzed and the operator is EQ, we need to check if the analyzer supports it. if (op == Operator.EQ && isAnalyzed && !analyzerFactory.supportsEquals()) @@ -714,7 +714,6 @@ public boolean supports(Operator op) || column.type instanceof IntegerType); // Currently truncates to 20 bytes Expression.Op operator = Expression.Op.valueOf(op); - if (isNonFrozenCollection()) { if (indexType == IndexTarget.Type.KEYS) @@ -726,17 +725,12 @@ public boolean supports(Operator op) return indexType == IndexTarget.Type.KEYS_AND_VALUES && (operator == Expression.Op.EQ || operator == Expression.Op.NOT_EQ || operator == Expression.Op.RANGE); } - if (indexType == IndexTarget.Type.FULL) return operator == Expression.Op.EQ; - AbstractType validator = getValidator(); - if (operator == Expression.Op.IN) return true; - if (operator != Expression.Op.EQ && EQ_ONLY_TYPES.contains(validator)) return false; - // RANGE only applicable to non-literal indexes return (operator != null) && !(TypeUtil.isLiteral(validator) && operator == Expression.Op.RANGE); } diff --git a/src/java/org/apache/cassandra/index/sai/StorageAttachedIndex.java b/src/java/org/apache/cassandra/index/sai/StorageAttachedIndex.java index fd2ab5795dce..6867bac87dca 100644 --- a/src/java/org/apache/cassandra/index/sai/StorageAttachedIndex.java +++ b/src/java/org/apache/cassandra/index/sai/StorageAttachedIndex.java @@ -305,20 +305,26 @@ public static Map validateOptions(Map options, T throw new InvalidRequestException("Failed to retrieve target column for: " + targetColumn); } - // In order to support different index target on non-frozen map, ie. KEYS, VALUE, ENTRIES, we need to put index - // name as part of index file name instead of column name. We only need to check that the target is different - // between indexes. This will only allow indexes in the same column with a different IndexTarget.Type. - // - // Note that: "metadata.indexes" already includes current index - if (metadata.indexes.stream().filter(index -> index.getIndexClassName().equals(StorageAttachedIndex.class.getName())) - .map(index -> TargetParser.parse(metadata, index.options.get(IndexTarget.TARGET_OPTION_NAME))) - .filter(Objects::nonNull).filter(t -> t.equals(target)).count() > 1) - { - throw new InvalidRequestException("Cannot create more than one storage-attached index on the same column: " + target.left); - } + // Check for duplicate indexes considering both target and analyzer configuration + boolean isAnalyzed = AbstractAnalyzer.isAnalyzed(options); + long duplicateCount = metadata.indexes.stream() + .filter(index -> index.getIndexClassName().equals(StorageAttachedIndex.class.getName())) + .filter(index -> { + // Indexes on the same column with different target (KEYS, VALUES, ENTRIES) + // are allowed on non-frozen Maps + var existingTarget = TargetParser.parse(metadata, index.options.get(IndexTarget.TARGET_OPTION_NAME)); + if (existingTarget == null || !existingTarget.equals(target)) + return false; + // Also allow different indexes if one is analyzed and the other isn't + return isAnalyzed == AbstractAnalyzer.isAnalyzed(index.options); + }) + .count(); + // >1 because "metadata.indexes" already includes current index + if (duplicateCount > 1) + throw new InvalidRequestException(String.format("Cannot create duplicate storage-attached index on column: %s", target.left)); // Analyzer is not supported against PK columns - if (AbstractAnalyzer.isAnalyzed(options)) + if (isAnalyzed) { for (ColumnMetadata column : metadata.primaryKeyColumns()) { @@ -695,22 +701,10 @@ public RowFilter getPostIndexQueryFilter(RowFilter filter) throw new UnsupportedOperationException(); } - @Override - public Comparator> postQueryComparator(Restriction restriction, int columnIndex, QueryOptions options) - { - assert restriction instanceof SingleColumnRestriction.OrderRestriction; - - SingleColumnRestriction.OrderRestriction orderRestriction = (SingleColumnRestriction.OrderRestriction) restriction; - var typeComparator = orderRestriction.getDirection() == Operator.ORDER_BY_DESC - ? indexContext.getValidator().reversed() - : indexContext.getValidator(); - return (a, b) -> typeComparator.compare(a.get(columnIndex), b.get(columnIndex)); - } - @Override public Scorer postQueryScorer(Restriction restriction, int columnIndex, QueryOptions options) { - // For now, only support ANN + // TODO remove this with SelectStatement.ANN_USE_SYNTHETIC_SCORE. assert restriction instanceof SingleColumnRestriction.AnnRestriction; Preconditions.checkState(indexContext.isVector()); diff --git a/src/java/org/apache/cassandra/index/sai/analyzer/AnalyzerEqOperatorSupport.java b/src/java/org/apache/cassandra/index/sai/analyzer/AnalyzerEqOperatorSupport.java index 116bc7f62832..30408c9b986f 100644 --- a/src/java/org/apache/cassandra/index/sai/analyzer/AnalyzerEqOperatorSupport.java +++ b/src/java/org/apache/cassandra/index/sai/analyzer/AnalyzerEqOperatorSupport.java @@ -49,7 +49,7 @@ public class AnalyzerEqOperatorSupport OPTION, Arrays.toString(Value.values())); public static final String EQ_RESTRICTION_ON_ANALYZED_WARNING = - String.format("Columns [%%s] are restricted by '=' and have analyzed indexes [%%s] able to process those restrictions. " + + String.format("Column [%%s] is restricted by '=' and has an analyzed index [%%s] able to process those restrictions. " + "Analyzed indexes might process '=' restrictions in a way that is inconsistent with non-indexed queries. " + "While '=' is still supported on analyzed indexes for backwards compatibility, " + "it is recommended to use the ':' operator instead to prevent the ambiguity. " + @@ -58,6 +58,13 @@ public class AnalyzerEqOperatorSupport "please use '%s':'%s' in the index options.", OPTION, Value.UNSUPPORTED.toString().toLowerCase()); + public static final String EQ_AMBIGUOUS_ERROR = + String.format("Column [%%s] equality predicate is ambiguous. It has both an analyzed index [%%s] configured with '%s':'%s', " + + "and an un-analyzed index [%%s]. " + + "To avoid ambiguity, drop the analyzed index and recreate it with option '%s':'%s'.", + OPTION, Value.MATCH.toString().toLowerCase(), OPTION, Value.UNSUPPORTED.toString().toLowerCase()); + + public static final String LWT_CONDITION_ON_ANALYZED_WARNING = "Index analyzers not applied to LWT conditions on columns [%s]."; diff --git a/src/java/org/apache/cassandra/index/sai/disk/RAMStringIndexer.java b/src/java/org/apache/cassandra/index/sai/disk/RAMStringIndexer.java index 8b6f07d43597..f2035bd9631a 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/RAMStringIndexer.java +++ b/src/java/org/apache/cassandra/index/sai/disk/RAMStringIndexer.java @@ -140,27 +140,36 @@ private ByteComparable asByteComparable(byte[] bytes, int offset, int length) }; } + /** + * @return bytes allocated. may be zero if the (term, row) pair is a duplicate + */ public long add(BytesRef term, int segmentRowId) { long startBytes = estimatedBytesUsed(); int termID = termsHash.add(term); + boolean firstOccurrence = termID >= 0; - if (termID >= 0) + if (firstOccurrence) { - // firs time seeing this term, create the term's first slice ! + // first time seeing this term, create the term's first slice ! slices.createNewSlice(termID); } else { termID = (-termID) - 1; + // compaction should call this method only with increasing segmentRowIds + assert segmentRowId >= lastSegmentRowID[termID]; + // Skip if we've already recorded seen this segmentRowId for this term + if (segmentRowId == lastSegmentRowID[termID]) + return 0; } if (termID >= lastSegmentRowID.length - 1) - { lastSegmentRowID = ArrayUtil.grow(lastSegmentRowID, termID + 1); - } int delta = segmentRowId - lastSegmentRowID[termID]; + // sanity check that we're advancing the row id, i.e. no duplicate entries. + assert firstOccurrence || delta > 0; lastSegmentRowID[termID] = segmentRowId; diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/IndexSearcher.java b/src/java/org/apache/cassandra/index/sai/disk/v1/IndexSearcher.java index 52988a685cd5..16d14bdd8ac3 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/IndexSearcher.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/IndexSearcher.java @@ -68,9 +68,7 @@ public abstract class IndexSearcher implements Closeable, SegmentOrdering protected final SegmentMetadata metadata; protected final IndexContext indexContext; - private static final SSTableReadsListener NOOP_LISTENER = new SSTableReadsListener() {}; - - private final ColumnFilter columnFilter; + protected final ColumnFilter columnFilter; protected IndexSearcher(PrimaryKeyMap.Factory primaryKeyMapFactory, PerIndexFiles perIndexFiles, @@ -90,30 +88,36 @@ protected IndexSearcher(PrimaryKeyMap.Factory primaryKeyMapFactory, public abstract long indexFileCacheSize(); /** - * Search on-disk index synchronously + * Search on-disk index synchronously. Used for WHERE clause predicates, including BOUNDED_ANN. * * @param expression to filter on disk index * @param keyRange key range specific in read command, used by ANN index * @param queryContext to track per sstable cache and per query metrics * @param defer create the iterator in a deferred state - * @param limit the num of rows to returned, used by ANN index * @return {@link KeyRangeIterator} that matches given expression */ - public abstract KeyRangeIterator search(Expression expression, AbstractBounds keyRange, QueryContext queryContext, boolean defer, int limit) throws IOException; + public abstract KeyRangeIterator search(Expression expression, AbstractBounds keyRange, QueryContext queryContext, boolean defer) throws IOException; /** - * Order the on-disk index synchronously and produce an iterator in score order + * Order the rows by the given Orderer. Used for ORDER BY clause when + * (1) the WHERE predicate is either a partition restriction or a range restriction on the index, + * (2) there is no WHERE predicate, or + * (3) the planner determines it is better to post-filter the ordered results by the predicate. * * @param orderer the object containing the ordering logic * @param slice optional predicate to get a slice of the index * @param keyRange key range specific in read command, used by ANN index * @param queryContext to track per sstable cache and per query metrics - * @param limit the num of rows to returned, used by ANN index + * @param limit the initial num of rows to returned, used by ANN index. More rows may be requested if filtering throws away more than expected! * @return an iterator of {@link PrimaryKeyWithSortKey} in score order */ public abstract CloseableIterator orderBy(Orderer orderer, Expression slice, AbstractBounds keyRange, QueryContext queryContext, int limit) throws IOException; - + /** + * Order the rows by the given Orderer. Used for ORDER BY clause when the WHERE predicates + * have been applied first, yielding a list of primary keys. Again, `limit` is a planner hint for ANN to determine + * the initial number of results returned, not a maximum. + */ @Override public CloseableIterator orderResultsBy(SSTableReader reader, QueryContext context, List keys, Orderer orderer, int limit) throws IOException { @@ -124,7 +128,7 @@ public CloseableIterator orderResultsBy(SSTableReader rea { var slices = Slices.with(indexContext.comparator(), Slice.make(key.clustering())); // TODO if we end up needing to read the row still, is it better to store offset and use reader.unfilteredAt? - try (var iter = reader.iterator(key.partitionKey(), slices, columnFilter, false, NOOP_LISTENER)) + try (var iter = reader.iterator(key.partitionKey(), slices, columnFilter, false, SSTableReadsListener.NOOP_LISTENER)) { if (iter.hasNext()) { diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcher.java b/src/java/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcher.java index 15aaa349e1c8..fb1bb3ac455f 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcher.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcher.java @@ -19,14 +19,26 @@ package org.apache.cassandra.index.sai.disk.v1; import java.io.IOException; +import java.io.UncheckedIOException; import java.lang.invoke.MethodHandles; +import java.nio.ByteBuffer; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; import java.util.Map; +import java.util.function.Function; +import java.util.stream.Collectors; import com.google.common.base.MoreObjects; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.apache.cassandra.cql3.Operator; import org.apache.cassandra.db.PartitionPosition; +import org.apache.cassandra.db.Slice; +import org.apache.cassandra.db.Slices; +import org.apache.cassandra.db.rows.Cell; +import org.apache.cassandra.db.rows.Row; import org.apache.cassandra.dht.AbstractBounds; import org.apache.cassandra.index.sai.IndexContext; import org.apache.cassandra.index.sai.QueryContext; @@ -35,24 +47,30 @@ import org.apache.cassandra.index.sai.disk.TermsIterator; import org.apache.cassandra.index.sai.disk.format.IndexComponentType; import org.apache.cassandra.index.sai.disk.format.Version; +import org.apache.cassandra.index.sai.disk.v1.postings.IntersectingPostingList; import org.apache.cassandra.index.sai.iterators.KeyRangeIterator; import org.apache.cassandra.index.sai.metrics.MulticastQueryEventListeners; import org.apache.cassandra.index.sai.metrics.QueryEventListener; import org.apache.cassandra.index.sai.plan.Expression; import org.apache.cassandra.index.sai.plan.Orderer; +import org.apache.cassandra.index.sai.utils.BM25Utils; +import org.apache.cassandra.index.sai.utils.PrimaryKey; import org.apache.cassandra.index.sai.utils.PrimaryKeyWithSortKey; import org.apache.cassandra.index.sai.utils.RowIdWithByteComparable; import org.apache.cassandra.index.sai.utils.SAICodecUtils; -import org.apache.cassandra.index.sai.utils.SegmentOrdering; +import org.apache.cassandra.io.sstable.format.SSTableReader; +import org.apache.cassandra.io.sstable.format.SSTableReadsListener; import org.apache.cassandra.io.util.FileUtils; import org.apache.cassandra.utils.AbstractIterator; import org.apache.cassandra.utils.CloseableIterator; import org.apache.cassandra.utils.bytecomparable.ByteComparable; +import static org.apache.cassandra.index.sai.disk.PostingList.END_OF_STREAM; + /** * Executes {@link Expression}s against the trie-based terms dictionary for an individual index segment. */ -public class InvertedIndexSearcher extends IndexSearcher implements SegmentOrdering +public class InvertedIndexSearcher extends IndexSearcher { private static final Logger logger = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass()); @@ -60,6 +78,7 @@ public class InvertedIndexSearcher extends IndexSearcher implements SegmentOrder private final QueryEventListener.TrieIndexEventListener perColumnEventListener; private final Version version; private final boolean filterRangeResults; + private final SSTableReader sstable; protected InvertedIndexSearcher(SSTableContext sstableContext, PerIndexFiles perIndexFiles, @@ -69,6 +88,7 @@ protected InvertedIndexSearcher(SSTableContext sstableContext, boolean filterRangeResults) throws IOException { super(sstableContext.primaryKeyMapFactory(), perIndexFiles, segmentMetadata, indexContext); + this.sstable = sstableContext.sstable; long root = metadata.getIndexRoot(IndexComponentType.TERMS_DATA); assert root >= 0; @@ -100,7 +120,7 @@ public long indexFileCacheSize() } @SuppressWarnings("resource") - public KeyRangeIterator search(Expression exp, AbstractBounds keyRange, QueryContext context, boolean defer, int limit) throws IOException + public KeyRangeIterator search(Expression exp, AbstractBounds keyRange, QueryContext context, boolean defer) throws IOException { PostingList postingList = searchPosting(exp, context); return toPrimaryKeyIterator(postingList, context); @@ -129,11 +149,113 @@ else if (exp.getOp() == Expression.Op.RANGE) throw new IllegalArgumentException(indexContext.logMessage("Unsupported expression: " + exp)); } + private Cell readColumn(SSTableReader sstable, PrimaryKey primaryKey) + { + var dk = primaryKey.partitionKey(); + var slices = Slices.with(indexContext.comparator(), Slice.make(primaryKey.clustering())); + try (var rowIterator = sstable.iterator(dk, slices, columnFilter, false, SSTableReadsListener.NOOP_LISTENER)) + { + var unfiltered = rowIterator.next(); + assert unfiltered.isRow() : unfiltered; + Row row = (Row) unfiltered; + return row.getCell(indexContext.getDefinition()); + } + } + @Override public CloseableIterator orderBy(Orderer orderer, Expression slice, AbstractBounds keyRange, QueryContext queryContext, int limit) throws IOException { - var iter = new RowIdWithTermsIterator(reader.allTerms(orderer.isAscending())); - return toMetaSortedIterator(iter, queryContext); + if (!orderer.isBM25()) + { + var iter = new RowIdWithTermsIterator(reader.allTerms(orderer.isAscending())); + return toMetaSortedIterator(iter, queryContext); + } + + // find documents that match each term + var queryTerms = orderer.getQueryTerms(); + var postingLists = queryTerms.stream() + .collect(Collectors.toMap(Function.identity(), term -> + { + var encodedTerm = version.onDiskFormat().encodeForTrie(term, indexContext.getValidator()); + var listener = MulticastQueryEventListeners.of(queryContext, perColumnEventListener); + var postings = reader.exactMatch(encodedTerm, listener, queryContext); + return postings == null ? PostingList.EMPTY : postings; + })); + // extract the match count for each + var documentFrequencies = postingLists.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> (long) e.getValue().size())); + + try (var pkm = primaryKeyMapFactory.newPerSSTablePrimaryKeyMap(); + var merged = IntersectingPostingList.intersect(List.copyOf(postingLists.values()))) + { + // construct an Iterator() from our intersected postings + var it = new AbstractIterator() { + @Override + protected PrimaryKey computeNext() + { + try + { + int rowId = merged.nextPosting(); + if (rowId == PostingList.END_OF_STREAM) + return endOfData(); + return pkm.primaryKeyFromRowId(rowId); + } + catch (IOException e) + { + throw new UncheckedIOException(e); + } + } + }; + return bm25Internal(it, queryTerms, documentFrequencies); + } + } + + private CloseableIterator bm25Internal(Iterator keyIterator, + List queryTerms, + Map documentFrequencies) + { + var totalRows = sstable.getTotalRows(); + // since doc frequencies can be an estimate from the index histogram, which does not have bounded error, + // cap frequencies to total rows so that the IDF term doesn't turn negative + var cappedFrequencies = documentFrequencies.entrySet().stream() + .collect(Collectors.toMap(Map.Entry::getKey, e -> Math.min(e.getValue(), totalRows))); + var docStats = new BM25Utils.DocStats(cappedFrequencies, totalRows); + return BM25Utils.computeScores(keyIterator, + queryTerms, + docStats, + indexContext, + sstable.descriptor.id, + pk -> readColumn(sstable, pk)); + } + + @Override + public CloseableIterator orderResultsBy(SSTableReader reader, QueryContext queryContext, List keys, Orderer orderer, int limit) throws IOException + { + if (!orderer.isBM25()) + return super.orderResultsBy(reader, queryContext, keys, orderer, limit); + + var queryTerms = orderer.getQueryTerms(); + // compute documentFrequencies from either histogram or an index search + var documentFrequencies = new HashMap(); + boolean hasHistograms = metadata.version.onDiskFormat().indexFeatureSet().hasTermsHistogram(); + for (ByteBuffer term : queryTerms) + { + long matches; + if (hasHistograms) + { + matches = metadata.estimateNumRowsMatching(new Expression(indexContext).add(Operator.ANALYZER_MATCHES, term)); + } + else + { + // Without histograms, need to do an actual index scan + var encodedTerm = version.onDiskFormat().encodeForTrie(term, indexContext.getValidator()); + var listener = MulticastQueryEventListeners.of(queryContext, perColumnEventListener); + var postingList = this.reader.exactMatch(encodedTerm, listener, queryContext); + matches = postingList.size(); + FileUtils.closeQuietly(postingList); + } + documentFrequencies.put(term, matches); + } + return bm25Internal(keys.iterator(), queryTerms, documentFrequencies); } @Override @@ -172,7 +294,7 @@ protected RowIdWithByteComparable computeNext() while (true) { long nextPosting = currentPostingList.nextPosting(); - if (nextPosting != PostingList.END_OF_STREAM) + if (nextPosting != END_OF_STREAM) return new RowIdWithByteComparable(Math.toIntExact(nextPosting), currentTerm); if (!source.hasNext()) diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/KDTreeIndexSearcher.java b/src/java/org/apache/cassandra/index/sai/disk/v1/KDTreeIndexSearcher.java index 8a2aa354bd47..5911f2351014 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/KDTreeIndexSearcher.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/KDTreeIndexSearcher.java @@ -84,7 +84,7 @@ public long indexFileCacheSize() } @Override - public KeyRangeIterator search(Expression exp, AbstractBounds keyRange, QueryContext context, boolean defer, int limit) throws IOException + public KeyRangeIterator search(Expression exp, AbstractBounds keyRange, QueryContext context, boolean defer) throws IOException { PostingList postingList = searchPosting(exp, context); return toPrimaryKeyIterator(postingList, context); diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/Segment.java b/src/java/org/apache/cassandra/index/sai/disk/v1/Segment.java index 87fe0a998e5f..3d52c944d726 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/Segment.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/Segment.java @@ -144,7 +144,7 @@ public long indexFileCacheSize() */ public KeyRangeIterator search(Expression expression, AbstractBounds keyRange, QueryContext context, boolean defer, int limit) throws IOException { - return index.search(expression, keyRange, context, defer, limit); + return index.search(expression, keyRange, context, defer); } /** diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/SegmentBuilder.java b/src/java/org/apache/cassandra/index/sai/disk/v1/SegmentBuilder.java index 504df79c924a..77c89c1d466a 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/SegmentBuilder.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/SegmentBuilder.java @@ -208,6 +208,7 @@ protected long addInternal(ByteBuffer term, int segmentRowId) var encodedTerm = components.onDiskFormat().encodeForTrie(term, termComparator); var bytes = ByteSourceInverse.readBytes(encodedTerm.asComparableBytes(byteComparableVersion)); var bytesRef = new BytesRef(bytes); + // ramIndexer is responsible for merging duplicate (term, row) pairs return ramIndexer.add(bytesRef, segmentRowId); } diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/postings/IntersectingPostingList.java b/src/java/org/apache/cassandra/index/sai/disk/v1/postings/IntersectingPostingList.java new file mode 100644 index 000000000000..8c3bfe3ee6df --- /dev/null +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/postings/IntersectingPostingList.java @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.index.sai.disk.v1.postings; + +import java.io.IOException; +import java.util.List; +import javax.annotation.concurrent.NotThreadSafe; + +import org.apache.cassandra.index.sai.disk.PostingList; +import org.apache.cassandra.io.util.FileUtils; + +/** + * Performs intersection operations on multiple PostingLists, returning only postings + * that appear in all inputs. + */ +@NotThreadSafe +public class IntersectingPostingList implements PostingList +{ + private final List postingLists; + private final int size; + + private IntersectingPostingList(List postingLists) + { + if (postingLists.isEmpty()) + throw new AssertionError(); + this.postingLists = postingLists; + this.size = postingLists.stream() + .mapToInt(PostingList::size) + .min() + .orElse(0); + } + + /** + * @return the intersection of the provided posting lists + */ + public static PostingList intersect(List postingLists) + { + if (postingLists.size() == 1) + return postingLists.get(0); + + if (postingLists.stream().anyMatch(PostingList::isEmpty)) + return new EmptyIntersectingList(postingLists); + + return new IntersectingPostingList(postingLists); + } + + @Override + public int nextPosting() throws IOException + { + return findNextIntersection(Integer.MIN_VALUE, false); + } + + @Override + public int advance(int targetRowID) throws IOException + { + assert targetRowID >= 0 : targetRowID; + return findNextIntersection(targetRowID, true); + } + + private int findNextIntersection(int targetRowID, boolean isAdvance) throws IOException + { + int maxRowId = targetRowID; + int maxRowIdIndex = -1; + + // Scan through all posting lists looking for a common row ID + for (int i = 0; i < postingLists.size(); i++) + { + // don't advance the sublist in which we found our current max + if (i == maxRowIdIndex) + continue; + + // Advance this sublist to the current max, special casing the first one as needed + PostingList list = postingLists.get(i); + int rowId = (isAdvance || maxRowIdIndex >= 0) + ? list.advance(maxRowId) + : list.nextPosting(); + if (rowId == END_OF_STREAM) + return END_OF_STREAM; + + // Update maxRowId + index if we find a larger value, or this was the first sublist evaluated + if (rowId > maxRowId || maxRowIdIndex < 0) + { + maxRowId = rowId; + maxRowIdIndex = i; + i = -1; // restart the scan with new maxRowId + } + } + + // Once we complete a full scan without finding a larger rowId, we've found an intersection + return maxRowId; + } + + @Override + public int size() + { + return size; + } + + @Override + public void close() + { + for (PostingList list : postingLists) + FileUtils.closeQuietly(list); + } + + private static class EmptyIntersectingList extends EmptyPostingList + { + private final List lists; + + public EmptyIntersectingList(List postingLists) + { + this.lists = postingLists; + } + + @Override + public void close() + { + for (PostingList list : lists) + FileUtils.closeQuietly(list); + } + } +} + + diff --git a/src/java/org/apache/cassandra/index/sai/disk/v1/postings/VectorPostingList.java b/src/java/org/apache/cassandra/index/sai/disk/v1/postings/ReorderingPostingList.java similarity index 83% rename from src/java/org/apache/cassandra/index/sai/disk/v1/postings/VectorPostingList.java rename to src/java/org/apache/cassandra/index/sai/disk/v1/postings/ReorderingPostingList.java index 52155bf6ed59..f43c7c4e8dce 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v1/postings/VectorPostingList.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v1/postings/ReorderingPostingList.java @@ -19,31 +19,29 @@ package org.apache.cassandra.index.sai.disk.v1.postings; import java.io.IOException; +import java.util.function.ToIntFunction; import org.apache.cassandra.index.sai.disk.PostingList; -import org.apache.cassandra.index.sai.utils.RowIdWithMeta; import org.apache.cassandra.utils.CloseableIterator; import org.apache.lucene.util.LongHeap; /** * A posting list for ANN search results. Transforms results from similarity order to rowId order. */ -public class VectorPostingList implements PostingList +public class ReorderingPostingList implements PostingList { private final LongHeap segmentRowIds; private final int size; - public VectorPostingList(CloseableIterator source) + public ReorderingPostingList(CloseableIterator source, ToIntFunction rowIdTransformer) { - // TODO find int specific data structure? segmentRowIds = new LongHeap(32); int n = 0; - // Once the source is consumed, we have to close it. try (source) { while (source.hasNext()) { - segmentRowIds.push(source.next().getSegmentRowId()); + segmentRowIds.push(rowIdTransformer.applyAsInt(source.next())); n++; } } diff --git a/src/java/org/apache/cassandra/index/sai/disk/v2/V2VectorIndexSearcher.java b/src/java/org/apache/cassandra/index/sai/disk/v2/V2VectorIndexSearcher.java index 62748e3332c4..21771ac94e1f 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/v2/V2VectorIndexSearcher.java +++ b/src/java/org/apache/cassandra/index/sai/disk/v2/V2VectorIndexSearcher.java @@ -49,7 +49,7 @@ import org.apache.cassandra.index.sai.disk.v1.IndexSearcher; import org.apache.cassandra.index.sai.disk.v1.PerIndexFiles; import org.apache.cassandra.index.sai.disk.v1.SegmentMetadata; -import org.apache.cassandra.index.sai.disk.v1.postings.VectorPostingList; +import org.apache.cassandra.index.sai.disk.v1.postings.ReorderingPostingList; import org.apache.cassandra.index.sai.disk.v5.V5VectorPostingsWriter; import org.apache.cassandra.index.sai.disk.vector.BruteForceRowIdIterator; import org.apache.cassandra.index.sai.disk.vector.CassandraDiskAnn; @@ -64,8 +64,8 @@ import org.apache.cassandra.index.sai.utils.PrimaryKey; import org.apache.cassandra.index.sai.utils.PrimaryKeyWithSortKey; import org.apache.cassandra.index.sai.utils.RangeUtil; +import org.apache.cassandra.index.sai.utils.RowIdWithMeta; import org.apache.cassandra.index.sai.utils.RowIdWithScore; -import org.apache.cassandra.index.sai.utils.SegmentOrdering; import org.apache.cassandra.io.sstable.format.SSTableReader; import org.apache.cassandra.metrics.LinearFit; import org.apache.cassandra.metrics.PairedSlidingWindowReservoir; @@ -81,7 +81,7 @@ /** * Executes ann search against the graph for an individual index segment. */ -public class V2VectorIndexSearcher extends IndexSearcher implements SegmentOrdering +public class V2VectorIndexSearcher extends IndexSearcher { private static final Logger logger = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass()); private static final VectorTypeSupport vts = VectorizationProvider.getInstance().getVectorTypeSupport(); @@ -133,13 +133,13 @@ public ProductQuantization getPQ() } @Override - public KeyRangeIterator search(Expression exp, AbstractBounds keyRange, QueryContext context, boolean defer, int limit) throws IOException + public KeyRangeIterator search(Expression exp, AbstractBounds keyRange, QueryContext context, boolean defer) throws IOException { - PostingList results = searchPosting(context, exp, keyRange, limit); + PostingList results = searchPosting(context, exp, keyRange); return toPrimaryKeyIterator(results, context); } - private PostingList searchPosting(QueryContext context, Expression exp, AbstractBounds keyRange, int limit) throws IOException + private PostingList searchPosting(QueryContext context, Expression exp, AbstractBounds keyRange) throws IOException { if (logger.isTraceEnabled()) logger.trace(indexContext.logMessage("Searching on expression '{}'..."), exp); @@ -151,7 +151,7 @@ private PostingList searchPosting(QueryContext context, Expression exp, Abstract // this is a thresholded query, so pass graph.size() as top k to get all results satisfying the threshold var result = searchInternal(keyRange, context, queryVector, graph.size(), graph.size(), exp.getEuclideanSearchThreshold()); - return new VectorPostingList(result); + return new ReorderingPostingList(result, RowIdWithMeta::getSegmentRowId); } @Override @@ -160,11 +160,11 @@ public CloseableIterator orderBy(Orderer orderer, Express if (logger.isTraceEnabled()) logger.trace(indexContext.logMessage("Searching on expression '{}'..."), orderer); - if (orderer.vector == null) + if (!orderer.isANN()) throw new IllegalArgumentException(indexContext.logMessage("Unsupported expression during ANN index query: " + orderer)); int rerankK = indexContext.getIndexWriterConfig().getSourceModel().rerankKFor(limit, graph.getCompression()); - var queryVector = vts.createFloatVector(orderer.vector); + var queryVector = vts.createFloatVector(orderer.getVectorTerm()); var result = searchInternal(keyRange, context, queryVector, limit, rerankK, 0); return toMetaSortedIterator(result, context); @@ -485,14 +485,14 @@ public CloseableIterator orderResultsBy(SSTableReader rea if (cost.shouldUseBruteForce()) { // brute force using the in-memory compressed vectors to cut down the number of results returned - var queryVector = vts.createFloatVector(orderer.vector); + var queryVector = vts.createFloatVector(orderer.getVectorTerm()); return toMetaSortedIterator(this.orderByBruteForce(queryVector, segmentOrdinalPairs, limit, rerankK), context); } // Create bits from the mapping var bits = bitSetForSearch(); segmentOrdinalPairs.forEachRightInt(bits::set); // else ask the index to perform a search limited to the bits we created - var queryVector = vts.createFloatVector(orderer.vector); + var queryVector = vts.createFloatVector(orderer.getVectorTerm()); var results = graph.search(queryVector, limit, rerankK, 0, bits, context, cost::updateStatistics); return toMetaSortedIterator(results, context); } diff --git a/src/java/org/apache/cassandra/index/sai/disk/vector/VectorMemtableIndex.java b/src/java/org/apache/cassandra/index/sai/disk/vector/VectorMemtableIndex.java index 25231748149f..cd2f48cda384 100644 --- a/src/java/org/apache/cassandra/index/sai/disk/vector/VectorMemtableIndex.java +++ b/src/java/org/apache/cassandra/index/sai/disk/vector/VectorMemtableIndex.java @@ -222,7 +222,7 @@ public List> orderBy(QueryContext conte assert slice == null : "ANN does not support index slicing"; assert orderer.isANN() : "Only ANN is supported for vector search, received " + orderer.operator; - var qv = vts.createFloatVector(orderer.vector); + var qv = vts.createFloatVector(orderer.getVectorTerm()); return List.of(searchInternal(context, qv, keyRange, limit, 0)); } @@ -310,7 +310,7 @@ public CloseableIterator orderResultsBy(QueryContext cont relevantOrdinals.size(), keys.size(), maxBruteForceRows, graph.size(), limit); // convert the expression value to query vector - var qv = vts.createFloatVector(orderer.vector); + var qv = vts.createFloatVector(orderer.getVectorTerm()); // brute force path if (keysInGraph.size() <= maxBruteForceRows) { diff --git a/src/java/org/apache/cassandra/index/sai/memory/TrieMemoryIndex.java b/src/java/org/apache/cassandra/index/sai/memory/TrieMemoryIndex.java index 57517d9192d0..3c7f945a4194 100644 --- a/src/java/org/apache/cassandra/index/sai/memory/TrieMemoryIndex.java +++ b/src/java/org/apache/cassandra/index/sai/memory/TrieMemoryIndex.java @@ -662,6 +662,13 @@ public void close() throws IOException } } + /** + * Iterator that provides ordered access to all indexed terms and their associated primary keys + * in the TrieMemoryIndex. For each term in the index, yields PrimaryKeyWithSortKey objects that + * combine a primary key with its associated term. + *

+ * A more verbose name could be KeysMatchingTermsByTermIterator. + */ private class AllTermsIterator extends AbstractIterator { private final Iterator> iterator; diff --git a/src/java/org/apache/cassandra/index/sai/memory/TrieMemtableIndex.java b/src/java/org/apache/cassandra/index/sai/memory/TrieMemtableIndex.java index 5a42468e2c2b..fefffa8d5d83 100644 --- a/src/java/org/apache/cassandra/index/sai/memory/TrieMemtableIndex.java +++ b/src/java/org/apache/cassandra/index/sai/memory/TrieMemtableIndex.java @@ -22,6 +22,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Comparator; +import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Objects; @@ -32,22 +33,30 @@ import com.google.common.base.Preconditions; import com.google.common.util.concurrent.Runnables; +import org.apache.cassandra.cql3.Operator; import org.apache.cassandra.db.Clustering; +import org.apache.cassandra.db.DataRange; import org.apache.cassandra.db.DecoratedKey; import org.apache.cassandra.db.PartitionPosition; +import org.apache.cassandra.db.RegularAndStaticColumns; +import org.apache.cassandra.db.filter.ColumnFilter; import org.apache.cassandra.db.marshal.AbstractType; import org.apache.cassandra.db.memtable.Memtable; import org.apache.cassandra.db.memtable.ShardBoundaries; import org.apache.cassandra.db.memtable.TrieMemtable; +import org.apache.cassandra.db.rows.Row; import org.apache.cassandra.dht.AbstractBounds; +import org.apache.cassandra.dht.Bounds; import org.apache.cassandra.index.sai.IndexContext; import org.apache.cassandra.index.sai.QueryContext; import org.apache.cassandra.index.sai.disk.format.Version; -import org.apache.cassandra.index.sai.iterators.KeyRangeLazyIterator; import org.apache.cassandra.index.sai.iterators.KeyRangeConcatIterator; +import org.apache.cassandra.index.sai.iterators.KeyRangeIntersectionIterator; import org.apache.cassandra.index.sai.iterators.KeyRangeIterator; +import org.apache.cassandra.index.sai.iterators.KeyRangeLazyIterator; import org.apache.cassandra.index.sai.plan.Expression; import org.apache.cassandra.index.sai.plan.Orderer; +import org.apache.cassandra.index.sai.utils.BM25Utils; import org.apache.cassandra.index.sai.utils.PrimaryKey; import org.apache.cassandra.index.sai.utils.PrimaryKeyWithByteComparable; import org.apache.cassandra.index.sai.utils.PrimaryKeyWithSortKey; @@ -236,15 +245,45 @@ public List> orderBy(QueryContext query int startShard = boundaries.getShardForToken(keyRange.left.getToken()); int endShard = keyRange.right.isMinimum() ? boundaries.shardCount() - 1 : boundaries.getShardForToken(keyRange.right.getToken()); - var iterators = new ArrayList>(endShard - startShard + 1); - - for (int shard = startShard; shard <= endShard; ++shard) + if (!orderer.isBM25()) { - assert rangeIndexes[shard] != null; - iterators.add(rangeIndexes[shard].orderBy(orderer, slice)); + var iterators = new ArrayList>(endShard - startShard + 1); + for (int shard = startShard; shard <= endShard; ++shard) + { + assert rangeIndexes[shard] != null; + iterators.add(rangeIndexes[shard].orderBy(orderer, slice)); + } + return iterators; } - return iterators; + // BM25 + var queryTerms = orderer.getQueryTerms(); + + // Intersect iterators to find documents containing all terms + var termIterators = keyIteratorsPerTerm(queryContext, keyRange, queryTerms); + var intersectedIterator = KeyRangeIntersectionIterator.builder(termIterators).build(); + + // Compute BM25 scores + var docStats = computeDocumentFrequencies(queryContext, queryTerms); + return List.of(BM25Utils.computeScores(intersectedIterator, + queryTerms, + docStats, + indexContext, + memtable, + this::getCellForKey)); + } + + private List keyIteratorsPerTerm(QueryContext queryContext, AbstractBounds keyRange, List queryTerms) + { + List termIterators = new ArrayList<>(queryTerms.size()); + for (ByteBuffer term : queryTerms) + { + Expression expr = new Expression(indexContext); + expr.add(Operator.ANALYZER_MATCHES, term); + KeyRangeIterator iterator = search(queryContext, expr, keyRange, Integer.MAX_VALUE); + termIterators.add(iterator); + } + return termIterators; } @Override @@ -256,32 +295,99 @@ public long estimateMatchingRowsCount(Expression expression, AbstractBounds orderResultsBy(QueryContext context, List keys, Orderer orderer, int limit) + public CloseableIterator orderResultsBy(QueryContext queryContext, List keys, Orderer orderer, int limit) { if (keys.isEmpty()) return CloseableIterator.emptyIterator(); - return SortingIterator.createCloseable( - orderer.getComparator(), - keys, - key -> + + if (!orderer.isBM25()) + { + return SortingIterator.createCloseable( + orderer.getComparator(), + keys, + key -> + { + var partition = memtable.getPartition(key.partitionKey()); + if (partition == null) + return null; + var row = partition.getRow(key.clustering()); + if (row == null) + return null; + var cell = row.getCell(indexContext.getDefinition()); + if (cell == null) + return null; + + // We do two kinds of encoding... it'd be great to make this more straight forward, but this is what + // we have for now. I leave it to the reader to inspect the two methods to see the nuanced differences. + var encoding = encode(TypeUtil.encode(cell.buffer(), validator)); + return new PrimaryKeyWithByteComparable(indexContext, memtable, key, encoding); + }, + Runnables.doNothing() + ); + } + + // BM25 + var queryTerms = orderer.getQueryTerms(); + var docStats = computeDocumentFrequencies(queryContext, queryTerms); + return BM25Utils.computeScores(keys.iterator(), + queryTerms, + docStats, + indexContext, + memtable, + this::getCellForKey); + } + + /** + * Count document frequencies for each term using brute force + */ + private BM25Utils.DocStats computeDocumentFrequencies(QueryContext queryContext, List queryTerms) + { + var termIterators = keyIteratorsPerTerm(queryContext, Bounds.unbounded(indexContext.getPartitioner()), queryTerms); + var documentFrequencies = new HashMap(); + for (int i = 0; i < queryTerms.size(); i++) + { + // KeyRangeIterator.getMaxKeys is not accurate enough, we have to count them + long keys = 0; + for (var it = termIterators.get(i); it.hasNext(); it.next()) + keys++; + documentFrequencies.put(queryTerms.get(i), keys); + } + long docCount = 0; + + // count all documents in the queried column + try (var it = memtable.makePartitionIterator(ColumnFilter.selection(RegularAndStaticColumns.of(indexContext.getDefinition())), + DataRange.allData(memtable.metadata().partitioner))) + { + while (it.hasNext()) { - var partition = memtable.getPartition(key.partitionKey()); - if (partition == null) - return null; - var row = partition.getRow(key.clustering()); - if (row == null) - return null; - var cell = row.getCell(indexContext.getDefinition()); - if (cell == null) - return null; - - // We do two kinds of encoding... it'd be great to make this more straight forward, but this is what - // we have for now. I leave it to the reader to inspect the two methods to see the nuanced differences. - var encoding = encode(TypeUtil.encode(cell.buffer(), validator)); - return new PrimaryKeyWithByteComparable(indexContext, memtable, key, encoding); - }, - Runnables.doNothing() - ); + var partitions = it.next(); + while (partitions.hasNext()) + { + var unfiltered = partitions.next(); + if (!unfiltered.isRow()) + continue; + var row = (Row) unfiltered; + var cell = row.getCell(indexContext.getDefinition()); + if (cell == null) + continue; + + docCount++; + } + } + } + return new BM25Utils.DocStats(documentFrequencies, docCount); + } + + @Nullable + private org.apache.cassandra.db.rows.Cell getCellForKey(PrimaryKey key) + { + var partition = memtable.getPartition(key.partitionKey()); + if (partition == null) + return null; + var row = partition.getRow(key.clustering()); + if (row == null) + return null; + return row.getCell(indexContext.getDefinition()); } private ByteComparable encode(ByteBuffer input) diff --git a/src/java/org/apache/cassandra/index/sai/plan/Expression.java b/src/java/org/apache/cassandra/index/sai/plan/Expression.java index a1bd9acddc4b..aac8e240829b 100644 --- a/src/java/org/apache/cassandra/index/sai/plan/Expression.java +++ b/src/java/org/apache/cassandra/index/sai/plan/Expression.java @@ -101,6 +101,7 @@ public static Op valueOf(Operator operator) return IN; case ANN: + case BM25: case ORDER_BY_ASC: case ORDER_BY_DESC: return ORDER_BY; @@ -250,6 +251,7 @@ public Expression add(Operator op, ByteBuffer value) boundedAnnEuclideanDistanceThreshold = GeoUtil.amplifiedEuclideanSimilarityThreshold(lower.value.vector, searchRadiusMeters); break; case ANN: + case BM25: case ORDER_BY_ASC: case ORDER_BY_DESC: // If we alread have an operation on the column, we don't need to set the ORDER_BY op because diff --git a/src/java/org/apache/cassandra/index/sai/plan/Operation.java b/src/java/org/apache/cassandra/index/sai/plan/Operation.java index abc9735ce510..71cd781def91 100644 --- a/src/java/org/apache/cassandra/index/sai/plan/Operation.java +++ b/src/java/org/apache/cassandra/index/sai/plan/Operation.java @@ -93,7 +93,7 @@ protected static ListMultimap analyzeGroup(QueryCont analyzer.reset(e.getIndexValue()); // EQ/LIKE_*/NOT_EQ can have multiple expressions e.g. text = "Hello World", - // becomes text = "Hello" OR text = "World" because "space" is always interpreted as a split point (by analyzer), + // becomes text = "Hello" AND text = "World" because "space" is always interpreted as a split point (by analyzer), // CONTAINS/CONTAINS_KEY are always treated as multiple expressions since they currently only targetting // collections, NOT_EQ is made an independent expression only in case of pre-existing multiple EQ expressions, or // if there is no EQ operations and NOT_EQ is met or a single NOT_EQ expression present, @@ -102,6 +102,7 @@ protected static ListMultimap analyzeGroup(QueryCont boolean isMultiExpression = columnIsMultiExpression.getOrDefault(e.column(), Boolean.FALSE); switch (e.operator()) { + // case BM25: leave it at the default of `false` case EQ: // EQ operator will always be a multiple expression because it is being used by map entries isMultiExpression = indexContext.isNonFrozenCollection(); diff --git a/src/java/org/apache/cassandra/index/sai/plan/Orderer.java b/src/java/org/apache/cassandra/index/sai/plan/Orderer.java index e23c83ed91e2..af202bae5f5b 100644 --- a/src/java/org/apache/cassandra/index/sai/plan/Orderer.java +++ b/src/java/org/apache/cassandra/index/sai/plan/Orderer.java @@ -19,9 +19,12 @@ package org.apache.cassandra.index.sai.plan; import java.nio.ByteBuffer; +import java.util.ArrayList; import java.util.Arrays; import java.util.Comparator; import java.util.EnumSet; +import java.util.HashSet; +import java.util.List; import java.util.stream.Collectors; import javax.annotation.Nullable; @@ -41,12 +44,15 @@ public class Orderer { // The list of operators that are valid for order by clauses. static final EnumSet ORDER_BY_OPERATORS = EnumSet.of(Operator.ANN, + Operator.BM25, Operator.ORDER_BY_ASC, Operator.ORDER_BY_DESC); public final IndexContext context; public final Operator operator; - public final float[] vector; + public final ByteBuffer term; + private float[] vector; + private List queryTerms; /** * Create an orderer for the given index context, operator, and term. @@ -59,7 +65,7 @@ public Orderer(IndexContext context, Operator operator, ByteBuffer term) this.context = context; assert ORDER_BY_OPERATORS.contains(operator) : "Invalid operator for order by clause " + operator; this.operator = operator; - this.vector = context.getValidator().isVector() ? TypeUtil.decomposeVector(context.getValidator(), term) : null; + this.term = term; } public String getIndexName() @@ -75,8 +81,8 @@ public boolean isAscending() public Comparator getComparator() { - // ANN's PrimaryKeyWithSortKey is always descending, so we use the natural order for the priority queue - return isAscending() || isANN() ? Comparator.naturalOrder() : Comparator.reverseOrder(); + // ANN/BM25's PrimaryKeyWithSortKey is always descending, so we use the natural order for the priority queue + return (isAscending() || isANN() || isBM25()) ? Comparator.naturalOrder() : Comparator.reverseOrder(); } public boolean isLiteral() @@ -89,6 +95,11 @@ public boolean isANN() return operator == Operator.ANN; } + public boolean isBM25() + { + return operator == Operator.BM25; + } + @Nullable public static Orderer from(SecondaryIndexManager indexManager, RowFilter filter) { @@ -110,8 +121,38 @@ public static boolean isFilterExpressionOrderer(RowFilter.Expression expression) public String toString() { String direction = isAscending() ? "ASC" : "DESC"; - return isANN() - ? context.getColumnName() + " ANN OF " + Arrays.toString(vector) + ' ' + direction - : context.getColumnName() + ' ' + direction; + if (isANN()) + return context.getColumnName() + " ANN OF " + Arrays.toString(getVectorTerm()) + ' ' + direction; + if (isBM25()) + return context.getColumnName() + " BM25 OF " + TypeUtil.getString(term, context.getValidator()) + ' ' + direction; + return context.getColumnName() + ' ' + direction; + } + + public float[] getVectorTerm() + { + if (vector == null) + vector = TypeUtil.decomposeVector(context.getValidator(), term); + return vector; + } + + public List getQueryTerms() + { + if (queryTerms != null) + return queryTerms; + + var queryAnalyzer = context.getQueryAnalyzerFactory().create(); + // Split query into terms + var uniqueTerms = new HashSet(); + queryAnalyzer.reset(term); + try + { + queryAnalyzer.forEachRemaining(uniqueTerms::add); + } + finally + { + queryAnalyzer.end(); + } + queryTerms = new ArrayList<>(uniqueTerms); + return queryTerms; } } diff --git a/src/java/org/apache/cassandra/index/sai/plan/Plan.java b/src/java/org/apache/cassandra/index/sai/plan/Plan.java index d34795cd5c2a..877431006cd0 100644 --- a/src/java/org/apache/cassandra/index/sai/plan/Plan.java +++ b/src/java/org/apache/cassandra/index/sai/plan/Plan.java @@ -1244,10 +1244,12 @@ protected double estimateSelectivity() @Override protected KeysIterationCost estimateCost() { - return ordering.isANN() - ? estimateAnnSortCost() - : estimateGlobalSortCost(); - + if (ordering.isANN()) + return estimateAnnSortCost(); + else if (ordering.isBM25()) + return estimateBm25SortCost(); + else + return estimateGlobalSortCost(); } private KeysIterationCost estimateAnnSortCost() @@ -1264,6 +1266,21 @@ private KeysIterationCost estimateAnnSortCost() return new KeysIterationCost(expectedKeys, initCost, searchCost); } + private KeysIterationCost estimateBm25SortCost() + { + double expectedKeys = access.expectedAccessCount(source.expectedKeys()); + + int termCount = ordering.getQueryTerms().size(); + // all of the cost for BM25 is up front since the index doesn't give us the information we need + // to return results in order, in isolation. The big cost is reading the indexed cells out of + // the sstables. + // VSTODO if we had stats on cell size _per column_ we could usefully include ROW_BYTE_COST + double initCost = source.fullCost() + + source.expectedKeys() * (hrs(ROW_CELL_COST) + ROW_CELL_COST) + + termCount * BM25_SCORE_COST; + return new KeysIterationCost(expectedKeys, initCost, 0); + } + private KeysIterationCost estimateGlobalSortCost() { return new KeysIterationCost(source.expectedKeys(), @@ -1297,36 +1314,30 @@ protected KeysSort withAccess(Access access) } /** - * Returns all keys in ANN order. - * Contrary to {@link KeysSort}, there is no input node here and the output is generated lazily. + * Base class for index scans that return results in a computed order (ANN, BM25) + * rather than the natural index order. */ - final static class AnnIndexScan extends Leaf + abstract static class ScoredIndexScan extends Leaf { final Orderer ordering; - protected AnnIndexScan(Factory factory, int id, Access access, Orderer ordering) + protected ScoredIndexScan(Factory factory, int id, Access access, Orderer ordering) { super(factory, id, access); this.ordering = ordering; } + @Nullable @Override - protected KeysIterationCost estimateCost() + protected Orderer ordering() { - double expectedKeys = access.expectedAccessCount(factory.tableMetrics.rows); - int expectedKeysInt = Math.max(1, (int) Math.ceil(expectedKeys)); - double searchCost = factory.costEstimator.estimateAnnSearchCost(ordering, - expectedKeysInt, - factory.tableMetrics.rows); - double initCost = 0; // negligible - return new KeysIterationCost(expectedKeys, initCost, searchCost); + return ordering; } - @Nullable @Override - protected Orderer ordering() + protected double estimateSelectivity() { - return ordering; + return 1.0; } @Override @@ -1335,6 +1346,30 @@ protected Iterator execute(Executor executor) int softLimit = max(1, round((float) access.expectedAccessCount(factory.tableMetrics.rows))); return executor.getTopKRows((Expression) null, softLimit); } + } + + /** + * Returns all keys in ANN order. + * Contrary to {@link KeysSort}, there is no input node here and the output is generated lazily. + */ + final static class AnnIndexScan extends ScoredIndexScan + { + protected AnnIndexScan(Factory factory, int id, Access access, Orderer ordering) + { + super(factory, id, access, ordering); + } + + @Override + protected KeysIterationCost estimateCost() + { + double expectedKeys = access.expectedAccessCount(factory.tableMetrics.rows); + int expectedKeysInt = Math.max(1, (int) Math.ceil(expectedKeys)); + double searchCost = factory.costEstimator.estimateAnnSearchCost(ordering, + expectedKeysInt, + factory.tableMetrics.rows); + double initCost = 0; // negligible + return new KeysIterationCost(expectedKeys, initCost, searchCost); + } @Override protected KeysIteration withAccess(Access access) @@ -1343,11 +1378,39 @@ protected KeysIteration withAccess(Access access) ? this : new AnnIndexScan(factory, id, access, ordering); } + } + /** + * Returns all keys in BM25 order. + * Like AnnIndexScan, this generates results lazily without an input node. + */ + final static class Bm25IndexScan extends ScoredIndexScan + { + protected Bm25IndexScan(Factory factory, int id, Access access, Orderer ordering) + { + super(factory, id, access, ordering); + } + + @Nonnull @Override - protected double estimateSelectivity() + protected KeysIterationCost estimateCost() { - return 1.0; + double expectedKeys = access.expectedAccessCount(factory.tableMetrics.rows); + int expectedKeysInt = Math.max(1, (int) Math.ceil(expectedKeys)); + + int termCount = ordering.getQueryTerms().size(); + double initCost = expectedKeysInt * (hrs(ROW_CELL_COST) + ROW_CELL_COST) + + termCount * BM25_SCORE_COST; + + return new KeysIterationCost(expectedKeys, initCost, 0); + } + + @Override + protected KeysIteration withAccess(Access access) + { + return Objects.equals(access, this.access) + ? this + : new Bm25IndexScan(factory, id, access, ordering); } } @@ -1665,6 +1728,8 @@ private KeysIteration indexScan(Expression predicate, long matchingKeysCount, Or if (ordering != null) if (ordering.isANN()) return new AnnIndexScan(this, id, defaultAccess, ordering); + else if (ordering.isBM25()) + return new Bm25IndexScan(this, id, defaultAccess, ordering); else if (ordering.isLiteral()) return new LiteralIndexScan(this, id, predicate, matchingKeysCount, defaultAccess, ordering); else @@ -1919,6 +1984,9 @@ public static class CostCoefficients /** Additional cost added to row fetch cost per each serialized byte of the row */ public final static double ROW_BYTE_COST = 0.005; + + /** Cost to perform BM25 scoring, per query term */ + public final static double BM25_SCORE_COST = 0.5; } /** Convenience builder for building intersection and union nodes */ diff --git a/src/java/org/apache/cassandra/index/sai/plan/QueryController.java b/src/java/org/apache/cassandra/index/sai/plan/QueryController.java index 20191fb8d830..ae8147cba020 100644 --- a/src/java/org/apache/cassandra/index/sai/plan/QueryController.java +++ b/src/java/org/apache/cassandra/index/sai/plan/QueryController.java @@ -201,6 +201,11 @@ public TableMetadata metadata() return command.metadata(); } + public ReadCommand command() + { + return command; + } + RowFilter.FilterElement filterOperation() { // NOTE: we cannot remove the order by filter expression here yet because it is used in the FilterTree class @@ -859,6 +864,7 @@ private long estimateMatchingRowCount(Expression predicate) switch (predicate.getOp()) { case EQ: + case MATCH: case CONTAINS_KEY: case CONTAINS_VALUE: case NOT_EQ: diff --git a/src/java/org/apache/cassandra/index/sai/plan/StorageAttachedIndexSearcher.java b/src/java/org/apache/cassandra/index/sai/plan/StorageAttachedIndexSearcher.java index cee2d69978fd..cb54764ec7a5 100644 --- a/src/java/org/apache/cassandra/index/sai/plan/StorageAttachedIndexSearcher.java +++ b/src/java/org/apache/cassandra/index/sai/plan/StorageAttachedIndexSearcher.java @@ -29,12 +29,10 @@ import java.util.Queue; import java.util.function.Supplier; import java.util.stream.Collectors; - import javax.annotation.Nonnull; import javax.annotation.Nullable; import com.google.common.base.Preconditions; - import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -44,8 +42,13 @@ import org.apache.cassandra.db.PartitionPosition; import org.apache.cassandra.db.ReadCommand; import org.apache.cassandra.db.ReadExecutionController; +import org.apache.cassandra.db.filter.ColumnFilter; +import org.apache.cassandra.db.marshal.FloatType; import org.apache.cassandra.db.partitions.UnfilteredPartitionIterator; import org.apache.cassandra.db.rows.AbstractUnfilteredRowIterator; +import org.apache.cassandra.db.rows.BTreeRow; +import org.apache.cassandra.db.rows.BufferCell; +import org.apache.cassandra.db.rows.ColumnData; import org.apache.cassandra.db.rows.Row; import org.apache.cassandra.db.rows.Unfiltered; import org.apache.cassandra.db.rows.UnfilteredRowIterator; @@ -59,13 +62,16 @@ import org.apache.cassandra.index.sai.iterators.KeyRangeIterator; import org.apache.cassandra.index.sai.metrics.TableQueryMetrics; import org.apache.cassandra.index.sai.utils.PrimaryKey; -import org.apache.cassandra.index.sai.utils.RangeUtil; +import org.apache.cassandra.index.sai.utils.PrimaryKeyWithScore; import org.apache.cassandra.index.sai.utils.PrimaryKeyWithSortKey; +import org.apache.cassandra.index.sai.utils.RangeUtil; import org.apache.cassandra.io.util.FileUtils; +import org.apache.cassandra.schema.ColumnMetadata; import org.apache.cassandra.schema.TableMetadata; import org.apache.cassandra.utils.AbstractIterator; import org.apache.cassandra.utils.CloseableIterator; import org.apache.cassandra.utils.FBUtilities; +import org.apache.cassandra.utils.btree.BTree; public class StorageAttachedIndexSearcher implements Index.Searcher { @@ -105,19 +111,17 @@ public UnfilteredPartitionIterator search(ReadExecutionController executionContr // Can't check for `command.isTopK()` because the planner could optimize sorting out Orderer ordering = plan.ordering(); - if (ordering != null) - { - assert !(keysIterator instanceof KeyRangeIterator); - var scoredKeysIterator = (CloseableIterator) keysIterator; - var result = new ScoreOrderedResultRetriever(scoredKeysIterator, filterTree, controller, - executionController, queryContext, command.limits().count()); - return (UnfilteredPartitionIterator) new TopKProcessor(command).filter(result); - } - else + if (ordering == null) { assert keysIterator instanceof KeyRangeIterator; return new ResultRetriever((KeyRangeIterator) keysIterator, filterTree, controller, executionController, queryContext); } + + assert !(keysIterator instanceof KeyRangeIterator); + var scoredKeysIterator = (CloseableIterator) keysIterator; + var result = new ScoreOrderedResultRetriever(scoredKeysIterator, filterTree, controller, + executionController, queryContext, command.limits().count()); + return (UnfilteredPartitionIterator) new TopKProcessor(command).filter(result); } catch (Throwable t) { @@ -496,26 +500,27 @@ public UnfilteredRowIterator computeNext() */ private void fillPendingRows() { + // Group PKs by source sstable/memtable + var groupedKeys = new HashMap>(); // We always want to get at least 1. int rowsToRetrieve = Math.max(1, softLimit - returnedRowCount); - var keys = new HashMap>(); // We want to get the first unique `rowsToRetrieve` keys to materialize // Don't pass the priority queue here because it is more efficient to add keys in bulk - fillKeys(keys, rowsToRetrieve, null); + fillKeys(groupedKeys, rowsToRetrieve, null); // Sort the primary keys by PrK order, just in case that helps with cache and disk efficiency - var primaryKeyPriorityQueue = new PriorityQueue<>(keys.keySet()); + var primaryKeyPriorityQueue = new PriorityQueue<>(groupedKeys.keySet()); - while (!keys.isEmpty()) + while (!groupedKeys.isEmpty()) { - var primaryKey = primaryKeyPriorityQueue.poll(); - var primaryKeyWithSortKeys = keys.remove(primaryKey); - var partitionIterator = readAndValidatePartition(primaryKey, primaryKeyWithSortKeys); + var pk = primaryKeyPriorityQueue.poll(); + var sourceKeys = groupedKeys.remove(pk); + var partitionIterator = readAndValidatePartition(pk, sourceKeys); if (partitionIterator != null) pendingRows.add(partitionIterator); else // The current primaryKey did not produce a partition iterator. We know the caller will need // `rowsToRetrieve` rows, so we get the next unique key and add it to the queue. - fillKeys(keys, 1, primaryKeyPriorityQueue); + fillKeys(groupedKeys, 1, primaryKeyPriorityQueue); } } @@ -523,21 +528,21 @@ private void fillPendingRows() * Fills the keys map with the next `count` unique primary keys that are in the keys produced by calling * {@link #nextSelectedKeyInRange()}. We map PrimaryKey to List because the same * primary key can be in the result set multiple times, but with different source tables. - * @param keys the map to fill + * @param groupedKeys the map to fill * @param count the number of unique PrimaryKeys to consume from the iterator * @param primaryKeyPriorityQueue the priority queue to add new keys to. If the queue is null, we do not add * keys to the queue. */ - private void fillKeys(Map> keys, int count, PriorityQueue primaryKeyPriorityQueue) + private void fillKeys(Map> groupedKeys, int count, PriorityQueue primaryKeyPriorityQueue) { - int initialSize = keys.size(); - while (keys.size() - initialSize < count) + int initialSize = groupedKeys.size(); + while (groupedKeys.size() - initialSize < count) { var primaryKeyWithSortKey = nextSelectedKeyInRange(); if (primaryKeyWithSortKey == null) return; var nextPrimaryKey = primaryKeyWithSortKey.primaryKey(); - var accumulator = keys.computeIfAbsent(nextPrimaryKey, k -> new ArrayList<>()); + var accumulator = groupedKeys.computeIfAbsent(nextPrimaryKey, k -> new ArrayList<>()); if (primaryKeyPriorityQueue != null && accumulator.isEmpty()) primaryKeyPriorityQueue.add(nextPrimaryKey); accumulator.add(primaryKeyWithSortKey); @@ -577,15 +582,29 @@ private boolean isInRange(DecoratedKey key) return null; } - public UnfilteredRowIterator readAndValidatePartition(PrimaryKey key, List primaryKeys) + /** + * Reads and validates a partition for a given primary key against its sources. + *

+ * @param pk The primary key of the partition to read and validate + * @param sourceKeys A list of PrimaryKeyWithSortKey objects associated with the primary key. + * Multiple sort keys can exist for the same primary key when data comes from different + * sstables or memtables. + * + * @return An UnfilteredRowIterator containing the validated partition data, or null if: + * - The key has already been processed + * - The partition does not pass index filters + * - The partition contains no valid rows + * - The row data does not match the index metadata for any of the provided primary keys + */ + public UnfilteredRowIterator readAndValidatePartition(PrimaryKey pk, List sourceKeys) { // If we've already processed the key, we can skip it. Because the score ordered iterator does not // deduplicate rows, we could see dupes if a row is in the ordering index multiple times. This happens // in the case of dupes and of overwrites. - if (processedKeys.contains(key)) + if (processedKeys.contains(pk)) return null; - try (UnfilteredRowIterator partition = controller.getPartition(key, view, executionController)) + try (UnfilteredRowIterator partition = controller.getPartition(pk, view, executionController)) { queryContext.addPartitionsRead(1); queryContext.checkpoint(); @@ -594,7 +613,7 @@ public UnfilteredRowIterator readAndValidatePartition(PrimaryKey key, List primaryKeysWithScore, ReadCommand command) { super(partition.metadata(), partition.partitionKey(), @@ -637,7 +668,47 @@ public PrimaryKeyIterator(UnfilteredRowIterator partition, Row staticRow, Unfilt partition.isReverseOrder(), partition.stats()); - row = content; + assert !primaryKeysWithScore.isEmpty(); + var isScoredRow = primaryKeysWithScore.get(0) instanceof PrimaryKeyWithScore; + if (!content.isRow() || !isScoredRow) + { + this.row = content; + return; + } + + // When +score is added on the coordinator side, it's represented as a PrecomputedColumnFilter + // even in a 'SELECT *' because WCF is not capable of representing synthetic columns. + // This can be simplified when we remove ANN_USE_SYNTHETIC_SCORE + var tm = metadata(); + var scoreColumn = ColumnMetadata.syntheticColumn(tm.keyspace, + tm.name, + ColumnMetadata.SYNTHETIC_SCORE_ID, + FloatType.instance); + var isScoreFetched = command.columnFilter().fetchesExplicitly(scoreColumn); + if (!isScoreFetched) + { + this.row = content; + return; + } + + // Clone the original Row + Row originalRow = (Row) content; + ArrayList columnData = new ArrayList<>(originalRow.columnCount() + 1); + columnData.addAll(originalRow.columnData()); + + // inject +score as a new column + var pkWithScore = (PrimaryKeyWithScore) primaryKeysWithScore.get(0); + columnData.add(BufferCell.live(scoreColumn, + FBUtilities.nowInSeconds(), + FloatType.instance.decompose(pkWithScore.indexScore))); + + this.row = BTreeRow.create(originalRow.clustering(), + originalRow.primaryKeyLivenessInfo(), + originalRow.deletion(), + BTree.builder(ColumnData.comparator) + .auto(true) + .addAll(columnData) + .build()); } @Override @@ -649,18 +720,6 @@ protected Unfiltered computeNext() return row; } } - - @Override - public TableMetadata metadata() - { - return controller.metadata(); - } - - public void close() - { - FileUtils.closeQuietly(scoredPrimaryKeyIterator); - controller.finish(); - } } private static UnfilteredRowIterator applyIndexFilter(UnfilteredRowIterator partition, FilterTree tree, QueryContext queryContext) diff --git a/src/java/org/apache/cassandra/index/sai/plan/TopKProcessor.java b/src/java/org/apache/cassandra/index/sai/plan/TopKProcessor.java index 147acf41eaef..3d5bce3d607a 100644 --- a/src/java/org/apache/cassandra/index/sai/plan/TopKProcessor.java +++ b/src/java/org/apache/cassandra/index/sai/plan/TopKProcessor.java @@ -48,11 +48,13 @@ import org.apache.cassandra.db.Keyspace; import org.apache.cassandra.db.ReadCommand; import org.apache.cassandra.db.filter.RowFilter; +import org.apache.cassandra.db.marshal.FloatType; import org.apache.cassandra.db.partitions.BasePartitionIterator; import org.apache.cassandra.db.partitions.ParallelCommandProcessor; import org.apache.cassandra.db.partitions.PartitionIterator; import org.apache.cassandra.db.partitions.UnfilteredPartitionIterator; import org.apache.cassandra.db.rows.BaseRowIterator; +import org.apache.cassandra.db.rows.Cell; import org.apache.cassandra.db.rows.Row; import org.apache.cassandra.db.rows.Unfiltered; import org.apache.cassandra.index.Index; @@ -99,6 +101,7 @@ public class TopKProcessor private final IndexContext indexContext; private final RowFilter.Expression expression; private final VectorFloat queryVector; + private final ColumnMetadata scoreColumn; private final int limit; @@ -106,18 +109,19 @@ public TopKProcessor(ReadCommand command) { this.command = command; - Pair annIndexAndExpression = findTopKIndexContext(); + Pair indexAndExpression = findTopKIndexContext(); // this can happen in case an index was dropped after the query was initiated - if (annIndexAndExpression == null) + if (indexAndExpression == null) throw invalidRequest(INDEX_MAY_HAVE_BEEN_DROPPED); - this.indexContext = annIndexAndExpression.left; - this.expression = annIndexAndExpression.right; + this.indexContext = indexAndExpression.left; + this.expression = indexAndExpression.right; if (expression.operator() == Operator.ANN) this.queryVector = vts.createFloatVector(TypeUtil.decomposeVector(indexContext, expression.getIndexValue().duplicate())); else this.queryVector = null; this.limit = command.limits().count(); + this.scoreColumn = ColumnMetadata.syntheticColumn(indexContext.getKeyspace(), indexContext.getTable(), ColumnMetadata.SYNTHETIC_SCORE_ID, FloatType.instance); } /** @@ -163,8 +167,10 @@ private , P extends BaseParti { // priority queue ordered by score in descending order Comparator> comparator; - if (queryVector != null) + if (expression.operator() == Operator.ANN || expression.operator() == Operator.BM25) + { comparator = Comparator.comparing((Triple t) -> (Float) t.getRight()).reversed(); + } else { comparator = Comparator.comparing(t -> (ByteBuffer) t.getRight(), indexContext.getValidator()); @@ -192,7 +198,7 @@ private , P extends BaseParti executor.maybeExecuteImmediately(() -> { try (var partitionRowIterator = pIter.commandToIterator(command.left(), command.right())) { - future.complete(partitionRowIterator == null ? null : processPartition(partitionRowIterator)); + future.complete(partitionRowIterator == null ? null : processScoredPartition(partitionRowIterator)); } catch (Throwable t) { @@ -240,9 +246,9 @@ private , P extends BaseParti // have to close to move to the next partition, otherwise hasNext() fails try (var partitionRowIterator = partitions.next()) { - if (queryVector != null) + if (expression.operator() == Operator.ANN || expression.operator() == Operator.BM25) { - PartitionResults pr = processPartition(partitionRowIterator); + PartitionResults pr = processScoredPartition(partitionRowIterator); topK.addAll(pr.rows); for (var uf: pr.tombstones) addUnfiltered(unfilteredByPartition, pr.partitionInfo, uf); @@ -255,7 +261,6 @@ private , P extends BaseParti topK.add(Triple.of(PartitionInfo.create(partitionRowIterator), row, row.getCell(expression.column()).buffer())); } } - } } } @@ -291,7 +296,7 @@ void addRow(Triple triple) { /** * Processes a single partition, calculating scores for rows and extracting tombstones. */ - private PartitionResults processPartition(BaseRowIterator partitionRowIterator) { + private PartitionResults processScoredPartition(BaseRowIterator partitionRowIterator) { // Compute key and static row score once per partition DecoratedKey key = partitionRowIterator.partitionKey(); Row staticRow = partitionRowIterator.staticRow(); @@ -357,6 +362,14 @@ private float getScoreForRow(DecoratedKey key, Row row) if ((column.isClusteringColumn() || column.isRegular()) && row.isStatic()) return 0; + var scoreData = row.getColumnData(scoreColumn); + if (scoreData != null) + { + var cell = (Cell) scoreData; + return FloatType.instance.compose(cell.buffer()); + } + + // TODO remove this once we enable ANN_USE_SYNTHETIC_SCORE ByteBuffer value = indexContext.getValueOf(key, row, FBUtilities.nowInSeconds()); if (value != null) { @@ -373,21 +386,24 @@ private Pair findTopKIndexContext() for (RowFilter.Expression expression : command.rowFilter().expressions()) { - StorageAttachedIndex sai = findVectorIndexFor(cfs.indexManager, expression); + StorageAttachedIndex sai = findOrderingIndexFor(cfs.indexManager, expression); if (sai != null) - { return Pair.create(sai.getIndexContext(), expression); - } } return null; } @Nullable - private StorageAttachedIndex findVectorIndexFor(SecondaryIndexManager sim, RowFilter.Expression e) + private StorageAttachedIndex findOrderingIndexFor(SecondaryIndexManager sim, RowFilter.Expression e) { - if (e.operator() != Operator.ANN && e.operator() != Operator.ORDER_BY_ASC && e.operator() != Operator.ORDER_BY_DESC) + if (e.operator() != Operator.ANN + && e.operator() != Operator.BM25 + && e.operator() != Operator.ORDER_BY_ASC + && e.operator() != Operator.ORDER_BY_DESC) + { return null; + } Optional index = sim.getBestIndexFor(e); return (StorageAttachedIndex) index.filter(i -> i instanceof StorageAttachedIndex).orElse(null); diff --git a/src/java/org/apache/cassandra/index/sai/utils/BM25Utils.java b/src/java/org/apache/cassandra/index/sai/utils/BM25Utils.java new file mode 100644 index 000000000000..cc4fa1f66b04 --- /dev/null +++ b/src/java/org/apache/cassandra/index/sai/utils/BM25Utils.java @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.index.sai.utils; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; + +import org.apache.cassandra.db.memtable.Memtable; +import org.apache.cassandra.db.rows.Cell; +import org.apache.cassandra.index.sai.IndexContext; +import org.apache.cassandra.index.sai.analyzer.AbstractAnalyzer; +import org.apache.cassandra.io.sstable.SSTableId; +import org.apache.cassandra.utils.CloseableIterator; + +public class BM25Utils +{ + private static final float K1 = 1.2f; // BM25 term frequency saturation parameter + private static final float B = 0.75f; // BM25 length normalization parameter + + /** + * Term frequencies across all documents. Each document is only counted once. + */ + public static class DocStats + { + // Map of term -> count of docs containing that term + private final Map frequencies; + // total number of docs in the index + private final long docCount; + + public DocStats(Map frequencies, long docCount) + { + this.frequencies = frequencies; + this.docCount = docCount; + } + } + + /** + * Term frequencies within a single document. All instances of a term are counted. + */ + public static class DocTF + { + private final PrimaryKey pk; + private final Map frequencies; + private final int termCount; + + private DocTF(PrimaryKey pk, int termCount, Map frequencies) + { + this.pk = pk; + this.frequencies = frequencies; + this.termCount = termCount; + } + + public int getTermFrequency(ByteBuffer term) + { + return frequencies.getOrDefault(term, 0); + } + + public static DocTF createFromDocument(PrimaryKey pk, + Cell cell, + AbstractAnalyzer docAnalyzer, + Collection queryTerms) + { + int count = 0; + Map frequencies = new HashMap<>(); + + docAnalyzer.reset(cell.buffer()); + try + { + while (docAnalyzer.hasNext()) + { + ByteBuffer term = docAnalyzer.next(); + count++; + if (queryTerms.contains(term)) + frequencies.merge(term, 1, Integer::sum); + } + } + finally + { + docAnalyzer.end(); + } + + return new DocTF(pk, count, frequencies); + } + } + + @FunctionalInterface + public interface CellReader + { + Cell readCell(PrimaryKey pk); + } + + public static CloseableIterator computeScores(Iterator keyIterator, + List queryTerms, + DocStats docStats, + IndexContext indexContext, + Object source, + CellReader cellReader) + { + var docAnalyzer = indexContext.getAnalyzerFactory().create(); + + // data structures for document stats and frequencies + ArrayList documents = new ArrayList<>(); + double totalTermCount = 0; + + // Compute TF within each document + while (keyIterator.hasNext()) + { + var pk = keyIterator.next(); + var cell = cellReader.readCell(pk); + if (cell == null) + continue; + var tf = DocTF.createFromDocument(pk, cell, docAnalyzer, queryTerms); + + // sstable index will only send documents that contain all query terms to this method, + // but memtable is not indexed and will send all documents, so we have to skip documents + // that don't contain all query terms here to preserve consistency with sstable behavior + if (tf.frequencies.size() != queryTerms.size()) + continue; + + documents.add(tf); + + totalTermCount += tf.termCount; + } + + // Calculate average document length + double avgDocLength = !documents.isEmpty() ? totalTermCount / documents.size() : 0.0; + + // Calculate BM25 scores + var scoredDocs = new ArrayList(documents.size()); + for (var doc : documents) + { + double score = 0.0; + for (var queryTerm : queryTerms) + { + int tf = doc.getTermFrequency(queryTerm); + Long df = docStats.frequencies.get(queryTerm); + // we shouldn't have more hits for a term than we counted total documents + assert df <= docStats.docCount : String.format("df=%d, totalDocs=%d", df, docStats.docCount); + + double normalizedTf = tf / (tf + K1 * (1 - B + B * doc.termCount / avgDocLength)); + double idf = Math.log(1 + (docStats.docCount - df + 0.5) / (df + 0.5)); + double deltaScore = normalizedTf * idf; + assert deltaScore >= 0 : String.format("BM25 score for tf=%d, df=%d, tc=%d, totalDocs=%d is %f", + tf, df, doc.termCount, docStats.docCount, deltaScore); + score += deltaScore; + } + if (source instanceof Memtable) + scoredDocs.add(new PrimaryKeyWithScore(indexContext, (Memtable) source, doc.pk, (float) score)); + else if (source instanceof SSTableId) + scoredDocs.add(new PrimaryKeyWithScore(indexContext, (SSTableId) source, doc.pk, (float) score)); + else + throw new IllegalArgumentException("Invalid source " + source.getClass()); + } + + // sort by score (PKWS implements Comparator correctly for us) + Collections.sort(scoredDocs); + + return (CloseableIterator) (CloseableIterator) CloseableIterator.wrap(scoredDocs.iterator()); + } +} diff --git a/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithByteComparable.java b/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithByteComparable.java index 837c032b7952..949eb21d282c 100644 --- a/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithByteComparable.java +++ b/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithByteComparable.java @@ -21,7 +21,9 @@ import java.nio.ByteBuffer; import java.util.Arrays; +import org.apache.cassandra.db.memtable.Memtable; import org.apache.cassandra.index.sai.IndexContext; +import org.apache.cassandra.io.sstable.SSTableId; import org.apache.cassandra.utils.bytecomparable.ByteComparable; import org.apache.cassandra.utils.bytecomparable.ByteSource; import org.apache.cassandra.utils.bytecomparable.ByteSourceInverse; @@ -34,7 +36,13 @@ public class PrimaryKeyWithByteComparable extends PrimaryKeyWithSortKey { private final ByteComparable byteComparable; - public PrimaryKeyWithByteComparable(IndexContext context, Object sourceTable, PrimaryKey primaryKey, ByteComparable byteComparable) + public PrimaryKeyWithByteComparable(IndexContext context, Memtable sourceTable, PrimaryKey primaryKey, ByteComparable byteComparable) + { + super(context, sourceTable, primaryKey); + this.byteComparable = byteComparable; + } + + public PrimaryKeyWithByteComparable(IndexContext context, SSTableId sourceTable, PrimaryKey primaryKey, ByteComparable byteComparable) { super(context, sourceTable, primaryKey); this.byteComparable = byteComparable; diff --git a/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithScore.java b/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithScore.java index 8b7b4acba9c6..a10d6e82549a 100644 --- a/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithScore.java +++ b/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithScore.java @@ -20,7 +20,9 @@ import java.nio.ByteBuffer; +import org.apache.cassandra.db.memtable.Memtable; import org.apache.cassandra.index.sai.IndexContext; +import org.apache.cassandra.io.sstable.SSTableId; /** * A {@link PrimaryKey} that includes a score from a source index. @@ -28,9 +30,15 @@ */ public class PrimaryKeyWithScore extends PrimaryKeyWithSortKey { - private final float indexScore; + public final float indexScore; - public PrimaryKeyWithScore(IndexContext context, Object source, PrimaryKey primaryKey, float indexScore) + public PrimaryKeyWithScore(IndexContext context, Memtable source, PrimaryKey primaryKey, float indexScore) + { + super(context, source, primaryKey); + this.indexScore = indexScore; + } + + public PrimaryKeyWithScore(IndexContext context, SSTableId source, PrimaryKey primaryKey, float indexScore) { super(context, source, primaryKey); this.indexScore = indexScore; diff --git a/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithSortKey.java b/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithSortKey.java index 2e79b0402124..8a171fca4dbe 100644 --- a/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithSortKey.java +++ b/src/java/org/apache/cassandra/index/sai/utils/PrimaryKeyWithSortKey.java @@ -22,9 +22,11 @@ import org.apache.cassandra.db.Clustering; import org.apache.cassandra.db.DecoratedKey; +import org.apache.cassandra.db.memtable.Memtable; import org.apache.cassandra.db.rows.Row; import org.apache.cassandra.dht.Token; import org.apache.cassandra.index.sai.IndexContext; +import org.apache.cassandra.io.sstable.SSTableId; import org.apache.cassandra.utils.bytecomparable.ByteComparable; import org.apache.cassandra.utils.bytecomparable.ByteSource; @@ -41,7 +43,14 @@ public abstract class PrimaryKeyWithSortKey implements PrimaryKey // Either a Memtable reference or an SSTableId reference private final Object sourceTable; - protected PrimaryKeyWithSortKey(IndexContext context, Object sourceTable, PrimaryKey primaryKey) + protected PrimaryKeyWithSortKey(IndexContext context, Memtable sourceTable, PrimaryKey primaryKey) + { + this.context = context; + this.sourceTable = sourceTable; + this.primaryKey = primaryKey; + } + + protected PrimaryKeyWithSortKey(IndexContext context, SSTableId sourceTable, PrimaryKey primaryKey) { this.context = context; this.sourceTable = sourceTable; diff --git a/src/java/org/apache/cassandra/index/sai/utils/RowIdWithScore.java b/src/java/org/apache/cassandra/index/sai/utils/RowIdWithScore.java index b4b1f129567f..c6ed4708e0c8 100644 --- a/src/java/org/apache/cassandra/index/sai/utils/RowIdWithScore.java +++ b/src/java/org/apache/cassandra/index/sai/utils/RowIdWithScore.java @@ -26,7 +26,7 @@ */ public class RowIdWithScore extends RowIdWithMeta { - private final float score; + public final float score; public RowIdWithScore(int segmentRowId, float score) { diff --git a/src/java/org/apache/cassandra/schema/ColumnMetadata.java b/src/java/org/apache/cassandra/schema/ColumnMetadata.java index 38784e72acb5..008813564e57 100644 --- a/src/java/org/apache/cassandra/schema/ColumnMetadata.java +++ b/src/java/org/apache/cassandra/schema/ColumnMetadata.java @@ -71,9 +71,9 @@ public enum ClusteringOrder /** * The type of CQL3 column this definition represents. - * There is 4 main type of CQL3 columns: those parts of the partition key, - * those parts of the clustering columns and amongst the others, regular and - * static ones. + * There are 5 types of columns: those parts of the partition key, + * those parts of the clustering columns and amongst the others, regular, + * static, and synthetic ones. * * IMPORTANT: this enum is serialized as toString() and deserialized by calling * Kind.valueOf(), so do not override toString() or rename existing values. @@ -81,18 +81,22 @@ public enum ClusteringOrder public enum Kind { // NOTE: if adding a new type, must modify comparisonOrder + SYNTHETIC, PARTITION_KEY, CLUSTERING, REGULAR, STATIC; + // it is not possible to add new Kinds after Synthetic without invasive changes to BTreeRow, which + // assumes that complex regulr/static columns are the last ones public boolean isPrimaryKeyKind() { return this == PARTITION_KEY || this == CLUSTERING; } - } + public static final ColumnIdentifier SYNTHETIC_SCORE_ID = ColumnIdentifier.getInterned("+:!score", true); + /** * Whether this is a dropped column. */ @@ -121,10 +125,17 @@ public boolean isPrimaryKeyKind() */ private final long comparisonOrder; + /** + * Bit layout (from most to least significant): + * - Bits 61-63: Kind ordinal (3 bits, supporting up to 8 Kind values) + * - Bit 60: isComplex flag + * - Bits 48-59: position (12 bits, see assert) + * - Bits 0-47: name.prefixComparison (shifted right by 16) + */ private static long comparisonOrder(Kind kind, boolean isComplex, long position, ColumnIdentifier name) { assert position >= 0 && position < 1 << 12; - return (((long) kind.ordinal()) << 61) + return (((long) kind.ordinal()) << 61) | (isComplex ? 1L << 60 : 0) | (position << 48) | (name.prefixComparison >>> 16); @@ -170,6 +181,14 @@ public static ColumnMetadata staticColumn(String keyspace, String table, String return new ColumnMetadata(keyspace, table, ColumnIdentifier.getInterned(name, true), type, NO_POSITION, Kind.STATIC); } + /** + * Creates a new synthetic column metadata instance. + */ + public static ColumnMetadata syntheticColumn(String keyspace, String table, ColumnIdentifier id, AbstractType type) + { + return new ColumnMetadata(keyspace, table, id, type, NO_POSITION, Kind.SYNTHETIC); + } + /** * Rebuild the metadata for a dropped column from its recorded data. * @@ -225,6 +244,7 @@ public ColumnMetadata(String ksName, this.kind = kind; this.position = position; this.cellPathComparator = makeCellPathComparator(kind, type); + assert kind != Kind.SYNTHETIC || cellPathComparator == null; this.cellComparator = cellPathComparator == null ? ColumnData.comparator : new Comparator>() { @Override @@ -461,7 +481,7 @@ public int compareTo(ColumnMetadata other) return 0; if (comparisonOrder != other.comparisonOrder) - return Long.compare(comparisonOrder, other.comparisonOrder); + return Long.compareUnsigned(comparisonOrder, other.comparisonOrder); return this.name.compareTo(other.name); } @@ -593,6 +613,11 @@ public boolean isCounterColumn() return type.isCounter(); } + public boolean isSynthetic() + { + return kind == Kind.SYNTHETIC; + } + public Selector.Factory newSelectorFactory(TableMetadata table, AbstractType expectedType, List defs, VariableSpecifications boundNames) throws InvalidRequestException { return SimpleSelector.newFactory(this, addAndGetIndex(this, defs)); diff --git a/src/java/org/apache/cassandra/schema/TableMetadata.java b/src/java/org/apache/cassandra/schema/TableMetadata.java index ba5e1db8d84d..8885b85fdd43 100644 --- a/src/java/org/apache/cassandra/schema/TableMetadata.java +++ b/src/java/org/apache/cassandra/schema/TableMetadata.java @@ -1122,8 +1122,7 @@ public Builder addStaticColumn(ColumnIdentifier name, AbstractType type) public Builder addColumn(ColumnMetadata column) { - if (columns.containsKey(column.name.bytes)) - throw new IllegalArgumentException(); + assert !columns.containsKey(column.name.bytes) : column.name + " is already present"; switch (column.kind) { diff --git a/src/java/org/apache/cassandra/service/ClientWarn.java b/src/java/org/apache/cassandra/service/ClientWarn.java index 5a6a878681e1..38570a06d2b8 100644 --- a/src/java/org/apache/cassandra/service/ClientWarn.java +++ b/src/java/org/apache/cassandra/service/ClientWarn.java @@ -18,7 +18,9 @@ package org.apache.cassandra.service; import java.util.ArrayList; +import java.util.HashSet; import java.util.List; +import java.util.Set; import io.netty.util.concurrent.FastThreadLocal; import org.apache.cassandra.concurrent.ExecutorLocal; @@ -45,10 +47,18 @@ public void set(State value) } public void warn(String text) + { + warn(text, null); + } + + /** + * Issue the given warning if this is the first time `key` is seen. + */ + public void warn(String text, Object key) { State state = warnLocal.get(); if (state != null) - state.add(text); + state.add(text, key); } public void captureWarnings() @@ -72,11 +82,16 @@ public void resetWarnings() public static class State { private final List warnings = new ArrayList<>(); + private final Set keysAdded = new HashSet<>(); - private void add(String warning) + private void add(String warning, Object key) { if (warnings.size() < FBUtilities.MAX_UNSIGNED_SHORT) + { + if (key != null && !keysAdded.add(key)) + return; warnings.add(maybeTruncate(warning)); + } } private static String maybeTruncate(String warning) diff --git a/test/burn/org/apache/cassandra/index/sai/LongBM25Test.java b/test/burn/org/apache/cassandra/index/sai/LongBM25Test.java new file mode 100644 index 000000000000..49b5b5118540 --- /dev/null +++ b/test/burn/org/apache/cassandra/index/sai/LongBM25Test.java @@ -0,0 +1,251 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.index.sai; + +import java.io.IOException; +import java.net.URISyntaxException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ForkJoinPool; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import org.junit.Before; +import org.junit.Test; +import org.slf4j.Logger; + +import org.apache.cassandra.db.memtable.TrieMemtable; + +public class LongBM25Test extends SAITester +{ + private static final Logger logger = org.slf4j.LoggerFactory.getLogger(LongBM25Test.class); + + private static final List documentLines = new ArrayList<>(); + + static + { + try + { + var cl = LongBM25Test.class.getClassLoader(); + var resourceDir = cl.getResource("bm25"); + if (resourceDir == null) + throw new RuntimeException("Could not find resource directory test/resources/bm25/"); + + var dirPath = java.nio.file.Paths.get(resourceDir.toURI()); + try (var files = java.nio.file.Files.list(dirPath)) + { + files.forEach(file -> { + try (var lines = java.nio.file.Files.lines(file)) + { + lines.map(String::trim) + .filter(line -> !line.isEmpty()) + .forEach(documentLines::add); + } + catch (IOException e) + { + throw new RuntimeException("Failed to read file: " + file, e); + } + }); + } + if (documentLines.isEmpty()) + { + throw new RuntimeException("No document lines loaded from test/resources/bm25/"); + } + } + catch (IOException | URISyntaxException e) + { + throw new RuntimeException("Failed to load test documents", e); + } + } + + KeySet keysInserted = new KeySet(); + private final int threadCount = 12; + + @Before + public void setup() throws Throwable + { + // we don't get loaded until after TM, so we can't affect the very first memtable, + // but this will affect all subsequent ones + TrieMemtable.SHARD_COUNT = 4 * threadCount; + } + + @FunctionalInterface + private interface Op + { + void run(int i) throws Throwable; + } + + public void testConcurrentOps(Op op) throws ExecutionException, InterruptedException + { + createTable("CREATE TABLE %s (key int primary key, value text)"); + // Create analyzed index following BM25Test pattern + createIndex("CREATE CUSTOM INDEX ON %s(value) " + + "USING 'org.apache.cassandra.index.sai.StorageAttachedIndex' " + + "WITH OPTIONS = {" + + "'index_analyzer': '{" + + "\"tokenizer\" : {\"name\" : \"standard\"}, " + + "\"filters\" : [{\"name\" : \"porterstem\"}]" + + "}'}" + ); + + AtomicInteger counter = new AtomicInteger(); + long start = System.currentTimeMillis(); + var fjp = new ForkJoinPool(threadCount); + var keys = IntStream.range(0, 10_000_000).boxed().collect(Collectors.toList()); + Collections.shuffle(keys); + var task = fjp.submit(() -> keys.stream().parallel().forEach(i -> + { + wrappedOp(op, i); + if (counter.incrementAndGet() % 10_000 == 0) + { + var elapsed = System.currentTimeMillis() - start; + logger.info("{} ops in {}ms = {} ops/s", counter.get(), elapsed, counter.get() * 1000.0 / elapsed); + } + if (ThreadLocalRandom.current().nextDouble() < 0.001) + flush(); + })); + fjp.shutdown(); + task.get(); // re-throw + } + + private static void wrappedOp(Op op, Integer i) + { + try + { + op.run(i); + } + catch (Throwable e) + { + throw new RuntimeException(e); + } + } + + private static String randomDocument() + { + var R = ThreadLocalRandom.current(); + int numLines = R.nextInt(5, 51); // 5 to 50 lines inclusive + var selectedLines = new ArrayList(); + + for (int i = 0; i < numLines; i++) + { + selectedLines.add(randomQuery(R)); + } + + return String.join("\n", selectedLines); + } + + private static String randomLine(ThreadLocalRandom R) + { + return documentLines.get(R.nextInt(documentLines.size())); + } + + @Test + public void testConcurrentReadsWritesDeletes() throws ExecutionException, InterruptedException + { + testConcurrentOps(i -> { + var R = ThreadLocalRandom.current(); + if (R.nextDouble() < 0.2 || keysInserted.isEmpty()) + { + var doc = randomDocument(); + execute("INSERT INTO %s (key, value) VALUES (?, ?)", i, doc); + keysInserted.add(i); + } + else if (R.nextDouble() < 0.1) + { + var key = keysInserted.getRandom(); + execute("DELETE FROM %s WHERE key = ?", key); + } + else + { + var line = randomQuery(R); + execute("SELECT * FROM %s ORDER BY value BM25 OF ? LIMIT ?", line, R.nextInt(1, 100)); + } + }); + } + + private static String randomQuery(ThreadLocalRandom R) + { + while (true) + { + var line = randomLine(R); + if (line.chars().anyMatch(Character::isAlphabetic)) + return line; + } + } + + @Test + public void testConcurrentReadsWrites() throws ExecutionException, InterruptedException + { + testConcurrentOps(i -> { + var R = ThreadLocalRandom.current(); + if (R.nextDouble() < 0.1 || keysInserted.isEmpty()) + { + var doc = randomDocument(); + execute("INSERT INTO %s (key, value) VALUES (?, ?)", i, doc); + keysInserted.add(i); + } + else + { + var line = randomQuery(R); + execute("SELECT * FROM %s ORDER BY value BM25 OF ? LIMIT ?", line, R.nextInt(1, 100)); + } + }); + } + + @Test + public void testConcurrentWrites() throws ExecutionException, InterruptedException + { + testConcurrentOps(i -> { + var doc = randomDocument(); + execute("INSERT INTO %s (key, value) VALUES (?, ?)", i, doc); + }); + } + + private static class KeySet + { + private final Map keys = new ConcurrentHashMap<>(); + private final AtomicInteger ordinal = new AtomicInteger(); + + public void add(int key) + { + var i = ordinal.getAndIncrement(); + keys.put(i, key); + } + + public int getRandom() + { + if (isEmpty()) + throw new IllegalStateException(); + var i = ThreadLocalRandom.current().nextInt(ordinal.get()); + // in case there is race with add(key), retry another random + return keys.containsKey(i) ? keys.get(i) : getRandom(); + } + + public boolean isEmpty() + { + return keys.isEmpty(); + } + } +} diff --git a/test/distributed/org/apache/cassandra/distributed/test/sai/BM25DistributedTest.java b/test/distributed/org/apache/cassandra/distributed/test/sai/BM25DistributedTest.java new file mode 100644 index 000000000000..8feb61f36b83 --- /dev/null +++ b/test/distributed/org/apache/cassandra/distributed/test/sai/BM25DistributedTest.java @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.distributed.test.sai; + +import java.util.concurrent.atomic.AtomicInteger; + +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +import org.apache.cassandra.distributed.Cluster; +import org.apache.cassandra.distributed.api.ConsistencyLevel; +import org.apache.cassandra.distributed.test.TestBaseImpl; + +import static org.apache.cassandra.distributed.api.Feature.GOSSIP; +import static org.apache.cassandra.distributed.api.Feature.NETWORK; +import static org.assertj.core.api.Assertions.assertThat; + +public class BM25DistributedTest extends TestBaseImpl +{ + private static final String CREATE_KEYSPACE = "CREATE KEYSPACE %%s WITH replication = {'class': 'SimpleStrategy', 'replication_factor': %d}"; + private static final String CREATE_TABLE = "CREATE TABLE %s (k int PRIMARY KEY, v text)"; + private static final String CREATE_INDEX = "CREATE CUSTOM INDEX ON %s(v) USING 'StorageAttachedIndex' WITH OPTIONS = {'index_analyzer': '{\"tokenizer\" : {\"name\" : \"standard\"}, \"filters\" : [{\"name\" : \"porterstem\"}]}'}"; + + // To get consistent results from BM25 we need to know which docs are evaluated, the easiest way + // to do that is to put all the docs on every replica + private static final int NUM_NODES = 3; + private static final int RF = 3; + + private static Cluster cluster; + private static String table; + + private static final AtomicInteger seq = new AtomicInteger(); + + @BeforeClass + public static void setupCluster() throws Exception + { + cluster = Cluster.build(NUM_NODES) + .withTokenCount(1) + .withDataDirCount(1) + .withConfig(config -> config.with(GOSSIP).with(NETWORK)) + .start(); + + cluster.schemaChange(withKeyspace(String.format(CREATE_KEYSPACE, RF))); + } + + @AfterClass + public static void closeCluster() + { + if (cluster != null) + cluster.close(); + } + + @Before + public void before() + { + table = "table_" + seq.getAndIncrement(); + cluster.schemaChange(formatQuery(CREATE_TABLE)); + cluster.schemaChange(formatQuery(CREATE_INDEX)); + SAIUtil.waitForIndexQueryable(cluster, KEYSPACE); + } + + @Test + public void testTermFrequencyOrdering() + { + // Insert documents with varying frequencies of the term "apple" + execute("INSERT INTO %s (k, v) VALUES (1, 'apple')"); + execute("INSERT INTO %s (k, v) VALUES (2, 'apple apple')"); + execute("INSERT INTO %s (k, v) VALUES (3, 'apple apple apple')"); + + // Query memtable index + assertBM25Ordering(); + + // Flush and query on-disk index + cluster.forEach(n -> n.flush(KEYSPACE)); + assertBM25Ordering(); + } + + private void assertBM25Ordering() + { + Object[][] result = execute("SELECT k FROM %s ORDER BY v BM25 OF 'apple' LIMIT 3"); + assertThat(result).hasNumberOfRows(3); + + // Results should be ordered by term frequency (highest to lowest) + assertThat((Integer) result[0][0]).isEqualTo(3); // 3 occurrences + assertThat((Integer) result[1][0]).isEqualTo(2); // 2 occurrences + assertThat((Integer) result[2][0]).isEqualTo(1); // 1 occurrence + } + + private static Object[][] execute(String query) + { + return execute(query, ConsistencyLevel.QUORUM); + } + + private static Object[][] execute(String query, ConsistencyLevel consistencyLevel) + { + return cluster.coordinator(1).execute(formatQuery(query), consistencyLevel); + } + + private static String formatQuery(String query) + { + return String.format(query, KEYSPACE + '.' + table); + } +} diff --git a/test/unit/org/apache/cassandra/cql3/CQLTester.java b/test/unit/org/apache/cassandra/cql3/CQLTester.java index f279026672c9..dafe0e4572d8 100644 --- a/test/unit/org/apache/cassandra/cql3/CQLTester.java +++ b/test/unit/org/apache/cassandra/cql3/CQLTester.java @@ -805,6 +805,11 @@ protected String currentIndex() return indexes.get(indexes.size() - 1); } + protected String getIndex(int i) + { + return indexes.get(i); + } + protected Collection currentTables() { if (tables == null || tables.isEmpty()) diff --git a/test/unit/org/apache/cassandra/cql3/validation/operations/SelectOrderByTest.java b/test/unit/org/apache/cassandra/cql3/validation/operations/SelectOrderByTest.java index dac86c76e5d8..34f7d606ce55 100644 --- a/test/unit/org/apache/cassandra/cql3/validation/operations/SelectOrderByTest.java +++ b/test/unit/org/apache/cassandra/cql3/validation/operations/SelectOrderByTest.java @@ -659,7 +659,7 @@ public void testAllowSkippingEqualityAndSingleValueInRestrictedClusteringColumns assertInvalidMessage("Cannot combine clustering column ordering with non-clustering column ordering", "SELECT * FROM %s WHERE a=? ORDER BY b ASC, c ASC, d ASC", 0); - String errorMsg = "Order by currently only supports the ordering of columns following their declared order in the PRIMARY KEY"; + String errorMsg = "Ordering by clustered columns must follow the declared order in the PRIMARY KEY"; assertRows(execute("SELECT * FROM %s WHERE a=? AND b=? ORDER BY c", 0, 0), row(0, 0, 0, 0), diff --git a/test/unit/org/apache/cassandra/index/sai/cql/BM25Test.java b/test/unit/org/apache/cassandra/index/sai/cql/BM25Test.java new file mode 100644 index 000000000000..ef9d09424ef5 --- /dev/null +++ b/test/unit/org/apache/cassandra/index/sai/cql/BM25Test.java @@ -0,0 +1,471 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.index.sai.cql; + +import java.util.Collection; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import org.apache.cassandra.index.sai.SAITester; +import org.apache.cassandra.index.sai.SAIUtil; +import org.apache.cassandra.index.sai.disk.format.Version; +import org.apache.cassandra.index.sai.plan.QueryController; + +import static org.apache.cassandra.index.sai.analyzer.AnalyzerEqOperatorSupport.EQ_AMBIGUOUS_ERROR; +import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat; + +@RunWith(Parameterized.class) +public class BM25Test extends SAITester +{ + @Parameterized.Parameter + public Version version; + + @Parameterized.Parameters(name = "{0}") + public static Collection data() + { + return Stream.of(Version.DC, Version.EB).map(v -> new Object[]{ v}).collect(Collectors.toList()); + } + + @Before + public void setup() throws Throwable + { + SAIUtil.setLatestVersion(version); + } + + @Test + public void testTwoIndexes() + { + // create un-analyzed index + createTable("CREATE TABLE %s (k int PRIMARY KEY, v text)"); + createIndex("CREATE CUSTOM INDEX ON %s(v) USING 'org.apache.cassandra.index.sai.StorageAttachedIndex'"); + execute("INSERT INTO %s (k, v) VALUES (1, 'apple')"); + + // BM25 should fail with only an equality index + assertInvalidMessage("BM25 ordering on column v requires an analyzed index", + "SELECT k FROM %s WHERE v : 'apple' ORDER BY v BM25 OF 'apple' LIMIT 3"); + + // create analyzed index + analyzeIndex(); + // BM25 query should work now + var result = execute("SELECT k FROM %s WHERE v : 'apple' ORDER BY v BM25 OF 'apple' LIMIT 3"); + assertRows(result, row(1)); + } + + @Test + public void testTwoIndexesAmbiguousPredicate() throws Throwable + { + createTable("CREATE TABLE %s (k int PRIMARY KEY, v text)"); + + // Create analyzed and un-analyzed indexes + analyzeIndex(); + createIndex("CREATE CUSTOM INDEX ON %s(v) USING 'org.apache.cassandra.index.sai.StorageAttachedIndex'"); + + execute("INSERT INTO %s (k, v) VALUES (1, 'apple')"); + execute("INSERT INTO %s (k, v) VALUES (2, 'apple juice')"); + execute("INSERT INTO %s (k, v) VALUES (3, 'orange juice')"); + + // equality predicate is ambiguous (both analyzed and un-analyzed indexes could support it) so it should + // be rejected + beforeAndAfterFlush(() -> { + // Single predicate + assertInvalidMessage(String.format(EQ_AMBIGUOUS_ERROR, "v", getIndex(0), getIndex(1)), + "SELECT k FROM %s WHERE v = 'apple'"); + + // AND + assertInvalidMessage(String.format(EQ_AMBIGUOUS_ERROR, "v", getIndex(0), getIndex(1)), + "SELECT k FROM %s WHERE v = 'apple' AND v : 'juice'"); + + // OR + assertInvalidMessage(String.format(EQ_AMBIGUOUS_ERROR, "v", getIndex(0), getIndex(1)), + "SELECT k FROM %s WHERE v = 'apple' OR v : 'juice'"); + }); + } + + @Test + public void testTwoIndexesWithEqualsUnsupported() throws Throwable + { + createTable("CREATE TABLE %s (k int PRIMARY KEY, v text)"); + createIndex("CREATE CUSTOM INDEX ON %s(v) USING 'org.apache.cassandra.index.sai.StorageAttachedIndex'"); + // analyzed index with equals_behavior:unsupported option + createIndex("CREATE CUSTOM INDEX ON %s(v) USING 'org.apache.cassandra.index.sai.StorageAttachedIndex' " + + "WITH OPTIONS = { 'equals_behaviour_when_analyzed': 'unsupported', " + + "'index_analyzer':'{\"tokenizer\":{\"name\":\"standard\"},\"filters\":[{\"name\":\"porterstem\"}]}' }"); + + execute("INSERT INTO %s (k, v) VALUES (1, 'apple')"); + execute("INSERT INTO %s (k, v) VALUES (2, 'apple juice')"); + + beforeAndAfterFlush(() -> { + // combining two EQ predicates is not allowed + assertInvalid("SELECT k FROM %s WHERE v = 'apple' AND v = 'juice'"); + + // combining EQ and MATCH predicates is also not allowed (when we're not converting EQ to MATCH) + assertInvalid("SELECT k FROM %s WHERE v = 'apple' AND v : 'apple'"); + + // combining two MATCH predicates is fine + assertRows(execute("SELECT k FROM %s WHERE v : 'apple' AND v : 'juice'"), + row(2)); + + // = operator should use un-analyzed index since equals is unsupported in analyzed index + assertRows(execute("SELECT k FROM %s WHERE v = 'apple'"), + row(1)); + + // : operator should use analyzed index + assertRows(execute("SELECT k FROM %s WHERE v : 'apple'"), + row(1), row(2)); + }); + } + + @Test + public void testComplexQueriesWithMultipleIndexes() throws Throwable + { + createTable("CREATE TABLE %s (k int PRIMARY KEY, v1 text, v2 text, v3 int)"); + + // Create mix of analyzed, unanalyzed, and non-text indexes + createIndex("CREATE CUSTOM INDEX ON %s(v1) USING 'org.apache.cassandra.index.sai.StorageAttachedIndex'"); + createIndex("CREATE CUSTOM INDEX ON %s(v2) " + + "USING 'org.apache.cassandra.index.sai.StorageAttachedIndex' " + + "WITH OPTIONS = {" + + "'index_analyzer': '{" + + "\"tokenizer\" : {\"name\" : \"standard\"}, " + + "\"filters\" : [{\"name\" : \"porterstem\"}]" + + "}'" + + "}"); + createIndex("CREATE CUSTOM INDEX ON %s(v3) USING 'org.apache.cassandra.index.sai.StorageAttachedIndex'"); + + execute("INSERT INTO %s (k, v1, v2, v3) VALUES (1, 'apple', 'orange juice', 5)"); + execute("INSERT INTO %s (k, v1, v2, v3) VALUES (2, 'apple juice', 'apple', 10)"); + execute("INSERT INTO %s (k, v1, v2, v3) VALUES (3, 'banana', 'grape juice', 5)"); + + beforeAndAfterFlush(() -> { + // Complex query mixing different types of indexes and operators + assertRows(execute("SELECT k FROM %s WHERE v1 = 'apple' AND v2 : 'juice' AND v3 = 5"), + row(1)); + + // Mix of AND and OR conditions across different index types + assertRows(execute("SELECT k FROM %s WHERE v3 = 5 AND (v1 = 'apple' OR v2 : 'apple')"), + row(1)); + + // Multi-term analyzed query + assertRows(execute("SELECT k FROM %s WHERE v2 : 'orange juice'"), + row(1)); + + // Range query with text match + assertRows(execute("SELECT k FROM %s WHERE v3 >= 5 AND v2 : 'juice'"), + row(1), row(3)); + }); + } + + @Test + public void testMatchingAllowed() throws Throwable + { + // match operator should be allowed with BM25 on the same column + // (seems obvious but exercises a corner case in the internal RestrictionSet processing) + createSimpleTable(); + + execute("INSERT INTO %s (k, v) VALUES (1, 'apple')"); + + beforeAndAfterFlush(() -> + { + var result = execute("SELECT k FROM %s WHERE v : 'apple' ORDER BY v BM25 OF 'apple' LIMIT 3"); + assertRows(result, row(1)); + }); + } + + @Test + public void testUnknownQueryTerm() throws Throwable + { + createSimpleTable(); + + execute("INSERT INTO %s (k, v) VALUES (1, 'apple')"); + + beforeAndAfterFlush(() -> + { + var result = execute("SELECT k FROM %s ORDER BY v BM25 OF 'orange' LIMIT 1"); + assertEmpty(result); + }); + } + + @Test + public void testDuplicateQueryTerm() throws Throwable + { + createSimpleTable(); + + execute("INSERT INTO %s (k, v) VALUES (1, 'apple')"); + + beforeAndAfterFlush(() -> + { + var result = execute("SELECT k FROM %s ORDER BY v BM25 OF 'apple apple' LIMIT 1"); + assertRows(result, row(1)); + }); + } + + @Test + public void testEmptyQuery() throws Throwable + { + createSimpleTable(); + + execute("INSERT INTO %s (k, v) VALUES (1, 'apple')"); + + beforeAndAfterFlush(() -> + { + assertInvalidMessage("BM25 query must contain at least one term (perhaps your analyzer is discarding tokens you didn't expect)", + "SELECT k FROM %s ORDER BY v BM25 OF '+' LIMIT 1"); + }); + } + + @Test + public void testTermFrequencyOrdering() throws Throwable + { + createSimpleTable(); + + // Insert documents with varying frequencies of the term "apple" + execute("INSERT INTO %s (k, v) VALUES (1, 'apple')"); + execute("INSERT INTO %s (k, v) VALUES (2, 'apple apple')"); + execute("INSERT INTO %s (k, v) VALUES (3, 'apple apple apple')"); + + beforeAndAfterFlush(() -> + { + // Results should be ordered by term frequency (highest to lowest) + var result = execute("SELECT k FROM %s ORDER BY v BM25 OF 'apple' LIMIT 3"); + assertRows(result, + row(3), // 3 occurrences + row(2), // 2 occurrences + row(1)); // 1 occurrence + }); + } + + @Test + public void testDocumentLength() throws Throwable + { + createSimpleTable(); + // Create documents with same term frequency but different lengths + execute("INSERT INTO %s (k, v) VALUES (1, 'test test')"); + execute("INSERT INTO %s (k, v) VALUES (2, 'test test other words here to make it longer')"); + execute("INSERT INTO %s (k, v) VALUES (3, 'test test extremely long document with many additional words to significantly increase the document length while maintaining the same term frequency for our target term')"); + + beforeAndAfterFlush(() -> + { + // Documents with same term frequency should be ordered by length (shorter first) + var result = execute("SELECT k FROM %s ORDER BY v BM25 OF 'test' LIMIT 3"); + assertRows(result, + row(1), + row(2), + row(3)); + }); + } + + @Test + public void testMultiTermQueryScoring() throws Throwable + { + createSimpleTable(); + // Two terms, but "apple" appears in fewer documents + execute("INSERT INTO %s (k, v) VALUES (1, 'apple banana')"); + execute("INSERT INTO %s (k, v) VALUES (2, 'apple apple banana')"); + execute("INSERT INTO %s (k, v) VALUES (3, 'apple banana banana')"); + execute("INSERT INTO %s (k, v) VALUES (4, 'apple apple banana banana')"); + execute("INSERT INTO %s (k, v) VALUES (5, 'banana banana')"); + + beforeAndAfterFlush(() -> + { + var result = execute("SELECT k FROM %s ORDER BY v BM25 OF 'apple banana' LIMIT 4"); + assertRows(result, + row(2), // Highest frequency of most important term + row(4), // More mentions of both terms + row(1), // One of each term + row(3)); // Low frequency of most important term + }); + } + + @Test + public void testIrrelevantRowsScoring() throws Throwable + { + createSimpleTable(); + // Insert pizza reviews with varying relevance to "crispy crust" + execute("INSERT INTO %s (k, v) VALUES (1, 'The pizza had a crispy crust and was delicious')"); // Basic mention + execute("INSERT INTO %s (k, v) VALUES (2, 'Very crispy crispy crust, perfectly cooked')"); // Emphasized crispy + execute("INSERT INTO %s (k, v) VALUES (3, 'The crust crust crust was okay, nothing special')"); // Only crust mentions + execute("INSERT INTO %s (k, v) VALUES (4, 'Super crispy crispy crust crust, best pizza ever!')"); // Most mentions of both + execute("INSERT INTO %s (k, v) VALUES (5, 'The toppings were good but the pizza was soggy')"); // Irrelevant review + + beforeAndAfterFlush(() -> + { + var result = execute("SELECT k FROM %s ORDER BY v BM25 OF 'crispy crust' LIMIT 5"); + assertRows(result, + row(4), // Highest frequency of both terms + row(2), // High frequency of 'crispy', one 'crust' + row(1)); // One mention of each term + // Rows 4 and 5 do not contain all terms + }); + } + + private void createSimpleTable() + { + createTable("CREATE TABLE %s (k int PRIMARY KEY, v text)"); + analyzeIndex(); + } + + private String analyzeIndex() + { + return createIndex("CREATE CUSTOM INDEX ON %s(v) " + + "USING 'org.apache.cassandra.index.sai.StorageAttachedIndex' " + + "WITH OPTIONS = {" + + "'index_analyzer': '{" + + "\"tokenizer\" : {\"name\" : \"standard\"}, " + + "\"filters\" : [{\"name\" : \"porterstem\"}]" + + "}'}" + ); + } + + @Test + public void testWithPredicate() throws Throwable + { + createTable("CREATE TABLE %s (k int PRIMARY KEY, p int, v text)"); + analyzeIndex(); + execute("CREATE CUSTOM INDEX ON %s(p) USING 'StorageAttachedIndex'"); + + // Insert documents with varying frequencies of the term "apple" + execute("INSERT INTO %s (k, p, v) VALUES (1, 5, 'apple')"); + execute("INSERT INTO %s (k, p, v) VALUES (2, 5, 'apple apple')"); + execute("INSERT INTO %s (k, p, v) VALUES (3, 5, 'apple apple apple')"); + execute("INSERT INTO %s (k, p, v) VALUES (4, 6, 'apple apple apple')"); + execute("INSERT INTO %s (k, p, v) VALUES (5, 7, 'apple apple apple')"); + + beforeAndAfterFlush(() -> + { + // Results should be ordered by term frequency (highest to lowest) + var result = execute("SELECT k FROM %s WHERE p = 5 ORDER BY v BM25 OF 'apple' LIMIT 3"); + assertRows(result, + row(3), // 3 occurrences + row(2), // 2 occurrences + row(1)); // 1 occurrence + }); + } + + @Test + public void testWidePartition() throws Throwable + { + createTable("CREATE TABLE %s (k1 int, k2 int, v text, PRIMARY KEY (k1, k2))"); + analyzeIndex(); + + // Insert documents with varying frequencies of the term "apple" + execute("INSERT INTO %s (k1, k2, v) VALUES (0, 1, 'apple')"); + execute("INSERT INTO %s (k1, k2, v) VALUES (0, 2, 'apple apple')"); + execute("INSERT INTO %s (k1, k2, v) VALUES (0, 3, 'apple apple apple')"); + + beforeAndAfterFlush(() -> + { + // Results should be ordered by term frequency (highest to lowest) + var result = execute("SELECT k2 FROM %s ORDER BY v BM25 OF 'apple' LIMIT 3"); + assertRows(result, + row(3), // 3 occurrences + row(2), // 2 occurrences + row(1)); // 1 occurrence + }); + } + + @Test + public void testWidePartitionWithPkPredicate() throws Throwable + { + createTable("CREATE TABLE %s (k1 int, k2 int, v text, PRIMARY KEY (k1, k2))"); + analyzeIndex(); + + // Insert documents with varying frequencies of the term "apple" + execute("INSERT INTO %s (k1, k2, v) VALUES (0, 1, 'apple')"); + execute("INSERT INTO %s (k1, k2, v) VALUES (0, 2, 'apple apple')"); + execute("INSERT INTO %s (k1, k2, v) VALUES (0, 3, 'apple apple apple')"); + execute("INSERT INTO %s (k1, k2, v) VALUES (1, 3, 'apple apple apple')"); + execute("INSERT INTO %s (k1, k2, v) VALUES (2, 3, 'apple apple apple')"); + + beforeAndAfterFlush(() -> + { + // Results should be ordered by term frequency (highest to lowest) + var result = execute("SELECT k2 FROM %s WHERE k1 = 0 ORDER BY v BM25 OF 'apple' LIMIT 3"); + assertRows(result, + row(3), // 3 occurrences + row(2), // 2 occurrences + row(1)); // 1 occurrence + }); + } + + @Test + public void testWidePartitionWithPredicate() throws Throwable + { + createTable("CREATE TABLE %s (k1 int, k2 int, p int, v text, PRIMARY KEY (k1, k2))"); + analyzeIndex(); + execute("CREATE CUSTOM INDEX ON %s(p) USING 'StorageAttachedIndex'"); + + // Insert documents with varying frequencies of the term "apple" + execute("INSERT INTO %s (k1, k2, p, v) VALUES (0, 1, 5, 'apple')"); + execute("INSERT INTO %s (k1, k2, p, v) VALUES (0, 2, 5, 'apple apple')"); + execute("INSERT INTO %s (k1, k2, p, v) VALUES (0, 3, 5, 'apple apple apple')"); + execute("INSERT INTO %s (k1, k2, p, v) VALUES (0, 4, 6, 'apple apple apple')"); + execute("INSERT INTO %s (k1, k2, p, v) VALUES (0, 5, 7, 'apple apple apple')"); + + beforeAndAfterFlush(() -> + { + // Results should be ordered by term frequency (highest to lowest) + var result = execute("SELECT k2 FROM %s WHERE p = 5 ORDER BY v BM25 OF 'apple' LIMIT 3"); + assertRows(result, + row(3), // 3 occurrences + row(2), // 2 occurrences + row(1)); // 1 occurrence + }); + } + + @Test + public void testWithPredicateSearchThenOrder() throws Throwable + { + QueryController.QUERY_OPT_LEVEL = 0; + testWithPredicate(); + } + + @Test + public void testWidePartitionWithPredicateOrderThenSearch() throws Throwable + { + QueryController.QUERY_OPT_LEVEL = 1; + testWidePartitionWithPredicate(); + } + + @Test + public void testQueryWithNulls() throws Throwable + { + createSimpleTable(); + + execute("INSERT INTO %s (k, v) VALUES (0, null)"); + execute("INSERT INTO %s (k, v) VALUES (1, 'test document')"); + beforeAndAfterFlush(() -> + { + var result = execute("SELECT k FROM %s ORDER BY v BM25 OF 'test' LIMIT 1"); + assertRows(result, row(1)); + }); + } + + @Test + public void testQueryEmptyTable() + { + createSimpleTable(); + var result = execute("SELECT k FROM %s ORDER BY v BM25 OF 'test' LIMIT 1"); + assertThat(result).hasSize(0); + } +} diff --git a/test/unit/org/apache/cassandra/index/sai/cql/MultipleColumnIndexTest.java b/test/unit/org/apache/cassandra/index/sai/cql/MultipleColumnIndexTest.java index 5364717a8906..8d952dc1f4f0 100644 --- a/test/unit/org/apache/cassandra/index/sai/cql/MultipleColumnIndexTest.java +++ b/test/unit/org/apache/cassandra/index/sai/cql/MultipleColumnIndexTest.java @@ -40,12 +40,30 @@ public void canCreateMultipleMapIndexesOnSameColumn() throws Throwable } @Test - public void cannotHaveMultipleLiteralIndexesWithDifferentOptions() throws Throwable + public void canHaveAnalyzedAndUnanalyzedIndexesOnSameColumn() throws Throwable { - createTable("CREATE TABLE %s (pk int, ck int, value text, PRIMARY KEY(pk, ck))"); + createTable("CREATE TABLE %s (pk int, value text, PRIMARY KEY(pk))"); createIndex("CREATE CUSTOM INDEX ON %s(value) USING 'StorageAttachedIndex' WITH OPTIONS = { 'case_sensitive' : true }"); - assertThatThrownBy(() -> createIndex("CREATE CUSTOM INDEX ON %s(value) USING 'StorageAttachedIndex' WITH OPTIONS = { 'case_sensitive' : false }")) - .isInstanceOf(InvalidRequestException.class); + createIndex("CREATE CUSTOM INDEX ON %s(value) USING 'StorageAttachedIndex' WITH OPTIONS = { 'case_sensitive' : false, 'equals_behaviour_when_analyzed': 'unsupported' }"); + + execute("INSERT INTO %s (pk, value) VALUES (?, ?)", 1, "a"); + execute("INSERT INTO %s (pk, value) VALUES (?, ?)", 2, "A"); + beforeAndAfterFlush(() -> { + assertRows(execute("SELECT pk FROM %s WHERE value = 'a'"), + row(1)); + assertRows(execute("SELECT pk FROM %s WHERE value : 'a'"), + row(1), + row(2)); + }); + } + + @Test + public void cannotHaveMultipleAnalyzingIndexesOnSameColumn() throws Throwable + { + createTable("CREATE TABLE %s (pk int, ck int, value text, PRIMARY KEY(pk, ck))"); + createIndex("CREATE CUSTOM INDEX ON %s(value) USING 'StorageAttachedIndex' WITH OPTIONS = { 'case_sensitive' : false }"); + assertThatThrownBy(() -> createIndex("CREATE CUSTOM INDEX ON %s(value) USING 'StorageAttachedIndex' WITH OPTIONS = { 'normalize' : true }")) + .isInstanceOf(InvalidRequestException.class); } @Test diff --git a/test/unit/org/apache/cassandra/index/sai/cql/NativeIndexDDLTest.java b/test/unit/org/apache/cassandra/index/sai/cql/NativeIndexDDLTest.java index 12b597e85ca8..8d4ac97b2a2a 100644 --- a/test/unit/org/apache/cassandra/index/sai/cql/NativeIndexDDLTest.java +++ b/test/unit/org/apache/cassandra/index/sai/cql/NativeIndexDDLTest.java @@ -581,7 +581,7 @@ public void shouldFailCreationMultipleIndexesOnSimpleColumn() // different name, different option, same target. assertThatThrownBy(() -> executeNet("CREATE CUSTOM INDEX ON %s(v1) USING 'StorageAttachedIndex' WITH OPTIONS = { 'case_sensitive' : true }")) .isInstanceOf(InvalidQueryException.class) - .hasMessageContaining("Cannot create more than one storage-attached index on the same column: v1" ); + .hasMessageContaining("Cannot create duplicate storage-attached index on column: v1" ); ResultSet rows = executeNet("SELECT id FROM %s WHERE v1 = '1'"); assertEquals(1, rows.all().size()); diff --git a/test/unit/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcherTest.java b/test/unit/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcherTest.java index 47565eb28739..899c03a95ccc 100644 --- a/test/unit/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcherTest.java +++ b/test/unit/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcherTest.java @@ -106,7 +106,7 @@ private void doTestEqQueriesAgainstStringIndex(Version version) throws Exception for (int t = 0; t < numTerms; ++t) { try (KeyRangeIterator results = searcher.search(new Expression(indexContext) - .add(Operator.EQ, termsEnum.get(t).originalTermBytes), null, new QueryContext(), false, LIMIT)) + .add(Operator.EQ, termsEnum.get(t).originalTermBytes), null, new QueryContext(), false)) { assertTrue(results.hasNext()); @@ -121,7 +121,7 @@ private void doTestEqQueriesAgainstStringIndex(Version version) throws Exception } try (KeyRangeIterator results = searcher.search(new Expression(indexContext) - .add(Operator.EQ, termsEnum.get(t).originalTermBytes), null, new QueryContext(), false, LIMIT)) + .add(Operator.EQ, termsEnum.get(t).originalTermBytes), null, new QueryContext(), false)) { assertTrue(results.hasNext()); @@ -143,12 +143,12 @@ private void doTestEqQueriesAgainstStringIndex(Version version) throws Exception // try searching for terms that weren't indexed final String tooLongTerm = randomSimpleString(10, 12); KeyRangeIterator results = searcher.search(new Expression(indexContext) - .add(Operator.EQ, UTF8Type.instance.decompose(tooLongTerm)), null, new QueryContext(), false, LIMIT); + .add(Operator.EQ, UTF8Type.instance.decompose(tooLongTerm)), null, new QueryContext(), false); assertFalse(results.hasNext()); final String tooShortTerm = randomSimpleString(1, 2); results = searcher.search(new Expression(indexContext) - .add(Operator.EQ, UTF8Type.instance.decompose(tooShortTerm)), null, new QueryContext(), false, LIMIT); + .add(Operator.EQ, UTF8Type.instance.decompose(tooShortTerm)), null, new QueryContext(), false); assertFalse(results.hasNext()); } } @@ -162,7 +162,7 @@ public void testUnsupportedOperator() throws Exception try (IndexSearcher searcher = buildIndexAndOpenSearcher(numTerms, numPostings, termsEnum)) { searcher.search(new Expression(indexContext) - .add(Operator.NEQ, UTF8Type.instance.decompose("a")), null, new QueryContext(), false, LIMIT); + .add(Operator.NEQ, UTF8Type.instance.decompose("a")), null, new QueryContext(), false); fail("Expect IllegalArgumentException thrown, but didn't"); } diff --git a/test/unit/org/apache/cassandra/index/sai/disk/v1/KDTreeIndexSearcherTest.java b/test/unit/org/apache/cassandra/index/sai/disk/v1/KDTreeIndexSearcherTest.java index 8f34b9a808aa..65242e319a37 100644 --- a/test/unit/org/apache/cassandra/index/sai/disk/v1/KDTreeIndexSearcherTest.java +++ b/test/unit/org/apache/cassandra/index/sai/disk/v1/KDTreeIndexSearcherTest.java @@ -151,7 +151,7 @@ public void testUnsupportedOperator() throws Exception {{ operation = Op.NOT_EQ; lower = upper = new Bound(ShortType.instance.decompose((short) 0), Int32Type.instance, true); - }}, null, new QueryContext(), false, LIMIT); + }}, null, new QueryContext(), false); fail("Expect IllegalArgumentException thrown, but didn't"); } @@ -169,7 +169,7 @@ private void testEqQueries(final IndexSearcher indexSearcher, {{ operation = Op.EQ; lower = upper = new Bound(rawType.decompose(rawValueProducer.apply(EQ_TEST_LOWER_BOUND_INCLUSIVE)), encodedType, true); - }}, null, new QueryContext(), false, LIMIT)) + }}, null, new QueryContext(), false)) { assertTrue(results.hasNext()); @@ -180,7 +180,7 @@ private void testEqQueries(final IndexSearcher indexSearcher, {{ operation = Op.EQ; lower = upper = new Bound(rawType.decompose(rawValueProducer.apply(EQ_TEST_UPPER_BOUND_EXCLUSIVE)), encodedType, true); - }}, null, new QueryContext(), false, LIMIT)) + }}, null, new QueryContext(), false)) { assertFalse(results.hasNext()); indexSearcher.close(); @@ -206,7 +206,7 @@ private void testRangeQueries(final IndexSearcher indexSearch lower = new Bound(rawType.decompose(rawValueProducer.apply((short)2)), encodedType, false); upper = new Bound(rawType.decompose(rawValueProducer.apply((short)7)), encodedType, true); - }}, null, new QueryContext(), false, LIMIT)) + }}, null, new QueryContext(), false)) { assertTrue(results.hasNext()); @@ -218,7 +218,7 @@ private void testRangeQueries(final IndexSearcher indexSearch {{ operation = Op.RANGE; lower = new Bound(rawType.decompose(rawValueProducer.apply(RANGE_TEST_UPPER_BOUND_EXCLUSIVE)), encodedType, true); - }}, null, new QueryContext(), false, LIMIT)) + }}, null, new QueryContext(), false)) { assertFalse(results.hasNext()); } @@ -227,7 +227,7 @@ private void testRangeQueries(final IndexSearcher indexSearch {{ operation = Op.RANGE; upper = new Bound(rawType.decompose(rawValueProducer.apply(RANGE_TEST_LOWER_BOUND_INCLUSIVE)), encodedType, false); - }}, null, new QueryContext(), false, LIMIT)) + }}, null, new QueryContext(), false)) { assertFalse(results.hasNext()); indexSearcher.close(); diff --git a/test/unit/org/apache/cassandra/index/sai/disk/v1/postings/IntersectingPostingListTest.java b/test/unit/org/apache/cassandra/index/sai/disk/v1/postings/IntersectingPostingListTest.java new file mode 100644 index 000000000000..40f6f16ef970 --- /dev/null +++ b/test/unit/org/apache/cassandra/index/sai/disk/v1/postings/IntersectingPostingListTest.java @@ -0,0 +1,213 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.index.sai.disk.v1.postings; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.IntStream; + +import org.junit.Test; + +import org.apache.cassandra.index.sai.disk.PostingList; +import org.apache.cassandra.index.sai.postings.IntArrayPostingList; +import org.apache.cassandra.index.sai.utils.SaiRandomizedTest; + +public class IntersectingPostingListTest extends SaiRandomizedTest +{ + @Test + public void shouldIntersectOverlappingPostingLists() throws IOException + { + var lists = listOfLists(new IntArrayPostingList(new int[]{ 1, 4, 6, 8 }), + new IntArrayPostingList(new int[]{ 2, 4, 6, 9 }), + new IntArrayPostingList(new int[]{ 4, 6, 7 })); + + final PostingList intersected = IntersectingPostingList.intersect(lists); + + assertPostingListEquals(new IntArrayPostingList(new int[]{ 4, 6 }), intersected); + } + + @Test + public void shouldIntersectDisjointPostingLists() throws IOException + { + var lists = listOfLists(new IntArrayPostingList(new int[]{ 1, 3, 5 }), + new IntArrayPostingList(new int[]{ 2, 4, 6 })); + + final PostingList intersected = IntersectingPostingList.intersect(lists); + + assertPostingListEquals(new IntArrayPostingList(new int[]{}), intersected); + } + + @Test + public void shouldIntersectSinglePostingList() throws IOException + { + var lists = listOfLists(new IntArrayPostingList(new int[]{ 1, 4, 6 })); + + final PostingList intersected = IntersectingPostingList.intersect(lists); + + assertPostingListEquals(new IntArrayPostingList(new int[]{ 1, 4, 6 }), intersected); + } + + @Test + public void shouldIntersectIdenticalPostingLists() throws IOException + { + var lists = listOfLists(new IntArrayPostingList(new int[]{ 1, 2, 3 }), + new IntArrayPostingList(new int[]{ 1, 2, 3 })); + + final PostingList intersected = IntersectingPostingList.intersect(lists); + + assertPostingListEquals(new IntArrayPostingList(new int[]{ 1, 2, 3 }), intersected); + } + + @Test + public void shouldAdvanceAllIntersectedLists() throws IOException + { + var lists = listOfLists(new IntArrayPostingList(new int[]{ 1, 3, 5, 7, 9 }), + new IntArrayPostingList(new int[]{ 2, 3, 5, 7, 8 }), + new IntArrayPostingList(new int[]{ 3, 5, 7, 10 })); + + final PostingList intersected = IntersectingPostingList.intersect(lists); + final PostingList expected = new IntArrayPostingList(new int[]{ 3, 5, 7 }); + + assertEquals(expected.advance(5), + intersected.advance(5)); + + assertPostingListEquals(expected, intersected); + } + + @Test + public void shouldHandleEmptyList() throws IOException + { + var lists = listOfLists(new IntArrayPostingList(new int[]{}), + new IntArrayPostingList(new int[]{ 1, 2, 3 })); + + final PostingList intersected = IntersectingPostingList.intersect(lists); + + assertEquals(PostingList.END_OF_STREAM, intersected.advance(1)); + } + + @Test + public void shouldInterleaveNextAndAdvance() throws IOException + { + var lists = listOfLists(new IntArrayPostingList(new int[]{ 1, 3, 5, 7, 9 }), + new IntArrayPostingList(new int[]{ 1, 3, 5, 7, 9 }), + new IntArrayPostingList(new int[]{ 1, 3, 5, 7, 9 })); + + final PostingList intersected = IntersectingPostingList.intersect(lists); + + assertEquals(1, intersected.nextPosting()); + assertEquals(5, intersected.advance(5)); + assertEquals(7, intersected.nextPosting()); + assertEquals(9, intersected.advance(9)); + } + + @Test + public void shouldInterleaveNextAndAdvanceOnRandom() throws IOException + { + for (int i = 0; i < 1000; ++i) + { + testAdvancingOnRandom(); + } + } + + private void testAdvancingOnRandom() throws IOException + { + final int postingsCount = nextInt(1, 50_000); + final int postingListCount = nextInt(2, 10); + + // Generate base postings that will be present in all lists + final AtomicInteger rowId = new AtomicInteger(); + final int[] commonPostings = IntStream.generate(() -> rowId.addAndGet(nextInt(1, 10))) + .limit(postingsCount / 4) // Fewer common elements + .toArray(); + + var splitPostingLists = new ArrayList(); + for (int i = 0; i < postingListCount; i++) + { + // Combine common postings with some unique ones for each list + final int[] uniquePostings = IntStream.generate(() -> rowId.addAndGet(nextInt(1, 10))) + .limit(postingsCount) + .toArray(); + int[] combined = IntStream.concat(IntStream.of(commonPostings), + IntStream.of(uniquePostings)) + .distinct() + .sorted() + .toArray(); + splitPostingLists.add(new IntArrayPostingList(combined)); + } + + final PostingList intersected = IntersectingPostingList.intersect(splitPostingLists); + final PostingList expected = new IntArrayPostingList(commonPostings); + + final List actions = new ArrayList<>(); + for (int idx = 0; idx < commonPostings.length; idx++) + { + if (nextInt(0, 8) == 0) + { + actions.add((postingList) -> { + try + { + return postingList.nextPosting(); + } + catch (IOException e) + { + fail(e.getMessage()); + throw new RuntimeException(e); + } + }); + } + else + { + final int skips = nextInt(0, 5); + idx = Math.min(idx + skips, commonPostings.length - 1); + final int rowID = commonPostings[idx]; + actions.add((postingList) -> { + try + { + return postingList.advance(rowID); + } + catch (IOException e) + { + fail(e.getMessage()); + throw new RuntimeException(e); + } + }); + } + } + + for (PostingListAdvance action : actions) + { + assertEquals(action.advance(expected), action.advance(intersected)); + } + } + + private ArrayList listOfLists(PostingList... postingLists) + { + var L = new ArrayList(); + Collections.addAll(L, postingLists); + return L; + } + + private interface PostingListAdvance + { + long advance(PostingList list) throws IOException; + } +} diff --git a/test/unit/org/apache/cassandra/index/sai/disk/v1/postings/VectorPostingListTest.java b/test/unit/org/apache/cassandra/index/sai/disk/v1/postings/ReorderingPostingListTest.java similarity index 91% rename from test/unit/org/apache/cassandra/index/sai/disk/v1/postings/VectorPostingListTest.java rename to test/unit/org/apache/cassandra/index/sai/disk/v1/postings/ReorderingPostingListTest.java index 3135db3c7748..10b2b6a33f47 100644 --- a/test/unit/org/apache/cassandra/index/sai/disk/v1/postings/VectorPostingListTest.java +++ b/test/unit/org/apache/cassandra/index/sai/disk/v1/postings/ReorderingPostingListTest.java @@ -30,14 +30,14 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; -public class VectorPostingListTest +public class ReorderingPostingListTest { @Test public void ensureEmptySourceBehavesCorrectly() throws Throwable { var source = new TestIterator(CloseableIterator.emptyIterator()); - try (var postingList = new VectorPostingList(source)) + try (var postingList = new ReorderingPostingList(source, RowIdWithScore::getSegmentRowId)) { // Even an empty source should be closed assertTrue(source.isClosed); @@ -55,7 +55,7 @@ public void ensureIteratorIsConsumedClosedAndReordered() throws Throwable new RowIdWithScore(4, 4), }).iterator()); - try (var postingList = new VectorPostingList(source)) + try (var postingList = new ReorderingPostingList(source, RowIdWithScore::getSegmentRowId)) { // The posting list is eagerly consumed, so it should be closed before // we close postingList @@ -80,7 +80,7 @@ public void ensureAdvanceWorksCorrectly() throws Throwable new RowIdWithScore(2, 2), }).iterator()); - try (var postingList = new VectorPostingList(source)) + try (var postingList = new ReorderingPostingList(source, RowIdWithScore::getSegmentRowId)) { assertEquals(3, postingList.advance(3)); assertEquals(PostingList.END_OF_STREAM, postingList.advance(4)); diff --git a/test/unit/org/apache/cassandra/index/sai/memory/VectorMemtableIndexTest.java b/test/unit/org/apache/cassandra/index/sai/memory/VectorMemtableIndexTest.java index 75e8c83f30be..2197c00fe231 100644 --- a/test/unit/org/apache/cassandra/index/sai/memory/VectorMemtableIndexTest.java +++ b/test/unit/org/apache/cassandra/index/sai/memory/VectorMemtableIndexTest.java @@ -148,7 +148,7 @@ private void validate(List keys) { IntStream.range(0, 1_000).parallel().forEach(i -> { - var orderer = generateRandomOrderer(); + var orderer = randomVectorOrderer(); AbstractBounds keyRange = generateRandomBounds(keys); // compute keys in range of the bounds Set keysInRange = keys.stream().filter(keyRange::contains) @@ -197,7 +197,7 @@ public void indexIteratorTest() // VSTODO } - private Orderer generateRandomOrderer() + private Orderer randomVectorOrderer() { return new Orderer(indexContext, Operator.ANN, randomVectorSerialized()); }