diff --git a/expr.go b/expr.go index 3608efb..eb0c951 100644 --- a/expr.go +++ b/expr.go @@ -9,6 +9,8 @@ import ( "github.com/google/cel-go/common/operators" "github.com/google/uuid" + "golang.org/x/sync/errgroup" + "golang.org/x/sync/semaphore" ) var ( @@ -70,11 +72,17 @@ func NewAggregateEvaluator( parser TreeParser, eval ExpressionEvaluator, evalLoader EvaluableLoader, + concurrency int64, ) AggregateEvaluator { + if concurrency <= 0 { + concurrency = 1 + } + return &aggregator{ eval: eval, parser: parser, loader: evalLoader, + sem: semaphore.NewWeighted(concurrency), engines: map[EngineType]MatchingEngine{ EngineTypeStringHash: newStringEqualityMatcher(), EngineTypeNullMatch: newNullMatcher(), @@ -92,6 +100,8 @@ type aggregator struct { // engines records all engines engines map[EngineType]MatchingEngine + sem *semaphore.Weighted + // lock prevents concurrent updates of data lock *sync.RWMutex // len stores the current len of aggregable expressions. @@ -123,6 +133,7 @@ func (a *aggregator) Evaluate(ctx context.Context, data map[string]any) ([]Evalu err error matched = int32(0) result = []Evaluable{} + s sync.Mutex ) // TODO: Concurrently match constant expressions using a semaphore for capacity. @@ -131,25 +142,49 @@ func (a *aggregator) Evaluate(ctx context.Context, data map[string]any) ([]Evalu if err != nil { return nil, 0, err } - for _, expr := range constantEvals { - atomic.AddInt32(&matched, 1) - if expr.GetExpression() == "" { - result = append(result, expr) - continue + eg := errgroup.Group{} + for _, item := range constantEvals { + if err := a.sem.Acquire(ctx, 1); err != nil { + return result, matched, err } - // NOTE: We don't need to add lifted expression variables, - // because match.Parsed.Evaluable() returns the original expression - // string. - ok, evalerr := a.eval(ctx, expr, data) - if evalerr != nil { - err = errors.Join(err, evalerr) - continue - } - if ok { - result = append(result, expr) - } + expr := item + eg.Go(func() error { + defer a.sem.Release(1) + defer func() { + if r := recover(); r != nil { + err = errors.Join(err, fmt.Errorf("recovered from panic in evaluate: %v", r)) + } + }() + + atomic.AddInt32(&matched, 1) + + if expr.GetExpression() == "" { + s.Lock() + result = append(result, expr) + s.Unlock() + return nil + } + + // NOTE: We don't need to add lifted expression variables, + // because match.Parsed.Evaluable() returns the original expression + // string. + ok, evalerr := a.eval(ctx, expr, data) + if evalerr != nil { + return evalerr + } + if ok { + s.Lock() + result = append(result, expr) + s.Unlock() + } + return nil + }) + } + + if werr := eg.Wait(); werr != nil { + err = errors.Join(err, werr) } matches, merr := a.AggregateMatch(ctx, data) @@ -173,26 +208,46 @@ func (a *aggregator) Evaluate(ctx context.Context, data map[string]any) ([]Evalu // ID's length. seen := map[uuid.UUID]struct{}{} + eg = errgroup.Group{} for _, match := range evaluables { - if _, ok := seen[match.GetID()]; ok { - continue + if err := a.sem.Acquire(ctx, 1); err != nil { + return result, matched, err } - atomic.AddInt32(&matched, 1) - // NOTE: We don't need to add lifted expression variables, - // because match.Parsed.Evaluable() returns the original expression - // string. - ok, evalerr := a.eval(ctx, match, data) + expr := match + eg.Go(func() error { + defer a.sem.Release(1) + defer func() { + if r := recover(); r != nil { + err = errors.Join(err, fmt.Errorf("recovered from panic in evaluate: %v", r)) + } + }() + + if _, ok := seen[expr.GetID()]; ok { + return nil + } - seen[match.GetID()] = struct{}{} + atomic.AddInt32(&matched, 1) + // NOTE: We don't need to add lifted expression variables, + // because match.Parsed.Evaluable() returns the original expression + // string. + ok, evalerr := a.eval(ctx, expr, data) - if evalerr != nil { - err = errors.Join(err, evalerr) - continue - } - if ok { - result = append(result, match) - } + seen[expr.GetID()] = struct{}{} + if evalerr != nil { + return evalerr + } + if ok { + s.Lock() + result = append(result, expr) + s.Unlock() + } + return nil + }) + } + + if werr := eg.Wait(); werr != nil { + err = errors.Join(err, werr) } return result, matched, err diff --git a/expr_test.go b/expr_test.go index 62f5c0e..1e16e73 100644 --- a/expr_test.go +++ b/expr_test.go @@ -77,7 +77,7 @@ func evaluate(b *testing.B, i int, parser TreeParser) error { loader := newEvalLoader() loader.AddEval(expected) - e := NewAggregateEvaluator(parser, testBoolEvaluator, loader.Load) + e := NewAggregateEvaluator(parser, testBoolEvaluator, loader.Load, 0) _, _ = e.Add(ctx, expected) addOtherExpressions(i, e, loader) @@ -108,7 +108,7 @@ func TestAdd(t *testing.T) { expr := tex(`event.data == {"a":1}`) loader.AddEval(expr) - e := NewAggregateEvaluator(parser, testBoolEvaluator, loader.Load) + e := NewAggregateEvaluator(parser, testBoolEvaluator, loader.Load, 0) _, err := e.Add(ctx, expr) require.NoError(t, err) @@ -124,7 +124,7 @@ func TestEvaluate_Strings(t *testing.T) { loader := newEvalLoader() loader.AddEval(expected) - e := NewAggregateEvaluator(parser, testBoolEvaluator, loader.Load) + e := NewAggregateEvaluator(parser, testBoolEvaluator, loader.Load, 0) _, err := e.Add(ctx, expected) require.NoError(t, err) @@ -180,63 +180,127 @@ func TestEvaluate_Numbers(t *testing.T) { ctx := context.Background() parser := NewTreeParser(NewCachingCompiler(newEnv(), nil)) - // This is the expected epression - expected := tex(`326909.0 == event.data.account_id && (event.data.ts == null || event.data.ts > 1714000000000)`) - // expected := tex(`event.data.id == 25`) - loader := newEvalLoader() - loader.AddEval(expected) + t.Run("With annoying floats", func(t *testing.T) { + // This is the expected epression + expected := tex(`4.797009e+06 == event.data.id && (event.data.ts == null || event.data.ts > 1715211850340)`) + // expected := tex(`event.data.id == 25`) + loader := newEvalLoader() + loader.AddEval(expected) - e := NewAggregateEvaluator(parser, testBoolEvaluator, loader.Load) + e := NewAggregateEvaluator(parser, testBoolEvaluator, loader.Load, 0) - _, err := e.Add(ctx, expected) - require.NoError(t, err) + _, err := e.Add(ctx, expected) + require.NoError(t, err) - n := 1 + n := 1 - addOtherExpressions(n, e, loader) + addOtherExpressions(n, e, loader) - require.EqualValues(t, n+1, e.Len()) + require.EqualValues(t, n+1, e.Len()) - t.Run("It matches items", func(t *testing.T) { - pre := time.Now() - evals, matched, err := e.Evaluate(ctx, map[string]any{ - "event": map[string]any{ - "data": map[string]any{ - "account_id": 326909, - "ts": 1714000000001, + t.Run("It matches items", func(t *testing.T) { + pre := time.Now() + evals, matched, err := e.Evaluate(ctx, map[string]any{ + "event": map[string]any{ + "data": map[string]any{ + "id": 4797009, + "ts": 2015211850340, + }, }, - }, + }) + total := time.Since(pre) + fmt.Printf("Matched in %v ns\n", total.Nanoseconds()) + fmt.Printf("Matched in %v ms (%d)\n", total.Milliseconds(), matched) + + require.NoError(t, err) + require.EqualValues(t, []Evaluable{expected}, evals) + + // Assert that we only evaluate one expression. + require.Equal(t, matched, int32(1)) }) - total := time.Since(pre) - fmt.Printf("Matched in %v ns\n", total.Nanoseconds()) - fmt.Printf("Matched in %v ms (%d)\n", total.Milliseconds(), matched) - require.NoError(t, err) - require.EqualValues(t, []Evaluable{expected}, evals) + t.Run("It handles non-matching data", func(t *testing.T) { + pre := time.Now() + evals, matched, err := e.Evaluate(ctx, map[string]any{ + "event": map[string]any{ + "data": map[string]any{ + "account_id": "yes", + "ts": "???", + "match": "no", + }, + }, + }) + total := time.Since(pre) + fmt.Printf("Matched in %v ns\n", total.Nanoseconds()) + fmt.Printf("Matched in %v ms\n", total.Milliseconds()) - // Assert that we only evaluate one expression. - require.Equal(t, matched, int32(1)) + require.NoError(t, err) + require.EqualValues(t, 0, len(evals)) + // require.EqualValues(t, 0, matched) // We still ran one expression + _ = matched + }) }) - t.Run("It handles non-matching data", func(t *testing.T) { - pre := time.Now() - evals, matched, err := e.Evaluate(ctx, map[string]any{ - "event": map[string]any{ - "data": map[string]any{ - "account_id": "yes", - "ts": "???", - "match": "no", + t.Run("With floats", func(t *testing.T) { + + // This is the expected epression + expected := tex(`326909.0 == event.data.account_id && (event.data.ts == null || event.data.ts > 1714000000000)`) + // expected := tex(`event.data.id == 25`) + loader := newEvalLoader() + loader.AddEval(expected) + + e := NewAggregateEvaluator(parser, testBoolEvaluator, loader.Load, 0) + + _, err := e.Add(ctx, expected) + require.NoError(t, err) + + n := 1 + + addOtherExpressions(n, e, loader) + + require.EqualValues(t, n+1, e.Len()) + + t.Run("It matches items", func(t *testing.T) { + pre := time.Now() + evals, matched, err := e.Evaluate(ctx, map[string]any{ + "event": map[string]any{ + "data": map[string]any{ + "account_id": 326909, + "ts": 1714000000001, + }, }, - }, + }) + total := time.Since(pre) + fmt.Printf("Matched in %v ns\n", total.Nanoseconds()) + fmt.Printf("Matched in %v ms (%d)\n", total.Milliseconds(), matched) + + require.NoError(t, err) + require.EqualValues(t, []Evaluable{expected}, evals) + + // Assert that we only evaluate one expression. + require.Equal(t, matched, int32(1)) }) - total := time.Since(pre) - fmt.Printf("Matched in %v ns\n", total.Nanoseconds()) - fmt.Printf("Matched in %v ms\n", total.Milliseconds()) - require.NoError(t, err) - require.EqualValues(t, 0, len(evals)) - // require.EqualValues(t, 0, matched) // We still ran one expression - _ = matched + t.Run("It handles non-matching data", func(t *testing.T) { + pre := time.Now() + evals, matched, err := e.Evaluate(ctx, map[string]any{ + "event": map[string]any{ + "data": map[string]any{ + "account_id": "yes", + "ts": "???", + "match": "no", + }, + }, + }) + total := time.Since(pre) + fmt.Printf("Matched in %v ns\n", total.Nanoseconds()) + fmt.Printf("Matched in %v ms\n", total.Milliseconds()) + + require.NoError(t, err) + require.EqualValues(t, 0, len(evals)) + // require.EqualValues(t, 0, matched) // We still ran one expression + _ = matched + }) }) } @@ -248,7 +312,7 @@ func TestEvaluate_Concurrently(t *testing.T) { loader := newEvalLoader() loader.AddEval(expected) - e := NewAggregateEvaluator(parser, testBoolEvaluator, loader.Load) + e := NewAggregateEvaluator(parser, testBoolEvaluator, loader.Load, 0) _, err := e.Add(ctx, expected) require.NoError(t, err) @@ -287,7 +351,7 @@ func TestEvaluate_ArrayIndexes(t *testing.T) { loader := newEvalLoader() loader.AddEval(expected) - e := NewAggregateEvaluator(parser, testBoolEvaluator, loader.Load) + e := NewAggregateEvaluator(parser, testBoolEvaluator, loader.Load, 0) _, err := e.Add(ctx, expected) require.NoError(t, err) @@ -337,7 +401,7 @@ func TestEvaluate_Compound(t *testing.T) { loader := newEvalLoader() loader.AddEval(expected) - e := NewAggregateEvaluator(parser, testBoolEvaluator, loader.Load) + e := NewAggregateEvaluator(parser, testBoolEvaluator, loader.Load, 0) ok, err := e.Add(ctx, expected) require.True(t, ok) @@ -384,7 +448,7 @@ func TestAggregateMatch(t *testing.T) { loader := newEvalLoader() - e := NewAggregateEvaluator(parser, testBoolEvaluator, loader.Load) + e := NewAggregateEvaluator(parser, testBoolEvaluator, loader.Load, 0) // Add three expressions matching on "a", "b", "c" respectively. keys := []string{"a", "b", "c"} @@ -463,7 +527,7 @@ func TestMacros(t *testing.T) { require.NoError(t, err) loader := newEvalLoader() - e := NewAggregateEvaluator(parser, testBoolEvaluator, loader.Load) + e := NewAggregateEvaluator(parser, testBoolEvaluator, loader.Load, 0) eval := tex(`event.data.ok == "true" || event.data.ids.exists(id, id == 'c')`) loader.AddEval(eval) ok, err := e.Add(ctx, eval) @@ -525,7 +589,7 @@ func TestAddRemove(t *testing.T) { loader := newEvalLoader() t.Run("With a basic aggregateable expression", func(t *testing.T) { - e := NewAggregateEvaluator(parser, testBoolEvaluator, loader.Load) + e := NewAggregateEvaluator(parser, testBoolEvaluator, loader.Load, 0) firstExpr := tex(`event.data.foo == "yes"`, "first-id") loader.AddEval(firstExpr) @@ -622,7 +686,7 @@ func TestAddRemove(t *testing.T) { }) t.Run("With a non-aggregateable expression due to inequality/GTE on strings", func(t *testing.T) { - e := NewAggregateEvaluator(parser, testBoolEvaluator, loader.Load) + e := NewAggregateEvaluator(parser, testBoolEvaluator, loader.Load, 0) ok, err := e.Add(ctx, loader.AddEval(tex(`event.data.foo != "no"`))) require.NoError(t, err) @@ -670,7 +734,7 @@ func TestEmptyExpressions(t *testing.T) { loader := newEvalLoader() - e := NewAggregateEvaluator(parser, testBoolEvaluator, loader.Load) + e := NewAggregateEvaluator(parser, testBoolEvaluator, loader.Load, 0) empty := loader.AddEval(tex(``, "id-1")) @@ -712,7 +776,7 @@ func TestEvaluate_Null(t *testing.T) { loader := newEvalLoader() - e := NewAggregateEvaluator(parser, testBoolEvaluator, loader.Load) + e := NewAggregateEvaluator(parser, testBoolEvaluator, loader.Load, 0) notNull := loader.AddEval(tex(`event.ts != null`, "id-1")) isNull := loader.AddEval(tex(`event.ts == null`, "id-2")) @@ -799,7 +863,7 @@ func TestEvaluate_Null(t *testing.T) { }) t.Run("Two idents aren't treated as nulls", func(t *testing.T) { - e := NewAggregateEvaluator(parser, testBoolEvaluator, loader.Load) + e := NewAggregateEvaluator(parser, testBoolEvaluator, loader.Load, 0) idents := loader.AddEval(tex("event.data.a == event.data.b")) ok, err := e.Add(ctx, idents) require.NoError(t, err) diff --git a/vendor/golang.org/x/sync/semaphore/semaphore.go b/vendor/golang.org/x/sync/semaphore/semaphore.go new file mode 100644 index 0000000..30f632c --- /dev/null +++ b/vendor/golang.org/x/sync/semaphore/semaphore.go @@ -0,0 +1,136 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package semaphore provides a weighted semaphore implementation. +package semaphore // import "golang.org/x/sync/semaphore" + +import ( + "container/list" + "context" + "sync" +) + +type waiter struct { + n int64 + ready chan<- struct{} // Closed when semaphore acquired. +} + +// NewWeighted creates a new weighted semaphore with the given +// maximum combined weight for concurrent access. +func NewWeighted(n int64) *Weighted { + w := &Weighted{size: n} + return w +} + +// Weighted provides a way to bound concurrent access to a resource. +// The callers can request access with a given weight. +type Weighted struct { + size int64 + cur int64 + mu sync.Mutex + waiters list.List +} + +// Acquire acquires the semaphore with a weight of n, blocking until resources +// are available or ctx is done. On success, returns nil. On failure, returns +// ctx.Err() and leaves the semaphore unchanged. +// +// If ctx is already done, Acquire may still succeed without blocking. +func (s *Weighted) Acquire(ctx context.Context, n int64) error { + s.mu.Lock() + if s.size-s.cur >= n && s.waiters.Len() == 0 { + s.cur += n + s.mu.Unlock() + return nil + } + + if n > s.size { + // Don't make other Acquire calls block on one that's doomed to fail. + s.mu.Unlock() + <-ctx.Done() + return ctx.Err() + } + + ready := make(chan struct{}) + w := waiter{n: n, ready: ready} + elem := s.waiters.PushBack(w) + s.mu.Unlock() + + select { + case <-ctx.Done(): + err := ctx.Err() + s.mu.Lock() + select { + case <-ready: + // Acquired the semaphore after we were canceled. Rather than trying to + // fix up the queue, just pretend we didn't notice the cancelation. + err = nil + default: + isFront := s.waiters.Front() == elem + s.waiters.Remove(elem) + // If we're at the front and there're extra tokens left, notify other waiters. + if isFront && s.size > s.cur { + s.notifyWaiters() + } + } + s.mu.Unlock() + return err + + case <-ready: + return nil + } +} + +// TryAcquire acquires the semaphore with a weight of n without blocking. +// On success, returns true. On failure, returns false and leaves the semaphore unchanged. +func (s *Weighted) TryAcquire(n int64) bool { + s.mu.Lock() + success := s.size-s.cur >= n && s.waiters.Len() == 0 + if success { + s.cur += n + } + s.mu.Unlock() + return success +} + +// Release releases the semaphore with a weight of n. +func (s *Weighted) Release(n int64) { + s.mu.Lock() + s.cur -= n + if s.cur < 0 { + s.mu.Unlock() + panic("semaphore: released more than held") + } + s.notifyWaiters() + s.mu.Unlock() +} + +func (s *Weighted) notifyWaiters() { + for { + next := s.waiters.Front() + if next == nil { + break // No more waiters blocked. + } + + w := next.Value.(waiter) + if s.size-s.cur < w.n { + // Not enough tokens for the next waiter. We could keep going (to try to + // find a waiter with a smaller request), but under load that could cause + // starvation for large requests; instead, we leave all remaining waiters + // blocked. + // + // Consider a semaphore used as a read-write lock, with N tokens, N + // readers, and one writer. Each reader can Acquire(1) to obtain a read + // lock. The writer can Acquire(N) to obtain a write lock, excluding all + // of the readers. If we allow the readers to jump ahead in the queue, + // the writer will starve — there is always one token available for every + // reader. + break + } + + s.cur += w.n + s.waiters.Remove(next) + close(w.ready) + } +} diff --git a/vendor/modules.txt b/vendor/modules.txt index 966c989..5c43557 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -61,6 +61,7 @@ golang.org/x/exp/slices # golang.org/x/sync v0.6.0 ## explicit; go 1.18 golang.org/x/sync/errgroup +golang.org/x/sync/semaphore # golang.org/x/text v0.9.0 ## explicit; go 1.17 golang.org/x/text/transform