Skip to content

Commit

Permalink
Refactor lifted args to use basic string walker
Browse files Browse the repository at this point in the history
This removes regexp, taking >= 60% less CPU time
  • Loading branch information
tonyhb committed Jan 5, 2024
1 parent a458261 commit 713e648
Show file tree
Hide file tree
Showing 6 changed files with 168 additions and 58 deletions.
46 changes: 2 additions & 44 deletions caching_parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@ package expr

import (
"regexp"
"strconv"
"strings"
"sync"
"sync/atomic"

Expand All @@ -13,13 +11,9 @@ import (

var (
doubleQuoteMatch *regexp.Regexp

Check failure on line 13 in caching_parser.go

View workflow job for this annotation

GitHub Actions / lint (ubuntu-latest)

var `doubleQuoteMatch` is unused (unused)
replace = []string{"a", "b", "c", "d", "e", "f", "g", "h", "i", "j"}
replace = []string{"a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t"}
)

func init() {
doubleQuoteMatch = regexp.MustCompile(`"[^"]*"`)
}

// NewCachingParser returns a CELParser which lifts quoted literals out of the expression
// as variables and uses caching to cache expression parsing, resulting in improved
// performance when parsing expressions.
Expand All @@ -40,43 +34,7 @@ type cachingParser struct {
misses int64
}

// liftLiterals lifts quoted literals into variables, allowing us to normalize
// expressions to increase cache hit rates.
func liftLiterals(expr string) (string, map[string]any) {
// TODO: Optimize this please. Use strconv.Unquote as the basis, and perform
// searches across each index quotes.

// If this contains an escape sequence (eg. `\` or `\'`), skip the lifting
// of literals out of the expression.
if strings.Contains(expr, `\"`) || strings.Contains(expr, `\'`) {
return expr, nil
}

var (
counter int
vars = map[string]any{}
)

rewrite := func(str string) string {
if counter > len(replace) {
return str
}

idx := replace[counter]
if val, err := strconv.Unquote(str); err == nil {
str = val
}
vars[idx] = str

counter++
return VarPrefix + idx
}

expr = doubleQuoteMatch.ReplaceAllStringFunc(expr, rewrite)
return expr, vars
}

func (c *cachingParser) Parse(expr string) (*cel.Ast, *cel.Issues, map[string]any) {
func (c *cachingParser) Parse(expr string) (*cel.Ast, *cel.Issues, LiftedArgs) {
expr, vars := liftLiterals(expr)

// TODO: ccache, when I have internet.
Expand Down
4 changes: 2 additions & 2 deletions caching_parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ func TestCachingParser_CachesSame(t *testing.T) {
var (
prevAST *cel.Ast
prevIssues *cel.Issues
prevVars map[string]any
prevVars LiftedArgs
)

t.Run("With an uncached expression", func(t *testing.T) {
Expand Down Expand Up @@ -61,7 +61,7 @@ func TestCachingParser_CacheIgnoreLiterals_Unescaped(t *testing.T) {
var (
prevAST *cel.Ast
prevIssues *cel.Issues
prevVars map[string]any
prevVars LiftedArgs
)

t.Run("With an uncached expression", func(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion expr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func evaluate(b *testing.B, i int, parser TreeParser) error {

func TestEvaluate(t *testing.T) {
ctx := context.Background()
parser, err := newParser()
parser, err := NewTreeParser(NewCachingParser(newEnv()))
require.NoError(t, err)
e := NewAggregateEvaluator(parser, testBoolEvaluator)

Expand Down
153 changes: 153 additions & 0 deletions lift.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
package expr

import (
"fmt"
"strconv"
"strings"
)

type LiftedArgs interface {
Get(val string) (any, bool)
}

// liftLiterals lifts quoted literals into variables, allowing us to normalize
// expressions to increase cache hit rates.
func liftLiterals(expr string) (string, LiftedArgs) {
// TODO: Lift numeric literals out of expressions.
// If this contains an escape sequence (eg. `\` or `\'`), skip the lifting
// of literals out of the expression.
if strings.Contains(expr, `\"`) || strings.Contains(expr, `\'`) {
return expr, nil
}

lp := liftParser{expr: expr}
return lp.lift()
}

type liftParser struct {
expr string
idx int

rewritten *strings.Builder

// varCounter counts the number of variables lifted.
varCounter int

vars pointerArgMap
}

func (l *liftParser) lift() (string, LiftedArgs) {
l.vars = pointerArgMap{
expr: l.expr,
vars: map[string]argMapValue{},
}

l.rewritten = &strings.Builder{}

for l.idx < len(l.expr) {
char := l.expr[l.idx]

l.idx++

switch char {
case '"':
// Consume the string arg.
val := l.consumeString('"')
l.addLiftedVar(val)

case '\'':
val := l.consumeString('\'')
l.addLiftedVar(val)
default:
l.rewritten.WriteByte(char)
}
}

return l.rewritten.String(), l.vars
}

func (l *liftParser) addLiftedVar(val argMapValue) {
if l.varCounter >= len(replace) {
// Do nothing.
str := val.get(l.expr)
l.rewritten.WriteString(strconv.Quote(str.(string)))
return
}

letter := replace[l.varCounter]

l.vars.vars[letter] = val
l.varCounter++

l.rewritten.WriteString(VarPrefix + letter)
}

func (l *liftParser) consumeString(quoteChar byte) argMapValue {
offset := l.idx
length := 0
for l.idx < len(l.expr) {
char := l.expr[l.idx]

// Grab the next char for evaluation.
l.idx++

if char == '\\' && l.peek() == quoteChar {
// If we're escaping the quote character, ignore it.
l.idx++
length++
continue
}

if char == quoteChar {
return argMapValue{offset, length}
}

// Only now has the length of the inner quote increased.
length++
}

// Should never happen: we should always find the ending string quote, as the
// expression should have already been validated.
panic(fmt.Sprintf("unable to parse quoted string: `%s` (offset %d)", l.expr, offset))
}

func (l *liftParser) peek() byte {
if (l.idx + 1) >= len(l.expr) {
return 0x0
}
return l.expr[l.idx+1]
}

// pointerArgMap takes the original expression, and adds pointers to the original expression
// in order to grab variables.
//
// It does this by pointing to the offset and length of data within the expression, as opposed
// to extracting the value into a new string. This greatly reduces memory growth & heap allocations.
type pointerArgMap struct {
expr string
vars map[string]argMapValue
}

func (p pointerArgMap) Get(key string) (any, bool) {
val, ok := p.vars[key]
if !ok {
return nil, false
}
data := val.get(p.expr)
return data, true
}

// argMapValue represents an offset and length for an argument in an expression string
type argMapValue [2]int

func (a argMapValue) get(expr string) any {
data := expr[a[0] : a[0]+a[1]]
return data
}

type regularArgMap map[string]any

func (p regularArgMap) Get(key string) (any, bool) {
val, ok := p[key]
return val, ok
}
15 changes: 7 additions & 8 deletions parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ type TreeParser interface {
// to provide a caching layer on top of *cel.Env to optimize parsing, as it's
// the slowest part of the expression process.
type CELParser interface {
Parse(expr string) (*cel.Ast, *cel.Issues, map[string]any)
Parse(expr string) (*cel.Ast, *cel.Issues, LiftedArgs)
}

// EnvParser turns a *cel.Env into a CELParser.
Expand All @@ -41,7 +41,7 @@ type envparser struct {
env *cel.Env
}

func (e envparser) Parse(txt string) (*cel.Ast, *cel.Issues, map[string]any) {
func (e envparser) Parse(txt string) (*cel.Ast, *cel.Issues, LiftedArgs) {
ast, iss := e.env.Parse(txt)
return ast, iss, nil
}
Expand Down Expand Up @@ -98,8 +98,7 @@ type ParsedExpression struct {
// share the same expression. Using the same expression allows us
// to cache and skip CEL parsing, which is the slowest aspect of
// expression matching.
//
Vars map[string]any
Vars LiftedArgs

// Evaluable stores the original evaluable interface that was parsed.
Evaluable Evaluable
Expand Down Expand Up @@ -330,7 +329,7 @@ type expr struct {
// It does this by iterating through the expression, amending the current `group` until
// an or expression is found. When an or expression is found, we create another group which
// is mutated by the iteration.
func navigateAST(nav expr, parent *Node, vars map[string]any) ([]*Node, error) {
func navigateAST(nav expr, parent *Node, vars LiftedArgs) ([]*Node, error) {
// on the very first call to navigateAST, ensure that we set the first node
// inside the nodemap.
result := []*Node{}
Expand Down Expand Up @@ -464,7 +463,7 @@ func peek(nav expr, operator string) []expr {
// callToPredicate transforms a function call within an expression (eg `>`) into
// a Predicate struct for our matching engine. It ahandles normalization of
// LHS/RHS plus inversions.
func callToPredicate(item celast.Expr, negated bool, vars map[string]any) *Predicate {
func callToPredicate(item celast.Expr, negated bool, vars LiftedArgs) *Predicate {
fn := item.AsCall().FunctionName()
if fn == operators.LogicalAnd || fn == operators.LogicalOr {
// Quit early, as we descend into these while iterating through the tree when calling this.
Expand Down Expand Up @@ -547,7 +546,7 @@ func callToPredicate(item celast.Expr, negated bool, vars map[string]any) *Predi
}

if aIsVar {
if val, ok := vars[strings.TrimPrefix(identA, VarPrefix)]; ok {
if val, ok := vars.Get(strings.TrimPrefix(identA, VarPrefix)); ok {
// Normalize.
literal = val
identA = identB
Expand All @@ -556,7 +555,7 @@ func callToPredicate(item celast.Expr, negated bool, vars map[string]any) *Predi
}

if bIsVar {
if val, ok := vars[strings.TrimPrefix(identB, VarPrefix)]; ok {
if val, ok := vars.Get(strings.TrimPrefix(identB, VarPrefix)); ok {
// Normalize.
literal = val
identB = ""
Expand Down
6 changes: 3 additions & 3 deletions parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -997,7 +997,7 @@ func TestParse_LiftedVars(t *testing.T) {
Operator: operators.Equals,
},
},
Vars: map[string]any{
Vars: regularArgMap{
"a": "foo",
},
},
Expand All @@ -1013,7 +1013,7 @@ func TestParse_LiftedVars(t *testing.T) {
Operator: operators.Equals,
},
},
Vars: map[string]any{
Vars: regularArgMap{
"a": "bar",
},
},
Expand All @@ -1029,7 +1029,7 @@ func TestParse_LiftedVars(t *testing.T) {
Operator: operators.Equals,
},
},
Vars: map[string]any{
Vars: regularArgMap{
"a": "bar",
},
},
Expand Down

0 comments on commit 713e648

Please sign in to comment.