Skip to content

Commit

Permalink
Enable concurrent evaluations of matched expressions (#20)
Browse files Browse the repository at this point in the history
* Enable concurrent evaluations of matched expressions

Do this via semaphores.

* Paranoia gets ya

* waits

* Fix tests

* safety

* mas test fixesa
  • Loading branch information
tonyhb authored Jul 17, 2024
1 parent b064fe2 commit 90b4cde
Show file tree
Hide file tree
Showing 4 changed files with 342 additions and 86 deletions.
117 changes: 86 additions & 31 deletions expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import (

"github.com/google/cel-go/common/operators"
"github.com/google/uuid"
"golang.org/x/sync/errgroup"
"golang.org/x/sync/semaphore"
)

var (
Expand Down Expand Up @@ -70,11 +72,17 @@ func NewAggregateEvaluator(
parser TreeParser,
eval ExpressionEvaluator,
evalLoader EvaluableLoader,
concurrency int64,
) AggregateEvaluator {
if concurrency <= 0 {
concurrency = 1
}

return &aggregator{
eval: eval,
parser: parser,
loader: evalLoader,
sem: semaphore.NewWeighted(concurrency),
engines: map[EngineType]MatchingEngine{
EngineTypeStringHash: newStringEqualityMatcher(),
EngineTypeNullMatch: newNullMatcher(),
Expand All @@ -92,6 +100,8 @@ type aggregator struct {
// engines records all engines
engines map[EngineType]MatchingEngine

sem *semaphore.Weighted

// lock prevents concurrent updates of data
lock *sync.RWMutex
// len stores the current len of aggregable expressions.
Expand Down Expand Up @@ -123,6 +133,7 @@ func (a *aggregator) Evaluate(ctx context.Context, data map[string]any) ([]Evalu
err error
matched = int32(0)
result = []Evaluable{}
s sync.Mutex
)

// TODO: Concurrently match constant expressions using a semaphore for capacity.
Expand All @@ -131,25 +142,49 @@ func (a *aggregator) Evaluate(ctx context.Context, data map[string]any) ([]Evalu
if err != nil {
return nil, 0, err
}
for _, expr := range constantEvals {
atomic.AddInt32(&matched, 1)

if expr.GetExpression() == "" {
result = append(result, expr)
continue
eg := errgroup.Group{}
for _, item := range constantEvals {
if err := a.sem.Acquire(ctx, 1); err != nil {
return result, matched, err
}

// NOTE: We don't need to add lifted expression variables,
// because match.Parsed.Evaluable() returns the original expression
// string.
ok, evalerr := a.eval(ctx, expr, data)
if evalerr != nil {
err = errors.Join(err, evalerr)
continue
}
if ok {
result = append(result, expr)
}
expr := item
eg.Go(func() error {
defer a.sem.Release(1)
defer func() {
if r := recover(); r != nil {
err = errors.Join(err, fmt.Errorf("recovered from panic in evaluate: %v", r))
}
}()

atomic.AddInt32(&matched, 1)

if expr.GetExpression() == "" {
s.Lock()
result = append(result, expr)
s.Unlock()
return nil
}

// NOTE: We don't need to add lifted expression variables,
// because match.Parsed.Evaluable() returns the original expression
// string.
ok, evalerr := a.eval(ctx, expr, data)
if evalerr != nil {
return evalerr
}
if ok {
s.Lock()
result = append(result, expr)
s.Unlock()
}
return nil
})
}

if werr := eg.Wait(); werr != nil {
err = errors.Join(err, werr)
}

matches, merr := a.AggregateMatch(ctx, data)
Expand All @@ -173,26 +208,46 @@ func (a *aggregator) Evaluate(ctx context.Context, data map[string]any) ([]Evalu
// ID's length.
seen := map[uuid.UUID]struct{}{}

eg = errgroup.Group{}
for _, match := range evaluables {
if _, ok := seen[match.GetID()]; ok {
continue
if err := a.sem.Acquire(ctx, 1); err != nil {
return result, matched, err
}

atomic.AddInt32(&matched, 1)
// NOTE: We don't need to add lifted expression variables,
// because match.Parsed.Evaluable() returns the original expression
// string.
ok, evalerr := a.eval(ctx, match, data)
expr := match
eg.Go(func() error {
defer a.sem.Release(1)
defer func() {
if r := recover(); r != nil {
err = errors.Join(err, fmt.Errorf("recovered from panic in evaluate: %v", r))
}
}()

if _, ok := seen[expr.GetID()]; ok {
return nil
}

seen[match.GetID()] = struct{}{}
atomic.AddInt32(&matched, 1)
// NOTE: We don't need to add lifted expression variables,
// because match.Parsed.Evaluable() returns the original expression
// string.
ok, evalerr := a.eval(ctx, expr, data)

if evalerr != nil {
err = errors.Join(err, evalerr)
continue
}
if ok {
result = append(result, match)
}
seen[expr.GetID()] = struct{}{}
if evalerr != nil {
return evalerr
}
if ok {
s.Lock()
result = append(result, expr)
s.Unlock()
}
return nil
})
}

if werr := eg.Wait(); werr != nil {
err = errors.Join(err, werr)
}

return result, matched, err
Expand Down
Loading

0 comments on commit 90b4cde

Please sign in to comment.