Skip to content

Commit

Permalink
Keep all evaluables in memory
Browse files Browse the repository at this point in the history
Don't refer to the loader when evaluating.
  • Loading branch information
tonyhb committed Nov 5, 2024
1 parent 01208c9 commit 5e98af3
Showing 1 changed file with 38 additions and 16 deletions.
54 changes: 38 additions & 16 deletions expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ func NewAggregateEvaluator(
EngineTypeBTree: newNumberMatcher(),
},
lock: &sync.RWMutex{},
evals: map[uuid.UUID]Evaluable{},
constants: map[uuid.UUID]struct{}{},
mixed: map[uuid.UUID]struct{}{},
}
Expand All @@ -117,6 +118,9 @@ type aggregator struct {
// fastLen stores the current len of purely aggregable expressions.
fastLen int32

// evals stores all original evaluables in the aggregator.
evals map[uuid.UUID]Evaluable

// mixed stores the current len of mixed aggregable expressions,
// eg "foo == '1' && bar != '1'". This is becasue != isn't aggregateable,
// but the first `==` is used as a prefilter.
Expand Down Expand Up @@ -171,20 +175,22 @@ func (a *aggregator) Evaluate(ctx context.Context, data map[string]any) ([]Evalu

// Match constant expressions always.
a.lock.RLock()
uuids := make([]uuid.UUID, len(a.constants))
constantEvals := make([]Evaluable, len(a.constants))
n := 0
for id := range a.constants {
uuids[n] = id
n++
for uuid := range a.constants {
if eval, ok := a.evals[uuid]; ok {
constantEvals[n] = eval
n++
}
}
a.lock.RUnlock()
constantEvals, err := a.loader(ctx, uuids...)
if err != nil {
return nil, 0, err
}

eg := errgroup.Group{}
for _, item := range constantEvals {
if item == nil {
continue
}

if err := a.sem.Acquire(ctx, 1); err != nil {
return result, matched, err
}
Expand Down Expand Up @@ -234,15 +240,19 @@ func (a *aggregator) Evaluate(ctx context.Context, data map[string]any) ([]Evalu
err = errors.Join(err, merr)
}

// Load all evaluable instances directly.
uuids = make([]uuid.UUID, len(matches))
for n, m := range matches {
uuids[n] = m.Parsed.EvaluableID
}
evaluables, lerr := a.loader(ctx, uuids...)
if err != nil {
err = errors.Join(err, lerr)
fmt.Sprintf("%#v\n", matches)

Check failure on line 243 in expr.go

View workflow job for this annotation

GitHub Actions / lint

unusedresult: result of fmt.Sprintf call not used (govet)

// Load all evaluable instances directly from the match
a.lock.RLock()
n = 0
evaluables := make([]Evaluable, len(matches))
for _, el := range matches {
if eval, ok := a.evals[el.Parsed.EvaluableID]; ok {
evaluables[n] = eval
n++
}
}
a.lock.RUnlock()

// Each match here is a potential success. When other trees and operators which are walkable
// are added (eg. >= operators on strings), ensure that we find the correct number of matches
Expand All @@ -253,6 +263,10 @@ func (a *aggregator) Evaluate(ctx context.Context, data map[string]any) ([]Evalu

eg = errgroup.Group{}
for _, match := range evaluables {
if match == nil {
continue
}

if err := a.sem.Acquire(ctx, 1); err != nil {
return result, matched, err
}
Expand Down Expand Up @@ -386,6 +400,10 @@ func (a *aggregator) Add(ctx context.Context, eval Evaluable) (float64, error) {
return -1, err
}

a.lock.Lock()
a.evals[eval.GetID()] = eval
a.lock.Unlock()

if eval.GetExpression() == "" || parsed.HasMacros {
// This is an empty expression which always matches.
a.lock.Lock()
Expand Down Expand Up @@ -433,6 +451,10 @@ func (a *aggregator) Add(ctx context.Context, eval Evaluable) (float64, error) {
}

func (a *aggregator) Remove(ctx context.Context, eval Evaluable) error {
a.lock.Lock()
delete(a.evals, eval.GetID())
a.lock.Unlock()

if eval.GetExpression() == "" {
return a.removeConstantEvaluable(ctx, eval)
}
Expand Down

0 comments on commit 5e98af3

Please sign in to comment.