diff --git a/README.md b/README.md index c8e5595..8b730b3 100644 --- a/README.md +++ b/README.md @@ -2,12 +2,10 @@ ![Go version](https://img.shields.io/github/go-mod/go-version/ccremer/go-command-pipeline) [![Version](https://img.shields.io/github/v/release/ccremer/go-command-pipeline)][releases] -[![GitHub downloads](https://img.shields.io/github/downloads/ccremer/go-command-pipeline/total)][releases] [![Go Report Card](https://goreportcard.com/badge/github.com/ccremer/go-command-pipeline)][goreport] [![Codecov](https://img.shields.io/codecov/c/github/ccremer/go-command-pipeline?token=XGOC4XUMJ5)][codecov] -Small Go utility that executes high-level actions in a pipeline fashion. -Especially useful when combined with the Facade design pattern. +Small Go utility that executes business actions in a pipeline. ## Usage @@ -20,16 +18,71 @@ import ( func main() { p := pipeline.NewPipeline() p.WithSteps( - predicate.ToStep("clone repository", CloneGitRepository(), predicate.Not(DirExists("my-repo"))), - pipeline.NewStep("checkout branch", CheckoutBranch()), - pipeline.NewStep("pull", Pull()), + pipeline.NewStep("define random number", defineNumber), + pipeline.NewStepFromFunc("print number", printNumber), ) result := p.Run() if !result.IsSuccessful() { log.Fatal(result.Err) } } + +func defineNumber(ctx pipeline.Context) pipeline.Result { + ctx.SetValue("number", rand.Int()) + return pipeline.Result{} +} + +// Let's assume this is a business function that can fail. +// You can enable "automatic" fail-on-first-error pipelines by having more small functions that return errors. +func printNumber(ctx pipeline.Context) error { + _, err := fmt.Println(ctx.IntValue("number", 0)) + return err +} +``` + +## Who is it for + +This utility is interesting for you if you have many business functions that are executed sequentially, each with their own error handling. +Do you grow tired of the tedious error handling in Go when all you do is passing the error "up" in the stack in over 90% of the cases, only to log it at the root? +This utility helps you focus on the business logic by dividing each failure-prone action into small steps since pipeline aborts on first error. + +Consider the following prose example: +```go +func Persist(data Data) error { + err := database.prepareTransaction() + if err != nil { + return err + } + err = database.executeQuery("SOME QUERY", data) + if err != nil { + return err + } + err = database.commit() + return err +} +``` +We have tons of `if err != nil` that bloats the function with more error handling than actual interesting business logic. + +It could be simplified to something like this: +```go +func Persist(data Data) error { + p := pipeline.NewPipeline().WithSteps( + pipeline.NewStep("prepareTransaction", prepareTransaction()), + pipeline.NewStep("executeQuery", executeQuery(data)), + pipeline.NewStep("commitTransaction", commit()), + ) + return p.Run().Err +} + +func executeQuery(data Data) pipeline.ActionFunc { + return func(_ pipeline.Context) pipeline.Result { + err := database.executeQuery("SOME QUERY", data) + return pipeline.Result{Err: err} + } +} +... ``` +While it seems to add more lines in order to set up a pipeline, it makes it very easily understandable what `Persist()` does without all the error handling. [releases]: https://github.com/ccremer/go-command-pipeline/releases [codecov]: https://app.codecov.io/gh/ccremer/go-command-pipeline diff --git a/codecov.yml b/codecov.yml index 62e1929..872178e 100644 --- a/codecov.yml +++ b/codecov.yml @@ -1,6 +1,9 @@ comment: false coverage: status: + patch: + default: + threshold: 50% project: default: - threshold: 5% + threshold: 10% diff --git a/context.go b/context.go new file mode 100644 index 0000000..a8cb4b7 --- /dev/null +++ b/context.go @@ -0,0 +1,100 @@ +package pipeline + +// Context contains data relevant for the pipeline execution. +// It's primary purpose is to store and retrieve data within an ActionFunc. +type Context interface { + // Value returns the raw value identified by key. + // Returns nil if the key doesn't exist. + Value(key interface{}) interface{} + // ValueOrDefault returns the value identified by key if it exists. + // If not, then the given default value is returned. + ValueOrDefault(key interface{}, defaultValue interface{}) interface{} + // StringValue is a sugared accessor like ValueOrDefault, but converts the value to string. + // If the key cannot be found or if the value is not of type string, then the defaultValue is returned. + StringValue(key interface{}, defaultValue string) string + // BoolValue is a sugared accessor like ValueOrDefault, but converts the value to bool. + // If the key cannot be found or if the value is not of type bool, then the defaultValue is returned. + BoolValue(key interface{}, defaultValue bool) bool + // IntValue is a sugared accessor like ValueOrDefault, but converts the value to int. + // If the key cannot be found or if the value is not of type int, then the defaultValue is returned. + IntValue(key interface{}, defaultValue int) int + // SetValue sets the value at the given key. + SetValue(key interface{}, value interface{}) +} + +// DefaultContext implements Context using a Map internally. +type DefaultContext struct { + values map[interface{}]interface{} +} + +// Value implements Context.Value. +func (ctx *DefaultContext) Value(key interface{}) interface{} { + if ctx.values == nil { + return nil + } + return ctx.values[key] +} + +// ValueOrDefault implements Context.ValueOrDefault. +func (ctx *DefaultContext) ValueOrDefault(key interface{}, defaultValue interface{}) interface{} { + if ctx.values == nil { + return defaultValue + } + if raw, exists := ctx.values[key]; exists { + return raw + } + return defaultValue +} + +// StringValue implements Context.StringValue. +func (ctx *DefaultContext) StringValue(key interface{}, defaultValue string) string { + if ctx.values == nil { + return defaultValue + } + raw, exists := ctx.values[key] + if !exists { + return defaultValue + } + if strValue, isString := raw.(string); isString { + return strValue + } + return defaultValue +} + +// BoolValue implements Context.BoolValue. +func (ctx *DefaultContext) BoolValue(key interface{}, defaultValue bool) bool { + if ctx.values == nil { + return defaultValue + } + raw, exists := ctx.values[key] + if !exists { + return defaultValue + } + if boolValue, isBool := raw.(bool); isBool { + return boolValue + } + return defaultValue +} + +// IntValue implements Context.IntValue. +func (ctx *DefaultContext) IntValue(key interface{}, defaultValue int) int { + if ctx.values == nil { + return defaultValue + } + raw, exists := ctx.values[key] + if !exists { + return defaultValue + } + if intValue, isInt := raw.(int); isInt { + return intValue + } + return defaultValue +} + +// SetValue implements Context.SetValue. +func (ctx *DefaultContext) SetValue(key interface{}, value interface{}) { + if ctx.values == nil { + ctx.values = map[interface{}]interface{}{} + } + ctx.values[key] = value +} diff --git a/context_test.go b/context_test.go new file mode 100644 index 0000000..296ddd8 --- /dev/null +++ b/context_test.go @@ -0,0 +1,164 @@ +package pipeline + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +const stringKey = "stringKey" +const boolKey = "boolKey" +const intKey = "intKey" +const valueKey = "value" + +func TestDefaultContext_Implements_Context(t *testing.T) { + assert.Implements(t, (*Context)(nil), new(DefaultContext)) +} + +type valueTestCase struct { + givenValues map[interface{}]interface{} + + defaultBool bool + defaultString string + defaultInt int + + expectedBool bool + expectedString string + expectedInt int +} + +var valueTests = map[string]valueTestCase{ + "GivenNilValues_ThenExpectDefaults": { + givenValues: nil, + }, + "GivenNonExistingKey_ThenExpectDefaults": { + givenValues: map[interface{}]interface{}{}, + defaultBool: true, + expectedBool: true, + defaultString: "default", + expectedString: "default", + defaultInt: 10, + expectedInt: 10, + }, + "GivenExistingKey_WhenInvalidType_ThenExpectDefaults": { + givenValues: map[interface{}]interface{}{ + boolKey: "invalid", + stringKey: 0, + intKey: "invalid", + }, + defaultBool: true, + expectedBool: true, + defaultString: "default", + expectedString: "default", + defaultInt: 10, + expectedInt: 10, + }, + "GivenExistingKey_WhenValidType_ThenExpectValues": { + givenValues: map[interface{}]interface{}{ + boolKey: true, + stringKey: "string", + intKey: 10, + }, + expectedBool: true, + expectedString: "string", + expectedInt: 10, + }, +} + +func TestDefaultContext_BoolValue(t *testing.T) { + for name, tt := range valueTests { + t.Run(name, func(t *testing.T) { + ctx := DefaultContext{values: tt.givenValues} + result := ctx.BoolValue(boolKey, tt.defaultBool) + assert.Equal(t, tt.expectedBool, result) + }) + } +} + +func TestDefaultContext_StringValue(t *testing.T) { + for name, tt := range valueTests { + t.Run(name, func(t *testing.T) { + ctx := DefaultContext{values: tt.givenValues} + result := ctx.StringValue(stringKey, tt.defaultString) + assert.Equal(t, tt.expectedString, result) + }) + } +} + +func TestDefaultContext_IntValue(t *testing.T) { + for name, tt := range valueTests { + t.Run(name, func(t *testing.T) { + ctx := DefaultContext{values: tt.givenValues} + result := ctx.IntValue(intKey, tt.defaultInt) + assert.Equal(t, tt.expectedInt, result) + }) + } +} + +func TestDefaultContext_SetValue(t *testing.T) { + ctx := DefaultContext{values: map[interface{}]interface{}{}} + ctx.SetValue(stringKey, "string") + assert.Equal(t, "string", ctx.values[stringKey]) +} + +func TestDefaultContext_Value(t *testing.T) { + t.Run("GivenNilValues_ThenExpectNil", func(t *testing.T) { + ctx := DefaultContext{values: nil} + result := ctx.Value(valueKey) + assert.Nil(t, result) + }) + t.Run("GivenNonExistingKey_ThenExpectNil", func(t *testing.T) { + ctx := DefaultContext{values: map[interface{}]interface{}{}} + result := ctx.Value(valueKey) + assert.Nil(t, result) + }) + t.Run("GivenExistingKey_WhenKeyContainsNil_ThenExpectNil", func(t *testing.T) { + ctx := DefaultContext{values: map[interface{}]interface{}{ + valueKey: nil, + }} + result := ctx.Value(valueKey) + assert.Nil(t, result) + }) +} + +func TestDefaultContext_ValueOrDefault(t *testing.T) { + t.Run("GivenNilValues_ThenExpectDefault", func(t *testing.T) { + ctx := DefaultContext{values: nil} + result := ctx.ValueOrDefault(valueKey, valueKey) + assert.Equal(t, result, valueKey) + }) + t.Run("GivenNonExistingKey_ThenExpectDefault", func(t *testing.T) { + ctx := DefaultContext{values: map[interface{}]interface{}{}} + result := ctx.ValueOrDefault(valueKey, valueKey) + assert.Equal(t, result, valueKey) + }) + t.Run("GivenExistingKey_ThenExpectValue", func(t *testing.T) { + ctx := DefaultContext{values: map[interface{}]interface{}{ + valueKey: valueKey, + }} + result := ctx.ValueOrDefault(valueKey, "default") + assert.Equal(t, result, valueKey) + }) +} + +func ExampleDefaultContext_BoolValue() { + ctx := DefaultContext{} + ctx.SetValue("key", true) + fmt.Println(ctx.BoolValue("key", false)) + // Output: true +} + +func ExampleDefaultContext_StringValue() { + ctx := DefaultContext{} + ctx.SetValue("key", "string") + fmt.Println(ctx.StringValue("key", "default")) + // Output: string +} + +func ExampleDefaultContext_IntValue() { + ctx := DefaultContext{} + ctx.SetValue("key", 1) + fmt.Println(ctx.IntValue("key", 0)) + // Output: 1 +} diff --git a/examples/context_test.go b/examples/context_test.go new file mode 100644 index 0000000..f6b7f71 --- /dev/null +++ b/examples/context_test.go @@ -0,0 +1,34 @@ +//+build examples + +package examples + +import ( + "fmt" + "math/rand" + "testing" + + pipeline "github.com/ccremer/go-command-pipeline" +) + +func TestExample_Context(t *testing.T) { + // Create pipeline with defaults + p := pipeline.NewPipeline() + p.WithSteps( + pipeline.NewStep("define random number", defineNumber), + pipeline.NewStepFromFunc("print number", printNumber), + ) + result := p.Run() + if !result.IsSuccessful() { + t.Fatal(result.Err) + } +} + +func defineNumber(ctx pipeline.Context) pipeline.Result { + ctx.SetValue("number", rand.Int()) + return pipeline.Result{} +} + +func printNumber(ctx pipeline.Context) error { + _, err := fmt.Println(ctx.IntValue("number", 0)) + return err +} diff --git a/examples/git.go b/examples/git_test.go similarity index 70% rename from examples/git.go rename to examples/git_test.go index 43c114d..0a0c8e1 100644 --- a/examples/git.go +++ b/examples/git_test.go @@ -1,17 +1,18 @@ //+build examples -package git +package examples import ( "log" "os" "os/exec" + "testing" pipeline "github.com/ccremer/go-command-pipeline" "github.com/ccremer/go-command-pipeline/predicate" ) -func main() { +func TestExample_Git(t *testing.T) { p := pipeline.NewPipeline() p.WithSteps( predicate.ToStep("clone repository", CloneGitRepository(), predicate.Not(DirExists("my-repo"))), @@ -20,7 +21,7 @@ func main() { ) result := p.Run() if !result.IsSuccessful() { - log.Fatal(result.Err) + t.Fatal(result.Err) } } @@ -30,38 +31,32 @@ func logSuccess(result pipeline.Result) error { } func CloneGitRepository() pipeline.ActionFunc { - return func() pipeline.Result { + return func(_ pipeline.Context) pipeline.Result { err := execGitCommand("clone", "git@github.com/ccremer/go-command-pipeline") - if err != nil { - return pipeline.Result{Err: err} - } - return pipeline.Result{} + return pipeline.Result{Err: err} } } func Pull() pipeline.ActionFunc { - return func() pipeline.Result { + return func(_ pipeline.Context) pipeline.Result { err := execGitCommand("pull") - if err != nil { - return pipeline.Result{Err: err} - } - return pipeline.Result{} + return pipeline.Result{Err: err} } } func CheckoutBranch() pipeline.ActionFunc { - return func() pipeline.Result { + return func(_ pipeline.Context) pipeline.Result { err := execGitCommand("checkout", "master") - if err != nil { - return pipeline.Result{Err: err} - } - return pipeline.Result{} + return pipeline.Result{Err: err} } } func execGitCommand(args ...string) error { - cmd := exec.Command("git", args...) - return cmd.Run() + // replace 'echo' with actual 'git' binary + cmd := exec.Command("echo", args...) + cmd.Stdout = os.Stdout + err := cmd.Run() + return err } func DirExists(path string) predicate.Predicate { diff --git a/parallel/context.go b/parallel/context.go new file mode 100644 index 0000000..0af3443 --- /dev/null +++ b/parallel/context.go @@ -0,0 +1,62 @@ +package parallel + +import ( + "sync" + + pipeline "github.com/ccremer/go-command-pipeline" +) + +// ConcurrentContext implements pipeline.Context by wrapping an existing pipeline.Context in a sync.Mutex. +type ConcurrentContext struct { + WrappedContext pipeline.Context + m sync.Mutex +} + +// NewConcurrentContext wraps the given pipeline.Context in a new ConcurrentContext. +func NewConcurrentContext(wrapped pipeline.Context) ConcurrentContext { + return ConcurrentContext{ + WrappedContext: wrapped, + } +} + +// Value implements pipeline.Context:Value +func (ctx *ConcurrentContext) Value(key interface{}) interface{} { + ctx.m.Lock() + defer ctx.m.Unlock() + return ctx.WrappedContext.Value(key) +} + +// ValueOrDefault implements pipeline.Context:Value. +func (ctx *ConcurrentContext) ValueOrDefault(key interface{}, defaultValue interface{}) interface{} { + ctx.m.Lock() + defer ctx.m.Unlock() + return ctx.WrappedContext.ValueOrDefault(key, defaultValue) +} + +// StringValue implements pipeline.Context:StringValue. +func (ctx *ConcurrentContext) StringValue(key interface{}, defaultValue string) string { + ctx.m.Lock() + defer ctx.m.Unlock() + return ctx.WrappedContext.StringValue(key, defaultValue) +} + +// BoolValue implements pipeline.Context:BoolValue. +func (ctx *ConcurrentContext) BoolValue(key interface{}, defaultValue bool) bool { + ctx.m.Lock() + defer ctx.m.Unlock() + return ctx.WrappedContext.BoolValue(key, defaultValue) +} + +// IntValue implements pipeline.Context:IntValue. +func (ctx *ConcurrentContext) IntValue(key interface{}, defaultValue int) int { + ctx.m.Lock() + defer ctx.m.Unlock() + return ctx.WrappedContext.IntValue(key, defaultValue) +} + +// SetValue implements pipeline.Context:SetValue. +func (ctx *ConcurrentContext) SetValue(key interface{}, value interface{}) { + ctx.m.Lock() + defer ctx.m.Unlock() + ctx.WrappedContext.SetValue(key, value) +} diff --git a/parallel/context_test.go b/parallel/context_test.go new file mode 100644 index 0000000..de0e4ab --- /dev/null +++ b/parallel/context_test.go @@ -0,0 +1,19 @@ +package parallel + +import ( + "testing" + + pipeline "github.com/ccremer/go-command-pipeline" + "github.com/stretchr/testify/assert" +) + +func TestConcurrentContext_Implements_Context(t *testing.T) { + assert.Implements(t, (*pipeline.Context)(nil), new(ConcurrentContext)) +} + +func TestConcurrentContext_Constructor(t *testing.T) { + wrapped := pipeline.DefaultContext{} + ctx := NewConcurrentContext(&wrapped) + ctx.SetValue("key", "value") + assert.Equal(t, "value", ctx.StringValue("key", "default")) +} diff --git a/parallel/fanout.go b/parallel/fanout.go index 2825001..5695f3d 100644 --- a/parallel/fanout.go +++ b/parallel/fanout.go @@ -11,10 +11,11 @@ NewFanOutStep creates a pipeline step that runs nested pipelines in their own Go The function provided as PipelineSupplier is expected to close the given channel when no more pipelines should be executed, otherwise this step blocks forever. The step waits until all pipelines are finished. If the given ResultHandler is non-nil it will be called after all pipelines were run, otherwise the step is considered successful. +The given pipelines have to define their own pipeline.Context, it's not passed "down" from parent pipeline. */ func NewFanOutStep(name string, pipelineSupplier PipelineSupplier, handler ResultHandler) pipeline.Step { step := pipeline.Step{Name: name} - step.F = func() pipeline.Result { + step.F = func(_ pipeline.Context) pipeline.Result { pipelineChan := make(chan *pipeline.Pipeline) m := sync.Map{} var wg sync.WaitGroup diff --git a/parallel/fanout_test.go b/parallel/fanout_test.go index 9cae3a4..76988d9 100644 --- a/parallel/fanout_test.go +++ b/parallel/fanout_test.go @@ -49,13 +49,13 @@ func TestNewFanOutStep(t *testing.T) { step := NewFanOutStep("fanout", func(funcs chan *pipeline.Pipeline) { defer close(funcs) for i := 0; i < tt.jobs; i++ { - funcs <- pipeline.NewPipeline().WithSteps(pipeline.NewStep("step", func() pipeline.Result { + funcs <- pipeline.NewPipeline().WithSteps(pipeline.NewStep("step", func(_ pipeline.Context) pipeline.Result { atomic.AddUint64(&counts, 1) return pipeline.Result{Err: tt.returnErr} })) } }, handler) - result := step.F() + result := step.F(nil) assert.NoError(t, result.Err) assert.Equal(t, uint64(tt.expectedCounts), counts) }) @@ -69,7 +69,7 @@ func ExampleNewFanOutStep() { // create some pipelines for i := 0; i < 3; i++ { n := i - pipelines <- pipeline.NewPipeline().AddStep(pipeline.NewStep(fmt.Sprintf("i = %d", n), func() pipeline.Result { + pipelines <- pipeline.NewPipeline().AddStep(pipeline.NewStep(fmt.Sprintf("i = %d", n), func(_ pipeline.Context) pipeline.Result { time.Sleep(time.Duration(n * 10000000)) // fake some load fmt.Println(fmt.Sprintf("I am worker %d", n)) return pipeline.Result{} diff --git a/parallel/pool.go b/parallel/pool.go index b260d14..f6190dd 100644 --- a/parallel/pool.go +++ b/parallel/pool.go @@ -11,17 +11,18 @@ import ( NewWorkerPoolStep creates a pipeline step that runs nested pipelines in a thread pool. The function provided as PipelineSupplier is expected to close the given channel when no more pipelines should be executed, otherwise this step blocks forever. The step waits until all pipelines are finished. -If the given ResultHandler is non-nil it will be called after all pipelines were run, otherwise the step is considered successful. -The pipelines are executed in a pool of a number of Go routines indicated by size. -If size is 1, the pipelines are effectively run in sequence. -If size is 0 or less, the function panics. + * If the given ResultHandler is non-nil it will be called after all pipelines were run, otherwise the step is considered successful. + * The pipelines are executed in a pool of a number of Go routines indicated by size. + * If size is 1, the pipelines are effectively run in sequence. + * If size is 0 or less, the function panics. +The given pipelines have to define their own pipeline.Context, it's not passed "down" from parent pipeline. */ func NewWorkerPoolStep(name string, size int, pipelineSupplier PipelineSupplier, handler ResultHandler) pipeline.Step { if size < 1 { panic("pool size cannot be lower than 1") } step := pipeline.Step{Name: name} - step.F = func() pipeline.Result { + step.F = func(_ pipeline.Context) pipeline.Result { pipelineChan := make(chan *pipeline.Pipeline, size) m := sync.Map{} var wg sync.WaitGroup diff --git a/parallel/pool_test.go b/parallel/pool_test.go index 663e617..3f6d103 100644 --- a/parallel/pool_test.go +++ b/parallel/pool_test.go @@ -37,7 +37,7 @@ func TestNewWorkerPoolStep(t *testing.T) { } step := NewWorkerPoolStep("pool", 1, func(pipelines chan *pipeline.Pipeline) { defer close(pipelines) - pipelines <- pipeline.NewPipeline().AddStep(pipeline.NewStep("step", func() pipeline.Result { + pipelines <- pipeline.NewPipeline().AddStep(pipeline.NewStep("step", func(_ pipeline.Context) pipeline.Result { atomic.AddUint64(&counts, 1) return pipeline.Result{Err: tt.expectedError} })) @@ -45,7 +45,7 @@ func TestNewWorkerPoolStep(t *testing.T) { assert.Error(t, results[0].Err) return pipeline.Result{Err: results[0].Err} }) - result := step.F() + result := step.F(nil) assert.Error(t, result.Err) }) } @@ -58,7 +58,7 @@ func ExampleNewWorkerPoolStep() { // create some pipelines for i := 0; i < 3; i++ { n := i - pipelines <- pipeline.NewPipeline().AddStep(pipeline.NewStep(fmt.Sprintf("i = %d", n), func() pipeline.Result { + pipelines <- pipeline.NewPipeline().AddStep(pipeline.NewStep(fmt.Sprintf("i = %d", n), func(_ pipeline.Context) pipeline.Result { time.Sleep(time.Duration(n * 100000000)) // fake some load fmt.Println(fmt.Sprintf("This is job item %d", n)) return pipeline.Result{} diff --git a/pipeline.go b/pipeline.go index 4c7280e..c47c5dd 100644 --- a/pipeline.go +++ b/pipeline.go @@ -7,8 +7,9 @@ import ( type ( // Pipeline holds and runs intermediate actions, called "steps". Pipeline struct { - log Logger - steps []Step + log Logger + steps []Step + context Context } // Result is the object that is returned after each step and after running a pipeline. Result struct { @@ -27,7 +28,7 @@ type ( // This is required. F ActionFunc // H is the ResultHandler assigned to a pipeline Step. - // This is optional and it will be called in any case if it is set after F completed. + // This is optional, and it will be called in any case if it is set after F completed. // Use cases could be logging, updating a GUI or handle errors while continuing the pipeline. // The function may return nil even if the Result contains an error, in which case the pipeline will continue. // This function is called before the next step's F is invoked. @@ -40,7 +41,7 @@ type ( Log(message, name string) } // ActionFunc is the func that contains your business logic. - ActionFunc func() Result + ActionFunc func(ctx Context) Result // ResultHandler is a func that gets called when a step's ActionFunc has finished with any Result. ResultHandler func(result Result) error @@ -49,14 +50,19 @@ type ( func (n nullLogger) Log(_, _ string) {} -// NewPipeline returns a new Pipeline instance that doesn't log anything. +// NewPipeline returns a new quiet Pipeline instance with DefaultContext. func NewPipeline() *Pipeline { return NewPipelineWithLogger(nullLogger{}) } -// NewPipelineWithLogger returns a new Pipeline instance with the given logger that shouldn't be nil. +// NewPipelineWithContext returns a new Pipeline instance with the given context. +func NewPipelineWithContext(ctx Context) *Pipeline { + return &Pipeline{context: ctx, log: nullLogger{}} +} + +// NewPipelineWithLogger returns a new Pipeline instance with the given logger and DefaultContext. func NewPipelineWithLogger(logger Logger) *Pipeline { - return &Pipeline{log: logger} + return &Pipeline{log: logger, context: &DefaultContext{values: map[interface{}]interface{}{}}} } // AddStep appends the given step to the Pipeline at the end and returns itself. @@ -71,10 +77,10 @@ func (p *Pipeline) WithSteps(steps ...Step) *Pipeline { return p } -// WithNestedSteps is similar to AsNestedStep but it accepts the steps given directly as parameters. +// WithNestedSteps is similar to AsNestedStep, but it accepts the steps given directly as parameters. func (p *Pipeline) WithNestedSteps(name string, steps ...Step) Step { - return NewStep(name, func() Result { - nested := &Pipeline{log: p.log, steps: steps} + return NewStep(name, func(ctx Context) Result { + nested := &Pipeline{log: p.log, steps: steps, context: ctx} return nested.Run() }) } @@ -82,8 +88,8 @@ func (p *Pipeline) WithNestedSteps(name string, steps ...Step) Step { // AsNestedStep converts the Pipeline instance into a Step that can be used in other pipelines. // The logger and abort handler are passed to the nested pipeline. func (p *Pipeline) AsNestedStep(name string) Step { - return NewStep(name, func() Result { - nested := &Pipeline{log: p.log, steps: p.steps} + return NewStep(name, func(ctx Context) Result { + nested := &Pipeline{log: p.log, steps: p.steps, context: ctx} return nested.Run() }) } @@ -98,6 +104,12 @@ func (r Result) IsFailed() bool { return r.Err != nil } +// WithContext returns itself while setting the context for the pipeline steps. +func (p *Pipeline) WithContext(ctx Context) *Pipeline { + p.context = ctx + return p +} + // Run executes the pipeline and returns the result. // Steps are executed sequentially as they were added to the Pipeline. // If a Step returns a Result with a non-nil error, the Pipeline is aborted its Result contains the affected step's error. @@ -105,7 +117,7 @@ func (p *Pipeline) Run() Result { for _, step := range p.steps { p.log.Log("executing step", step.Name) - r := step.F() + r := step.F(p.context) if step.H != nil { if handlerErr := step.H(r); handlerErr != nil { return Result{Err: fmt.Errorf("step '%s' failed: %w", step.Name, handlerErr)} @@ -127,6 +139,14 @@ func NewStep(name string, action ActionFunc) Step { } } +// NewStepFromFunc returns a new Step with given name using a function that expects an error. +func NewStepFromFunc(name string, fn func(ctx Context) error) Step { + return NewStep(name, func(ctx Context) Result { + err := fn(ctx) + return Result{Err: err, Name: name} + }) +} + // WithResultHandler sets the ResultHandler of this specific step and returns the step itself. func (s Step) WithResultHandler(handler ResultHandler) Step { s.H = handler diff --git a/pipeline_test.go b/pipeline_test.go index 97951b2..0163663 100644 --- a/pipeline_test.go +++ b/pipeline_test.go @@ -8,7 +8,7 @@ import ( "github.com/stretchr/testify/require" ) -func TestPipeline_runPipeline(t *testing.T) { +func TestPipeline_Run(t *testing.T) { callCount := 0 tests := map[string]struct { givenSteps []Step @@ -17,7 +17,7 @@ func TestPipeline_runPipeline(t *testing.T) { }{ "GivenSingleStep_WhenRunning_ThenCallStep": { givenSteps: []Step{ - NewStep("test-step", func() Result { + NewStep("test-step", func(_ Context) Result { callCount += 1 return Result{} }), @@ -26,7 +26,7 @@ func TestPipeline_runPipeline(t *testing.T) { }, "GivenSingleStepWithoutHandler_WhenRunningWithError_ThenReturnError": { givenSteps: []Step{ - NewStep("test-step", func() Result { + NewStep("test-step", func(_ Context) Result { callCount += 1 return Result{Err: errors.New("step failed")} }), @@ -36,14 +36,14 @@ func TestPipeline_runPipeline(t *testing.T) { }, "GivenSingleStepWithHandler_WhenRunningWithError_ThenAbortWithError": { givenSteps: []Step{ - NewStep("test-step", func() Result { + NewStep("test-step", func(_ Context) Result { callCount += 1 return Result{} }).WithResultHandler(func(result Result) error { callCount += 1 return errors.New("handler") }), - NewStep("don't run this step", func() Result { + NewStep("don't run this step", func(_ Context) Result { callCount += 1 return Result{} }), @@ -53,14 +53,14 @@ func TestPipeline_runPipeline(t *testing.T) { }, "GivenSingleStepWithHandler_WhenNullifyingError_ThenContinuePipeline": { givenSteps: []Step{ - NewStep("test-step", func() Result { + NewStep("test-step", func(_ Context) Result { callCount += 1 return Result{Err: errors.New("failed step")} }).WithResultHandler(func(result Result) error { callCount += 1 return nil }), - NewStep("continue", func() Result { + NewStep("continue", func(_ Context) Result { callCount += 1 return Result{} }), @@ -69,12 +69,12 @@ func TestPipeline_runPipeline(t *testing.T) { }, "GivenNestedPipeline_WhenParentPipelineRuns_ThenRunNestedAsWell": { givenSteps: []Step{ - NewStep("test-step", func() Result { + NewStep("test-step", func(_ Context) Result { callCount += 1 return Result{} }), NewPipeline(). - AddStep(NewStep("nested-step", func() Result { + AddStep(NewStep("nested-step", func(_ Context) Result { callCount += 1 return Result{} })).AsNestedStep("nested-pipeline"), @@ -85,7 +85,7 @@ func TestPipeline_runPipeline(t *testing.T) { givenSteps: []Step{ NewPipeline(). WithNestedSteps("nested-pipeline", - NewStep("nested-step", func() Result { + NewStep("nested-step", func(_ Context) Result { callCount += 1 return Result{} })), @@ -113,3 +113,25 @@ func TestPipeline_runPipeline(t *testing.T) { }) } } + +func TestPipeline_RunWithContext(t *testing.T) { + ctx := &DefaultContext{values: map[interface{}]interface{}{}} + p := NewPipelineWithContext(ctx) + p.AddStep(NewStep("context", func(ctx Context) Result { + ctx.SetValue("key", "value") + return Result{} + })) + result := p.Run() + require.NoError(t, result.Err) + assert.Equal(t, "value", ctx.StringValue("key", "default")) +} + +func TestNewStepFromFunc(t *testing.T) { + called := false + step := NewStepFromFunc("name", func(ctx Context) error { + called = true + return nil + }) + _ = step.F(nil) + assert.True(t, called) +} diff --git a/predicate.go b/predicate.go deleted file mode 100644 index fb2071c..0000000 --- a/predicate.go +++ /dev/null @@ -1 +0,0 @@ -package pipeline diff --git a/predicate/predicate.go b/predicate/predicate.go index 5d3d436..3b146d0 100644 --- a/predicate/predicate.go +++ b/predicate/predicate.go @@ -6,18 +6,19 @@ import ( type ( // Predicate is a function that expects 'true' if a pipeline.ActionFunc should run. - // Is is evaluated lazily resp. only when needed. + // It is evaluated lazily resp. only when needed. Predicate func(step pipeline.Step) bool ) // ToStep wraps the given action func in its own step. // When the step's function is called, the given Predicate will evaluate whether the action should actually run. // It returns the action's pipeline.Result, otherwise an empty (successful) pipeline.Result struct. +// The pipeline.Context from the pipeline is passed through the given action. func ToStep(name string, action pipeline.ActionFunc, predicate Predicate) pipeline.Step { step := pipeline.Step{Name: name} - step.F = func() pipeline.Result { + step.F = func(ctx pipeline.Context) pipeline.Result { if predicate(step) { - return action() + return action(ctx) } return pipeline.Result{} } @@ -27,9 +28,10 @@ func ToStep(name string, action pipeline.ActionFunc, predicate Predicate) pipeli // ToNestedStep wraps the given pipeline in its own step. // When the step's function is called, the given Predicate will evaluate whether the nested pipeline.Pipeline should actually run. // It returns the pipeline's pipeline.Result, otherwise an empty (successful) pipeline.Result struct. +// The given pipeline has to define its own pipeline.Context, it's not passed "down". func ToNestedStep(name string, p *pipeline.Pipeline, predicate Predicate) pipeline.Step { step := pipeline.Step{Name: name} - step.F = func() pipeline.Result { + step.F = func(_ pipeline.Context) pipeline.Result { if predicate(step) { return p.Run() } @@ -39,11 +41,12 @@ func ToNestedStep(name string, p *pipeline.Pipeline, predicate Predicate) pipeli } // WrapIn returns a new step that wraps the given step and executes its action only if the given Predicate evaluates true. +// The pipeline.Context from the pipeline is passed through the given action. func WrapIn(originalStep pipeline.Step, predicate Predicate) pipeline.Step { wrappedStep := pipeline.Step{Name: originalStep.Name} - wrappedStep.F = func() pipeline.Result { + wrappedStep.F = func(ctx pipeline.Context) pipeline.Result { if predicate(wrappedStep) { - return originalStep.F() + return originalStep.F(ctx) } return pipeline.Result{} } diff --git a/predicate/predicate_test.go b/predicate/predicate_test.go index be4d0ce..231b82f 100644 --- a/predicate/predicate_test.go +++ b/predicate/predicate_test.go @@ -58,11 +58,11 @@ func Test_Predicates(t *testing.T) { for name, tt := range tests { t.Run(name, func(t *testing.T) { counter = 0 - step := ToStep("name", func() pipeline.Result { + step := ToStep("name", func(_ pipeline.Context) pipeline.Result { counter += 1 return pipeline.Result{} }, tt.givenPredicate) - result := step.F() + result := step.F(nil) assert.Equal(t, tt.expectedCounts, counter) assert.NoError(t, result.Err) }) @@ -88,12 +88,12 @@ func TestToNestedStep(t *testing.T) { for name, tt := range tests { t.Run(name, func(t *testing.T) { counter = 0 - p := pipeline.NewPipeline().AddStep(pipeline.NewStep("nested step", func() pipeline.Result { + p := pipeline.NewPipeline().AddStep(pipeline.NewStep("nested step", func(_ pipeline.Context) pipeline.Result { counter++ return pipeline.Result{} })) step := ToNestedStep("super step", p, tt.givenPredicate) - _ = step.F() + _ = step.F(nil) assert.Equal(t, tt.expectedCounts, counter) }) } @@ -118,12 +118,12 @@ func TestWrapIn(t *testing.T) { for name, tt := range tests { t.Run(name, func(t *testing.T) { counter = 0 - step := pipeline.NewStep("step", func() pipeline.Result { + step := pipeline.NewStep("step", func(_ pipeline.Context) pipeline.Result { counter++ return pipeline.Result{} }) wrapped := WrapIn(step, tt.givenPredicate) - result := wrapped.F() + result := wrapped.F(nil) require.NoError(t, result.Err) assert.Equal(t, tt.expectedCalls, counter) assert.Equal(t, step.Name, wrapped.Name)