From d0821c232fc95f7198bbe7edeeea7529866f88ce Mon Sep 17 00:00:00 2001 From: Chris Date: Tue, 12 Apr 2022 12:50:05 +0200 Subject: [PATCH] Add wrapper to store and retrieve values in context --- context.go | 47 ++++++++++++++++++++++++++++++++ context_test.go | 72 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 119 insertions(+) create mode 100644 context.go create mode 100644 context_test.go diff --git a/context.go b/context.go new file mode 100644 index 0000000..9e73fc3 --- /dev/null +++ b/context.go @@ -0,0 +1,47 @@ +package pipeline + +import ( + "context" + "errors" + "sync" +) + +type contextKey struct{} + +// VariableContext adds a map to the given context that can be used to store intermediate values in the context. +// It uses sync.Map under the hood. +// +// See also AddToContext() and ValueFromContext. +func VariableContext(parent context.Context) context.Context { + return context.WithValue(parent, contextKey{}, &sync.Map{}) +} + +// AddToContext adds the given key and value to ctx. +// Any keys or values added during pipeline execution is available in the next steps, provided the pipeline runs synchronously. +// In parallel executed pipelines you may encounter race conditions. +// Use ValueFromContext to retrieve values. +// +// Note: This method is thread-safe, but panics if ctx has not been set up with VariableContext first. +func AddToContext(ctx context.Context, key, value interface{}) { + m := ctx.Value(contextKey{}) + if m == nil { + panic(errors.New("context was not set up with VariableContext()")) + } + m.(*sync.Map).Store(key, value) +} + +// ValueFromContext returns the value from the given context with the given key. +// It returns the value and true, or nil and false if the key doesn't exist. +// It may return nil and true if the key exists, but the value actually is nil. +// Use AddToContext to store values. +// +// Note: This method is thread-safe, but panics if the ctx has not been set up with VariableContext first. +func ValueFromContext(ctx context.Context, key interface{}) (interface{}, bool) { + m := ctx.Value(contextKey{}) + if m == nil { + panic(errors.New("context was not set up with VariableContext()")) + } + mp := m.(*sync.Map) + val, found := mp.Load(key) + return val, found +} diff --git a/context_test.go b/context_test.go new file mode 100644 index 0000000..126a4c9 --- /dev/null +++ b/context_test.go @@ -0,0 +1,72 @@ +package pipeline + +import ( + "context" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestContext(t *testing.T) { + tests := map[string]struct { + givenKey interface{} + givenValue interface{} + expectedValue interface{} + expectedFound bool + }{ + "GivenNonExistentKey_ThenExpectNilAndFalse": { + givenKey: nil, + expectedValue: nil, + }, + "GivenKeyWithNilValue_ThenExpectNilAndTrue": { + givenKey: "key", + givenValue: nil, + expectedValue: nil, + expectedFound: true, + }, + "GivenKeyWithValue_ThenExpectValueAndTrue": { + givenKey: "key", + givenValue: "value", + expectedValue: "value", + expectedFound: true, + }, + } + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + ctx := VariableContext(context.Background()) + if tc.givenKey != nil { + AddToContext(ctx, tc.givenKey, tc.givenValue) + } + result, found := ValueFromContext(ctx, tc.givenKey) + assert.Equal(t, tc.expectedValue, result, "value") + assert.Equal(t, tc.expectedFound, found, "value found") + }) + } +} + +func TestContextPanics(t *testing.T) { + assert.PanicsWithError(t, "context was not set up with VariableContext()", func() { + AddToContext(context.Background(), "key", "value") + }, "AddToContext") + assert.PanicsWithError(t, "context was not set up with VariableContext()", func() { + ValueFromContext(context.Background(), "key") + }, "ValueFromContext") +} + +func ExampleVariableContext() { + ctx := VariableContext(context.Background()) + p := NewPipeline().WithSteps( + NewStepFromFunc("store value", func(ctx context.Context) error { + AddToContext(ctx, "key", "value") + return nil + }), + NewStepFromFunc("retrieve value", func(ctx context.Context) error { + value, _ := ValueFromContext(ctx, "key") + fmt.Println(value) + return nil + }), + ) + p.RunWithContext(ctx) + // Output: value +}