Skip to content

Commit

Permalink
[Bug] correctly set task execution phase for terminal array node (#5136)
Browse files Browse the repository at this point in the history
* correctly set task execution phase for array node

Signed-off-by: Paul Dittamo <[email protected]>

* clean up

Signed-off-by: Paul Dittamo <[email protected]>

* update unit tests

Signed-off-by: Paul Dittamo <[email protected]>

---------

Signed-off-by: Paul Dittamo <[email protected]>
  • Loading branch information
pvditt authored Jan 13, 2025
1 parent 6ea9531 commit be66530
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 21 deletions.
23 changes: 18 additions & 5 deletions flytepropeller/pkg/controller/nodes/array/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@ func (a *arrayNodeHandler) Abort(ctx context.Context, nCtx interfaces.NodeExecut

eventRecorder := newArrayEventRecorder(nCtx.EventsRecorder())
messageCollector := errorcollector.NewErrorMessageCollector()

taskPhase := idlcore.TaskExecution_ABORTED
if arrayNodeState.Phase == v1alpha1.ArrayNodePhaseFailing {
taskPhase = idlcore.TaskExecution_FAILED
}
switch arrayNodeState.Phase {
case v1alpha1.ArrayNodePhaseExecuting, v1alpha1.ArrayNodePhaseFailing:
for i, nodePhaseUint64 := range arrayNodeState.SubNodePhases.GetItems() {
Expand Down Expand Up @@ -122,13 +127,12 @@ func (a *arrayNodeHandler) Abort(ctx context.Context, nCtx interfaces.NodeExecut
}

// update state for subNodes
if err := eventRecorder.finalize(ctx, nCtx, idlcore.TaskExecution_ABORTED, 0, a.eventConfig); err != nil {
if err := eventRecorder.finalize(ctx, nCtx, taskPhase, 0, a.eventConfig); err != nil {
// a task event with abort phase is already emitted when handling ArrayNodePhaseFailing
if eventsErr.IsAlreadyExists(err) {
return nil
if !eventsErr.IsAlreadyExists(err) {
logger.Errorf(ctx, "ArrayNode event recording failed: [%s]", err.Error())
return err
}
logger.Errorf(ctx, "ArrayNode event recording failed: [%s]", err.Error())
return err
}

return nil
Expand Down Expand Up @@ -462,6 +466,7 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu
nCtx.ExecutionContext().IncrementParallelism()
}
case v1alpha1.ArrayNodePhaseFailing:
// note: sub node eventing handled during Abort
if err := a.Abort(ctx, nCtx, "ArrayNodeFailing"); err != nil {
return handler.UnknownTransition, err
}
Expand Down Expand Up @@ -609,6 +614,14 @@ func (a *arrayNodeHandler) Handle(ctx context.Context, nCtx interfaces.NodeExecu
return handler.UnknownTransition, err
}

// ensure task_execution set to succeeded
if err := eventRecorder.finalize(ctx, nCtx, idlcore.TaskExecution_SUCCEEDED, 0, a.eventConfig); err != nil {
if !eventsErr.IsAlreadyExists(err) {
logger.Errorf(ctx, "ArrayNode event recording failed: [%s]", err.Error())
return handler.UnknownTransition, err
}
}

return handler.DoTransition(handler.TransitionTypeEphemeral, handler.PhaseInfoSuccess(
&handler.ExecutionInfo{
OutputInfo: &handler.OutputInfo{
Expand Down
55 changes: 39 additions & 16 deletions flytepropeller/pkg/controller/nodes/array/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,40 +202,56 @@ func createNodeExecutionContext(dataStore *storage.DataStore, eventRecorder inte

func TestAbort(t *testing.T) {
ctx := context.Background()
scope := promutils.NewTestScope()
dataStore, err := storage.NewDataStore(&storage.Config{
Type: storage.TypeMemory,
}, scope)
assert.NoError(t, err)

nodeHandler := &mocks.NodeHandler{}
nodeHandler.OnAbortMatch(mock.Anything, mock.Anything, mock.Anything).Return(nil)
nodeHandler.OnFinalizeMatch(mock.Anything, mock.Anything, mock.Anything).Return(nil)

// initialize ArrayNodeHandler
arrayNodeHandler, err := createArrayNodeHandler(ctx, t, nodeHandler, dataStore, scope)
assert.NoError(t, err)

tests := []struct {
name string
inputMap map[string][]int64
subNodePhases []v1alpha1.NodePhase
subNodeTaskPhases []core.Phase
expectedExternalResourcePhases []idlcore.TaskExecution_Phase
arrayNodeState v1alpha1.ArrayNodePhase
expectedTaskExecutionPhase idlcore.TaskExecution_Phase
}{
{
name: "Success",
name: "Aborted after failed",
inputMap: map[string][]int64{
"foo": []int64{0, 1, 2},
},
subNodePhases: []v1alpha1.NodePhase{v1alpha1.NodePhaseSucceeded, v1alpha1.NodePhaseRunning, v1alpha1.NodePhaseNotYetStarted},
subNodeTaskPhases: []core.Phase{core.PhaseSuccess, core.PhaseRunning, core.PhaseUndefined},
expectedExternalResourcePhases: []idlcore.TaskExecution_Phase{idlcore.TaskExecution_ABORTED},
arrayNodeState: v1alpha1.ArrayNodePhaseFailing,
expectedTaskExecutionPhase: idlcore.TaskExecution_FAILED,
},
{
name: "Aborted while running",
inputMap: map[string][]int64{
"foo": []int64{0, 1, 2},
},
subNodePhases: []v1alpha1.NodePhase{v1alpha1.NodePhaseSucceeded, v1alpha1.NodePhaseRunning, v1alpha1.NodePhaseNotYetStarted},
subNodeTaskPhases: []core.Phase{core.PhaseSuccess, core.PhaseRunning, core.PhaseUndefined},
expectedExternalResourcePhases: []idlcore.TaskExecution_Phase{idlcore.TaskExecution_ABORTED},
arrayNodeState: v1alpha1.ArrayNodePhaseExecuting,
expectedTaskExecutionPhase: idlcore.TaskExecution_ABORTED,
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
scope := promutils.NewTestScope()
dataStore, err := storage.NewDataStore(&storage.Config{
Type: storage.TypeMemory,
}, scope)
assert.NoError(t, err)

nodeHandler := &mocks.NodeHandler{}
nodeHandler.OnAbortMatch(mock.Anything, mock.Anything, mock.Anything).Return(nil)
nodeHandler.OnFinalizeMatch(mock.Anything, mock.Anything, mock.Anything).Return(nil)

// initialize ArrayNodeHandler
arrayNodeHandler, err := createArrayNodeHandler(ctx, t, nodeHandler, dataStore, scope)
assert.NoError(t, err)

// initialize universal variables
literalMap := convertMapToArrayLiterals(test.inputMap)

Expand All @@ -250,7 +266,7 @@ func TestAbort(t *testing.T) {

// initialize ArrayNodeState
arrayNodeState := &handler.ArrayNodeState{
Phase: v1alpha1.ArrayNodePhaseFailing,
Phase: test.arrayNodeState,
}
for _, item := range []struct {
arrayReference *bitarray.CompactArray
Expand Down Expand Up @@ -279,12 +295,13 @@ func TestAbort(t *testing.T) {
nCtx := createNodeExecutionContext(dataStore, eventRecorder, nil, literalMap, &arrayNodeSpec, arrayNodeState, 0, workflowMaxParallelism)

// evaluate node
err := arrayNodeHandler.Abort(ctx, nCtx, "foo")
err = arrayNodeHandler.Abort(ctx, nCtx, "foo")
assert.NoError(t, err)

nodeHandler.AssertNumberOfCalls(t, "Abort", len(test.expectedExternalResourcePhases))
if len(test.expectedExternalResourcePhases) > 0 {
assert.Equal(t, 1, len(eventRecorder.taskExecutionEvents))
assert.Equal(t, test.expectedTaskExecutionPhase, eventRecorder.taskExecutionEvents[0].GetPhase())

externalResources := eventRecorder.taskExecutionEvents[0].GetMetadata().GetExternalResources()
assert.Equal(t, len(test.expectedExternalResourcePhases), len(externalResources))
Expand Down Expand Up @@ -1296,6 +1313,9 @@ func TestHandleArrayNodePhaseSucceeding(t *testing.T) {
assert.Equal(t, int64(*outputValue), collection.GetLiterals()[i].GetScalar().GetPrimitive().GetInteger())
}
}

assert.Equal(t, 1, len(eventRecorder.taskExecutionEvents))
assert.Equal(t, idlcore.TaskExecution_SUCCEEDED, eventRecorder.taskExecutionEvents[0].GetPhase())
})
}
}
Expand Down Expand Up @@ -1374,6 +1394,9 @@ func TestHandleArrayNodePhaseFailing(t *testing.T) {
assert.Equal(t, test.expectedArrayNodePhase, arrayNodeState.Phase)
assert.Equal(t, test.expectedTransitionPhase, transition.Info().GetPhase())
nodeHandler.AssertNumberOfCalls(t, "Abort", test.expectedAbortCalls)

assert.Equal(t, 1, len(eventRecorder.taskExecutionEvents))
assert.Equal(t, idlcore.TaskExecution_FAILED, eventRecorder.taskExecutionEvents[0].GetPhase())
})
}
}
Expand Down

0 comments on commit be66530

Please sign in to comment.