Skip to content

Commit

Permalink
Refactor group IDs to set at the parse level
Browse files Browse the repository at this point in the history
This ensures group IDs are static for each parsed node.  We also ensure
that we check for at least GroupID.Size() matching items from trees when
matching on incoming events.
  • Loading branch information
tonyhb committed Jan 5, 2024
1 parent 921f6fc commit db40612
Show file tree
Hide file tree
Showing 7 changed files with 290 additions and 42 deletions.
71 changes: 54 additions & 17 deletions expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,21 @@ func (a *aggregator) Evaluate(ctx context.Context, data map[string]any) ([]Evalu
// are added (eg. >= operators on strings), ensure that we find the correct number of matches
// for each group ID and then skip evaluating expressions if the number of matches is <= the group
// ID's length.
seen := map[groupID]struct{}{}

for _, match := range matches {
if _, ok := seen[match.GroupID]; ok {
continue
}

atomic.AddInt32(&matched, 1)
// NOTE: We don't need to add lifted expression variables,
// because match.Parsed.Evaluable() returns the original expression
// string.
ok, evalerr := a.eval(ctx, match.Parsed.Evaluable, data)

seen[match.GroupID] = struct{}{}

if evalerr != nil {
err = errors.Join(err, evalerr)
continue
Expand All @@ -161,6 +170,12 @@ func (a *aggregator) AggregateMatch(ctx context.Context, data map[string]any) ([
a.lock.RLock()
defer a.lock.RUnlock()

// Store the number of times each GroupID has found a match. We need at least
// as many matches as stored in the group ID to consider the match.
counts := map[groupID]int{}
// Store all expression parts per group ID for returning.
found := map[groupID][]ExpressionPart{}

// Iterate through all known variables/idents in the aggregate tree to see if
// the data has those keys set. If so, we can immediately evaluate the data with
// the tree.
Expand All @@ -179,16 +194,32 @@ func (a *aggregator) AggregateMatch(ctx context.Context, data map[string]any) ([

switch cast := res[0].(type) {
case string:
found, ok := tree.Search(ctx, cast)
all, ok := tree.Search(ctx, cast)
if !ok {
continue
}
result = append(result, found.Evals...)

for _, eval := range all.Evals {
counts[eval.GroupID] += 1
if _, ok := found[eval.GroupID]; !ok {
found[eval.GroupID] = []ExpressionPart{}
}
found[eval.GroupID] = append(found[eval.GroupID], eval)
}
default:
continue
}
}

for k, count := range counts {
if int(k.Size()) > count {
// The GroupID required more comparisons to equate to true than
// we had, so this could never evaluate to true. Skip this.
continue
}
result = append(result, found[k]...)
}

return result, nil
}

Expand Down Expand Up @@ -238,16 +269,26 @@ func (a *aggregator) addGroup(ctx context.Context, node *Node, parsed *ParsedExp
return false, nil
}

// Merge all of the nodes together and check whether each node is aggregateable.
all := append(node.Ands, node)
for _, n := range all {
if !n.HasPredicate() || len(n.Ors) > 0 {
// Don't handle sub-branching for now.
return false, nil
if len(node.Ands) > 0 {
for _, n := range node.Ands {
if !n.HasPredicate() || len(n.Ors) > 0 {
// Don't handle sub-branching for now.
return false, nil
}
if !isAggregateable(n) {
return false, nil
}
}
if !isAggregateable(n) {
}

all := node.Ands

if node.Predicate != nil {
if !isAggregateable(node) {
return false, nil
}
// Merge all of the nodes together and check whether each node is aggregateable.
all = append(node.Ands, node)
}

// Create a new group ID which tracks the number of expressions that must match
Expand All @@ -258,9 +299,8 @@ func (a *aggregator) addGroup(ctx context.Context, node *Node, parsed *ParsedExp
// When checking an incoming event, we match the event against each node's
// ident/variable. Using the group ID, we can see if we've matched N necessary
// items from the same identifier. If so, the evaluation is true.
groupID := newGroupID(uint16(len(all)))
for _, n := range all {
err := a.addNode(ctx, n, groupID, parsed)
err := a.addNode(ctx, n, parsed)
if err == errTreeUnimplemented {
return false, nil
}
Expand All @@ -272,7 +312,7 @@ func (a *aggregator) addGroup(ctx context.Context, node *Node, parsed *ParsedExp
return true, nil
}

func (a *aggregator) addNode(ctx context.Context, n *Node, gid groupID, parsed *ParsedExpression) error {
func (a *aggregator) addNode(ctx context.Context, n *Node, parsed *ParsedExpression) error {
// Don't allow anything to update in parallel. This enrues that Add() can be called
// concurrently.
a.lock.Lock()
Expand All @@ -286,7 +326,7 @@ func (a *aggregator) addNode(ctx context.Context, n *Node, gid groupID, parsed *
tree = newArtTree()
}
err := tree.Add(ctx, ExpressionPart{
GroupID: gid,
GroupID: n.GroupID,
Predicate: *n.Predicate,
Parsed: parsed,
})
Expand All @@ -302,14 +342,11 @@ func (a *aggregator) addNode(ctx context.Context, n *Node, gid groupID, parsed *
func (a *aggregator) Remove(ctx context.Context, eval Evaluable) error {
// parse the expression using our tree parser.
parsed, err := a.parser.Parse(ctx, eval)
_ = parsed
if err != nil {
return err
}

for _, g := range parsed.RootGroups() {
_ = g
}

return fmt.Errorf("not implemented")
}

Expand Down
69 changes: 47 additions & 22 deletions expr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,25 +137,7 @@ func TestEvaluate(t *testing.T) {

require.NoError(t, err)
require.EqualValues(t, 0, len(evals))
require.EqualValues(t, 1, matched) // We still ran one expression
})

t.Run("It handles matching on arrays of data", func(t *testing.T) {
pre := time.Now()
evals, matched, err := e.Evaluate(ctx, map[string]any{
"event": map[string]any{
"data": map[string]any{
"ids": []string{"a", "b", "c"},
},
},
})
total := time.Since(pre)
fmt.Printf("Matched in %v ns\n", total.Nanoseconds())
fmt.Printf("Matched in %v ms\n", total.Milliseconds())

require.NoError(t, err)
require.EqualValues(t, 0, len(evals))
require.EqualValues(t, 1, matched) // We still ran one expression
require.EqualValues(t, 0, matched) // We still ran one expression
})
}

Expand All @@ -170,7 +152,7 @@ func TestEvaluate_Concurrently(t *testing.T) {
require.NoError(t, err)

go func() {
for i := 0; i < 1_000; i++ {
for i := 0; i < 100_000; i++ {
//nolint:all
go func() {
byt := make([]byte, 8)
Expand Down Expand Up @@ -213,7 +195,7 @@ func TestEvaluate_ArrayIndexes(t *testing.T) {
require.NoError(t, err)
e := NewAggregateEvaluator(parser, testBoolEvaluator)

expected := tex(`event.data.ids[2] == "id-b"`)
expected := tex(`event.data.ids[1] == "id-b" && event.data.ids[2] == "id-c"`)
_, err = e.Add(ctx, expected)
require.NoError(t, err)

Expand Down Expand Up @@ -258,7 +240,7 @@ func TestEvaluate_ArrayIndexes(t *testing.T) {
evals, matched, err := e.Evaluate(ctx, map[string]any{
"event": map[string]any{
"data": map[string]any{
"ids": []string{"a", "yes", "id-b"},
"ids": []string{"id-a", "id-b", "id-c"},
},
},
})
Expand All @@ -272,6 +254,49 @@ func TestEvaluate_ArrayIndexes(t *testing.T) {
})
}

func TestEvaluate_Compound(t *testing.T) {
ctx := context.Background()
parser, err := NewTreeParser(NewCachingParser(newEnv(), nil))
require.NoError(t, err)
e := NewAggregateEvaluator(parser, testBoolEvaluator)

expected := tex(`event.data.a == "ok" && event.data.b == "yes" && event.data.c == "please"`)
ok, err := e.Add(ctx, expected)
require.True(t, ok)
require.NoError(t, err)

t.Run("It matches items", func(t *testing.T) {
evals, matched, err := e.Evaluate(ctx, map[string]any{
"event": map[string]any{
"data": map[string]any{
"a": "ok",
"b": "yes",
"c": "please",
},
},
})
require.NoError(t, err)
require.EqualValues(t, 1, matched) // We only perform one eval
require.EqualValues(t, []Evaluable{expected}, evals)
})

t.Run("It skips if less than the group length is found", func(t *testing.T) {
evals, matched, err := e.Evaluate(ctx, map[string]any{
"event": map[string]any{
"data": map[string]any{
"a": "ok",
"b": "yes",
"c": "no - no match",
},
},
})
require.NoError(t, err)
require.EqualValues(t, 0, matched)
require.EqualValues(t, []Evaluable{}, evals)
})

}

func TestAggregateMatch(t *testing.T) {
ctx := context.Background()
parser, err := newParser()
Expand Down
9 changes: 8 additions & 1 deletion groupid.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,27 @@ package expr
import (
"crypto/rand"
"encoding/binary"
"encoding/hex"
)

// groupID represents a group ID. The first 2 byets are an int16 size of the expression group,
// representing the number of predicates within the expression. The last 6 bytes are a random
// ID for the predicate group.
type groupID [8]byte

var rander = rand.Read

func (g groupID) String() string {
return hex.EncodeToString(g[:])
}

func (g groupID) Size() uint16 {
return binary.NativeEndian.Uint16(g[0:2])
}

func newGroupID(size uint16) groupID {
id := make([]byte, 8)
binary.NativeEndian.PutUint16(id, size)
_, _ = rand.Read(id[2:])
_, _ = rander(id[2:])
return [8]byte(id[0:8])
}
27 changes: 27 additions & 0 deletions parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ func (p ParsedExpression) RootGroups() []*Node {
// This requres A *and* either B or C, and so we require all ANDs plus at least one node
// from OR to evaluate to true
type Node struct {
GroupID groupID

// Ands contains predicates at this level of the expression that are joined together
// with an && operator. All nodes in this set must evaluate to true in order for this
// node in the expression to be truthy.
Expand Down Expand Up @@ -416,6 +418,31 @@ func navigateAST(nav expr, parent *Node, vars LiftedArgs) ([]*Node, error) {
}

parent.Ands = result

// Add a group ID to the parent.
total := len(parent.Ands)
if parent.Predicate != nil {
total += 1
}
if len(parent.Ors) >= 1 {
total += 1
}

parent.GroupID = newGroupID(uint16(total))
// For each sub-group, add the same group IDs to children if there's no nesting.
for n, item := range parent.Ands {
if len(item.Ands) == 0 && len(item.Ors) == 0 && item.Predicate != nil {
item.GroupID = parent.GroupID
parent.Ands[n] = item
}
}
for n, item := range parent.Ors {
if len(item.Ands) == 0 && len(item.Ors) == 0 && item.Predicate != nil {
item.GroupID = parent.GroupID
parent.Ors[n] = item
}
}

return result, nil
}

Expand Down
Loading

0 comments on commit db40612

Please sign in to comment.