Skip to content

Commit

Permalink
AST: Lazy evaluate boolean child nodes.
Browse files Browse the repository at this point in the history
This adds circuit breaking to boolean evaluation.
An AND boolean resolves to false if the first operand is false, and an OR boolean resolves to true if the first operand is true.
With this, second operands and skipped if they can have no sway on the node's result.
  • Loading branch information
Antoine Popineau committed Jan 14, 2025
1 parent 9f2e035 commit 896f4f0
Show file tree
Hide file tree
Showing 8 changed files with 273 additions and 9 deletions.
2 changes: 1 addition & 1 deletion integration_test/scenario_flow_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,7 @@ func createDecisions(
transactionPayloadJson = []byte(`{
"object_id": "{transaction_id}",
"updated_at": "2020-01-01T00:00:00Z",
"account_id": "{account_id_approve}",
"account_id": "{account_id_decline}",
"amount": 0
}`)
approveDivisionByZeroDecision := createAndTestDecision(ctx, t, transactionPayloadJson, table,
Expand Down
24 changes: 24 additions & 0 deletions models/ast/ast_function.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ type FuncAttributes struct {
// However, it is not consumed anywhere, and it is in NO WAY enforced by the compiler or even the runtime.
// The only source of truth for what named children an AST must/can have is in the ast nodes Evaluate function.
NamedArguments []string
// A function can define LazyChildEvaluation indicating that its children should be evaluated lazily,
// considering for every one of them if evaluation should continue or not. For the result value of one child, the
// function returns whether evaluation of subsequent children should continue (true) or not (false).
LazyChildEvaluation func(NodeEvaluation) bool
}

// If number of arguments -1 the function can take any number of arguments
Expand Down Expand Up @@ -141,10 +145,14 @@ var FuncAttributesMap = map[Function]FuncAttributes{
FUNC_AND: {
DebugName: "FUNC_AND",
AstName: "And",
// Boolean AND returns false if any child node evaluates to false
LazyChildEvaluation: shortCircuitIfFalse,
},
FUNC_OR: {
DebugName: "FUNC_OR",
AstName: "Or",
// Boolean OR returns true if any child nodes evluates to true
LazyChildEvaluation: shortCircuitIfTrue,
},
FUNC_TIME_ADD: {
DebugName: "FUNC_TIME_ADD",
Expand Down Expand Up @@ -286,3 +294,19 @@ func NewNodeDatabaseAccess(tableName string, fieldName string, path []string) No
AddNamedChild(AttributeFuncDbAccess.ArgumentFieldName, NewNodeConstant(fieldName)).
AddNamedChild(AttributeFuncDbAccess.ArgumentPathName, NewNodeConstant(path))
}

func shortCircuitIfTrue(res NodeEvaluation) bool {
if b, ok := res.ReturnValue.(bool); ok {
// If node returned true, we stop (return !true = false), otherwise, continue with true
return !b
}
return true
}

func shortCircuitIfFalse(res NodeEvaluation) bool {
if b, ok := res.ReturnValue.(bool); ok {
// If node returned false, we stop (return false), otherwise, continue with true
return b
}
return true
}
34 changes: 34 additions & 0 deletions pure_utils/map.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,37 @@ func MapValuesErr[Key comparable, T any, U any](src map[Key]T, f func(T) (U, err
}
return result, nil
}

// MapWhile maps over items in a slice and produces a slice of items of another type.
// Contrary to regular Map(), the callbacks returns a second boolean value to indicate if the operation
// should continue. It stops whenever the callback returns false.
func MapWhile[T, U any](src []T, f func(T) (U, bool)) []U {
us := make([]U, 0, len(src))
for i := range src {
item, next := f(src[i])

us = append(us, item)

if !next {
break
}
}
return us
}

// MapValuesWhile maps over a map's values in a slice and produces a slice of items of another type.
// Contrary to regular MapValues(), the callbacks returns a second boolean value to indicate if the operation
// should continue. It stops whenever the callback returns false.
func MapValuesWhile[Key comparable, T any, U any](src map[Key]T, f func(T) (U, bool)) map[Key]U {
result := make(map[Key]U, len(src))
for key, value := range src {
item, next := f(value)

result[key] = item

if !next {
break
}
}
return result
}
18 changes: 14 additions & 4 deletions usecases/ast_eval/evaluate_ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,29 @@ func EvaluateAst(ctx context.Context, environment AstEvaluationEnvironment, node

childEvaluationFail := false

evalChild := func(child ast.Node) ast.NodeEvaluation {
// Only interested in lazy callback which will have default value if an error is returned
attrs, _ := node.Function.Attributes()

evalChild := func(child ast.Node) (childEval ast.NodeEvaluation, evalNext bool) {
childEval, ok := EvaluateAst(ctx, environment, child)

if !ok {
childEvaluationFail = true
return
}
return childEval

// Should we continue evaluating subsequent children nodes? If the parent node is not lazy, yes (first condition),
// otherwise, it is determined by the return value of the LazyChildEvaluation function.
evalNext = environment.disableCircuitBreaking || attrs.LazyChildEvaluation == nil || attrs.LazyChildEvaluation(childEval)

return
}

// eval each child
evaluation := ast.NodeEvaluation{
Function: node.Function,
Children: pure_utils.Map(node.Children, evalChild),
NamedChildren: pure_utils.MapValues(node.NamedChildren, evalChild),
Children: pure_utils.MapWhile(node.Children, evalChild),
NamedChildren: pure_utils.MapValuesWhile(node.NamedChildren, evalChild),
}

if childEvaluationFail {
Expand Down
87 changes: 87 additions & 0 deletions usecases/ast_eval/evaluate_ast_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"testing"

"github.com/checkmarble/marble-backend/models/ast"
"github.com/checkmarble/marble-backend/utils"

"github.com/stretchr/testify/assert"
)
Expand Down Expand Up @@ -91,3 +92,89 @@ func NewAstOrFalse() ast.Node {
AddChild(ast.Node{Constant: false}).
AddChild(ast.Node{Constant: false})
}

func TestLazyAnd(t *testing.T) {
environment := NewAstEvaluationEnvironment()

for _, value := range []bool{true, false} {
root := ast.Node{Function: ast.FUNC_AND}.
AddChild(ast.Node{Function: ast.FUNC_EQUAL}.
AddChild(ast.Node{Constant: value}).
AddChild(ast.Node{Constant: true})).
AddChild(ast.Node{Function: ast.FUNC_UNKNOWN})

evaluation, ok := EvaluateAst(context.TODO(), environment, root)

switch value {
case false:
assert.True(t, ok, "unknown node should not be evaluated because of AND lazy evaluation")
assert.Len(t, evaluation.Children, 1, "lazy evaluated AND should only have one child")
case true:
assert.False(t, ok, "unknown node should be evaluated because of AND lazy evaluation")
assert.Len(t, evaluation.Children, 2, "lazy evaluated AND should have two children")
}
}
}

func TestLazyOr(t *testing.T) {
environment := NewAstEvaluationEnvironment()

for _, value := range []bool{true, false} {
root := ast.Node{Function: ast.FUNC_OR}.
AddChild(ast.Node{Function: ast.FUNC_EQUAL}.
AddChild(ast.Node{Constant: value}).
AddChild(ast.Node{Constant: true})).
AddChild(ast.Node{Function: ast.FUNC_UNKNOWN})

evaluation, ok := EvaluateAst(context.TODO(), environment, root)

switch value {
case true:
assert.True(t, ok, "unknown node should not be evaluated because of OR lazy evaluation")
assert.Len(t, evaluation.Children, 1, "lazy evaluates OR should only have one child")
case false:
assert.False(t, ok, "unknown node should be evaluated because of OR lazy evaluation")
assert.Len(t, evaluation.Children, 2, "lazy evaluated AND should have two children")
}
}
}

func TestLazyBooleanNulls(t *testing.T) {
tts := []struct {
fn ast.Function
lhs, rhs, res *bool
}{
{ast.FUNC_OR, nil, utils.Ptr(true), utils.Ptr(true)},
{ast.FUNC_OR, utils.Ptr(true), nil, utils.Ptr(true)},
{ast.FUNC_OR, nil, utils.Ptr(false), nil},
{ast.FUNC_OR, utils.Ptr(false), nil, nil},
{ast.FUNC_AND, nil, utils.Ptr(true), nil},
{ast.FUNC_AND, utils.Ptr(true), nil, nil},
{ast.FUNC_AND, nil, utils.Ptr(false), utils.Ptr(false)},
{ast.FUNC_AND, utils.Ptr(false), nil, utils.Ptr(false)},
}

environment := NewAstEvaluationEnvironment()

for _, tt := range tts {
root := ast.Node{Function: tt.fn}

for _, op := range []*bool{tt.lhs, tt.rhs} {
switch op {
case nil:
root = root.AddChild(ast.Node{Constant: nil})
default:
root = root.AddChild(ast.Node{Constant: *op})
}
}

evaluation, _ := EvaluateAst(context.TODO(), environment, root)

switch {
case tt.res == nil:
assert.Equal(t, nil, evaluation.ReturnValue)
default:
assert.Equal(t, *tt.res, evaluation.ReturnValue)
}
}
}
9 changes: 8 additions & 1 deletion usecases/ast_eval/evaluate_environment.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ import (
)

type AstEvaluationEnvironment struct {
availableFunctions map[ast.Function]evaluate.Evaluator
availableFunctions map[ast.Function]evaluate.Evaluator
disableCircuitBreaking bool
}

func (environment *AstEvaluationEnvironment) AddEvaluator(function ast.Function, evaluator evaluate.Evaluator) {
Expand All @@ -27,6 +28,12 @@ func (environment *AstEvaluationEnvironment) GetEvaluator(function ast.Function)
return nil, errors.New(fmt.Sprintf("function '%s' is not available", function.DebugString()))
}

func (environment AstEvaluationEnvironment) WithoutCircuitBreaking() AstEvaluationEnvironment {
environment.disableCircuitBreaking = true

return environment
}

func NewAstEvaluationEnvironment() AstEvaluationEnvironment {
environment := AstEvaluationEnvironment{
availableFunctions: make(map[ast.Function]evaluate.Evaluator),
Expand Down
4 changes: 3 additions & 1 deletion usecases/scenarios/scenario_validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,8 @@ func (validator *AstValidatorImpl) MakeDryRunEnvironment(ctx context.Context,
ClientObject: clientObject,
DataModel: dataModel,
DatabaseAccessReturnFakeValue: true,
})
}).
WithoutCircuitBreaking()

return env, nil
}
104 changes: 102 additions & 2 deletions usecases/scenarios/scenario_validation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ func TestValidateScenarioIterationImpl_Validate(t *testing.T) {
validator := AstValidatorImpl{
DataModelRepository: mdmr,
AstEvaluationEnvironmentFactory: func(params ast_eval.EvaluationEnvironmentFactoryParams) ast_eval.AstEvaluationEnvironment {
return ast_eval.NewAstEvaluationEnvironment()
return ast_eval.NewAstEvaluationEnvironment().WithoutCircuitBreaking()
},
ExecutorFactory: executorFactory,
}
Expand Down Expand Up @@ -183,7 +183,107 @@ func TestValidateScenarioIterationImpl_Validate_notBool(t *testing.T) {
validator := AstValidatorImpl{
DataModelRepository: mdmr,
AstEvaluationEnvironmentFactory: func(params ast_eval.EvaluationEnvironmentFactoryParams) ast_eval.AstEvaluationEnvironment {
return ast_eval.NewAstEvaluationEnvironment()
return ast_eval.NewAstEvaluationEnvironment().WithoutCircuitBreaking()
},
ExecutorFactory: executorFactory,
}

siValidator := ValidateScenarioIterationImpl{
AstValidator: &validator,
}

result := siValidator.Validate(ctx, models.ScenarioAndIteration{
Scenario: scenario,
Iteration: scenarioIteration,
})
assert.NotEmpty(t, ScenarioValidationToError(result))
}

func TestValidationShouldBypassCircuitBreaking(t *testing.T) {
ctx := utils.StoreLoggerInContext(context.Background(), utils.NewLogger("text"))
scenario := models.Scenario{
Id: uuid.New().String(),
OrganizationId: uuid.New().String(),
Name: "scenario_name",
Description: "description",
TriggerObjectType: "object_type",
CreatedAt: time.Now(),
LiveVersionID: utils.Ptr(uuid.New().String()),
}

scenarioIterationID := uuid.New().String()
scenarioIteration := models.ScenarioIteration{
Id: scenarioIterationID,
OrganizationId: scenario.OrganizationId,
ScenarioId: scenario.Id,
Version: utils.Ptr(1),
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
TriggerConditionAstExpression: utils.Ptr(ast.Node{
Constant: true,
}),
Rules: []models.Rule{
{
Id: "rule",
ScenarioIterationId: scenarioIterationID,
OrganizationId: scenario.OrganizationId,
DisplayOrder: 0,
Name: "rule",
Description: "description",
FormulaAstExpression: utils.Ptr(ast.Node{
Function: ast.FUNC_AND,
Constant: nil,
Children: []ast.Node{
{
Function: ast.FUNC_EQUAL,
Children: []ast.Node{
{Constant: 100},
{Constant: 101},
},
},
{
Function: ast.FUNC_EQUAL,
Children: []ast.Node{
{Constant: 100},
{Constant: "oplop"},
},
},
},
}),
ScoreModifier: 10,
CreatedAt: time.Now(),
},
},
ScoreReviewThreshold: utils.Ptr(100),
ScoreBlockAndReviewThreshold: utils.Ptr(1000),
ScoreDeclineThreshold: utils.Ptr(1000),
Schedule: "schedule",
}

exec := new(mocks.Executor)
executorFactory := new(mocks.ExecutorFactory)
executorFactory.On("NewExecutor").Once().Return(exec)
mdmr := new(mocks.DataModelRepository)
mdmr.On("GetDataModel", ctx, exec, scenario.OrganizationId, false).
Return(models.DataModel{
Version: "version",
Tables: map[string]models.Table{
"object_type": {
Name: "object_type",
Fields: map[string]models.Field{
"id": {
DataType: models.Int,
},
},
LinksToSingle: nil,
},
},
}, nil)

validator := AstValidatorImpl{
DataModelRepository: mdmr,
AstEvaluationEnvironmentFactory: func(params ast_eval.EvaluationEnvironmentFactoryParams) ast_eval.AstEvaluationEnvironment {
return ast_eval.NewAstEvaluationEnvironment().WithoutCircuitBreaking()
},
ExecutorFactory: executorFactory,
}
Expand Down

0 comments on commit 896f4f0

Please sign in to comment.