From 1307b786d9b13696b6fc8eae65654488a7621fd6 Mon Sep 17 00:00:00 2001 From: featherchen Date: Tue, 22 Oct 2024 14:14:27 -0700 Subject: [PATCH 1/5] feat(task_repo): set default version to the latest Signed-off-by: featherchen --- .../pkg/manager/impl/validation/validation.go | 20 +++++++++++++++++++ .../pkg/repositories/gormimpl/task_repo.go | 12 ++++++++--- 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/flyteadmin/pkg/manager/impl/validation/validation.go b/flyteadmin/pkg/manager/impl/validation/validation.go index de2927495c..2ff5859b44 100644 --- a/flyteadmin/pkg/manager/impl/validation/validation.go +++ b/flyteadmin/pkg/manager/impl/validation/validation.go @@ -95,6 +95,23 @@ func ValidateIdentifierFieldsSet(id *core.Identifier) error { return nil } +// ValidateTaskIdentifierFieldsSet Validates that all required fields, except version, for a task identifier are present. +func ValidateTaskIdentifierFieldsSet(id *core.Identifier) error { + if id == nil { + return shared.GetMissingArgumentError(shared.ID) + } + if err := ValidateEmptyStringField(id.Project, shared.Project); err != nil { + return err + } + if err := ValidateEmptyStringField(id.Domain, shared.Domain); err != nil { + return err + } + if err := ValidateEmptyStringField(id.Name, shared.Name); err != nil { + return err + } + return nil +} + // ValidateIdentifier Validates that all required fields for an identifier are present. func ValidateIdentifier(id *core.Identifier, expectedType common.Entity) error { if id == nil { @@ -105,6 +122,9 @@ func ValidateIdentifier(id *core.Identifier, expectedType common.Entity) error { "unexpected resource type %s for identifier [%+v], expected %s instead", strings.ToLower(id.ResourceType.String()), id, strings.ToLower(entityToResourceType[expectedType].String())) } + if id.ResourceType == core.ResourceType_TASK { + return ValidateTaskIdentifierFieldsSet(id) + } return ValidateIdentifierFieldsSet(id) } diff --git a/flyteadmin/pkg/repositories/gormimpl/task_repo.go b/flyteadmin/pkg/repositories/gormimpl/task_repo.go index 1b42756b7a..784bb306b7 100644 --- a/flyteadmin/pkg/repositories/gormimpl/task_repo.go +++ b/flyteadmin/pkg/repositories/gormimpl/task_repo.go @@ -49,14 +49,20 @@ func (r *TaskRepo) Create(ctx context.Context, input models.Task, descriptionEnt func (r *TaskRepo) Get(ctx context.Context, input interfaces.Identifier) (models.Task, error) { var task models.Task timer := r.metrics.GetDuration.Start() - tx := r.db.WithContext(ctx).Where(&models.Task{ + query := r.db.WithContext(ctx).Where(&models.Task{ TaskKey: models.TaskKey{ Project: input.Project, Domain: input.Domain, Name: input.Name, - Version: input.Version, }, - }).Take(&task) + }) + + if input.Version == "" { + query = query.Order("version DESC").Limit(1) + } else { + query = query.Where("version = ?", input.Version) + } + tx := query.Take(&task) timer.Stop() if errors.Is(tx.Error, gorm.ErrRecordNotFound) { return models.Task{}, flyteAdminDbErrors.GetMissingEntityError(core.ResourceType_TASK.String(), &core.Identifier{ From 7770c6a686b388b8a9cdeeb5d8bab5b8e5fb4a87 Mon Sep 17 00:00:00 2001 From: featherchen Date: Tue, 22 Oct 2024 15:08:09 -0700 Subject: [PATCH 2/5] fix(test): remove tests related to empty task version Signed-off-by: featherchen --- .../task_execution_validator_test.go | 22 ----------------- .../impl/validation/task_validator_test.go | 8 ------- .../pkg/repositories/gormimpl/task_repo.go | 24 ++++++++++--------- 3 files changed, 13 insertions(+), 41 deletions(-) diff --git a/flyteadmin/pkg/manager/impl/validation/task_execution_validator_test.go b/flyteadmin/pkg/manager/impl/validation/task_execution_validator_test.go index 26f2c48948..712d5948b0 100644 --- a/flyteadmin/pkg/manager/impl/validation/task_execution_validator_test.go +++ b/flyteadmin/pkg/manager/impl/validation/task_execution_validator_test.go @@ -63,28 +63,6 @@ func TestValidateTaskExecutionRequest_MissingFields(t *testing.T) { }, maxOutputSizeInBytes) assert.EqualError(t, err, "missing occurred_at") - err = ValidateTaskExecutionRequest(&admin.TaskExecutionEventRequest{ - Event: &event.TaskExecutionEvent{ - OccurredAt: taskEventOccurredAtProto, - TaskId: &core.Identifier{ - ResourceType: core.ResourceType_TASK, - Project: "project", - Domain: "domain", - Name: "name", - }, - ParentNodeExecutionId: &core.NodeExecutionIdentifier{ - NodeId: "nodey", - ExecutionId: &core.WorkflowExecutionIdentifier{ - Project: "project", - Domain: "domain", - Name: "name", - }, - }, - RetryAttempt: 0, - }, - }, maxOutputSizeInBytes) - assert.EqualError(t, err, "missing version") - err = ValidateTaskExecutionRequest(&admin.TaskExecutionEventRequest{ Event: &event.TaskExecutionEvent{ OccurredAt: taskEventOccurredAtProto, diff --git a/flyteadmin/pkg/manager/impl/validation/task_validator_test.go b/flyteadmin/pkg/manager/impl/validation/task_validator_test.go index be450e4171..d6bc780bbb 100644 --- a/flyteadmin/pkg/manager/impl/validation/task_validator_test.go +++ b/flyteadmin/pkg/manager/impl/validation/task_validator_test.go @@ -110,14 +110,6 @@ func TestValidateTaskEmptyName(t *testing.T) { assert.EqualError(t, err, "missing name") } -func TestValidateTaskEmptyVersion(t *testing.T) { - request := testutils.GetValidTaskRequest() - request.Id.Version = "" - err := ValidateTask(context.Background(), request, testutils.GetRepoWithDefaultProject(), - getMockTaskResources(), mockWhitelistConfigProvider, taskApplicationConfigProvider) - assert.EqualError(t, err, "missing version") -} - func TestValidateTaskEmptyType(t *testing.T) { request := testutils.GetValidTaskRequest() request.Spec.Template.Type = "" diff --git a/flyteadmin/pkg/repositories/gormimpl/task_repo.go b/flyteadmin/pkg/repositories/gormimpl/task_repo.go index 784bb306b7..828888e646 100644 --- a/flyteadmin/pkg/repositories/gormimpl/task_repo.go +++ b/flyteadmin/pkg/repositories/gormimpl/task_repo.go @@ -49,20 +49,22 @@ func (r *TaskRepo) Create(ctx context.Context, input models.Task, descriptionEnt func (r *TaskRepo) Get(ctx context.Context, input interfaces.Identifier) (models.Task, error) { var task models.Task timer := r.metrics.GetDuration.Start() - query := r.db.WithContext(ctx).Where(&models.Task{ - TaskKey: models.TaskKey{ - Project: input.Project, - Domain: input.Domain, - Name: input.Name, - }, - }) - + var tx *gorm.DB if input.Version == "" { - query = query.Order("version DESC").Limit(1) + tx := r.db.WithContext(ctx).Limit(1) + tx = tx.Order("DESC") + tx.Find(&task) } else { - query = query.Where("version = ?", input.Version) + tx = r.db.WithContext(ctx).Where(&models.Task{ + TaskKey: models.TaskKey{ + Project: input.Project, + Domain: input.Domain, + Name: input.Name, + Version: input.Version, + }, + }).Take(&task) } - tx := query.Take(&task) + timer.Stop() if errors.Is(tx.Error, gorm.ErrRecordNotFound) { return models.Task{}, flyteAdminDbErrors.GetMissingEntityError(core.ResourceType_TASK.String(), &core.Identifier{ From 9c88b3e3fb77e761d5e6e93dae8f5803824e5bcb Mon Sep 17 00:00:00 2001 From: featherchen Date: Wed, 23 Oct 2024 02:40:22 -0700 Subject: [PATCH 3/5] feat(test): TestGetTask Signed-off-by: featherchen --- .../pkg/repositories/gormimpl/task_repo.go | 4 ++-- .../repositories/gormimpl/task_repo_test.go | 22 +++++++++++++++++++ 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/flyteadmin/pkg/repositories/gormimpl/task_repo.go b/flyteadmin/pkg/repositories/gormimpl/task_repo.go index 828888e646..c8c9a6948f 100644 --- a/flyteadmin/pkg/repositories/gormimpl/task_repo.go +++ b/flyteadmin/pkg/repositories/gormimpl/task_repo.go @@ -51,8 +51,8 @@ func (r *TaskRepo) Get(ctx context.Context, input interfaces.Identifier) (models timer := r.metrics.GetDuration.Start() var tx *gorm.DB if input.Version == "" { - tx := r.db.WithContext(ctx).Limit(1) - tx = tx.Order("DESC") + tx := r.db.WithContext(ctx).Where("project = ? AND domain = ? AND name = ?", input.Project, input.Domain, input.Name).Limit(1) + tx = tx.Order("version DESC") tx.Find(&task) } else { tx = r.db.WithContext(ctx).Where(&models.Task{ diff --git a/flyteadmin/pkg/repositories/gormimpl/task_repo_test.go b/flyteadmin/pkg/repositories/gormimpl/task_repo_test.go index 3309ad3609..5204c2a1cf 100644 --- a/flyteadmin/pkg/repositories/gormimpl/task_repo_test.go +++ b/flyteadmin/pkg/repositories/gormimpl/task_repo_test.go @@ -79,6 +79,28 @@ func TestGetTask(t *testing.T) { assert.Equal(t, version, output.Version) assert.Equal(t, []byte{1, 2}, output.Closure) assert.Equal(t, pythonTestTaskType, output.Type) + + //When version is empty, return the latest task + GlobalMock = mocket.Catcher.Reset() + GlobalMock.Logging = true + + GlobalMock.NewMock().WithQuery( + `SELECT * FROM "tasks" WHERE project = $1 AND domain = $2 AND name = $3 ORDER BY version DESC LIMIT 1`). + WithReply(tasks) + output, err = taskRepo.Get(context.Background(), interfaces.Identifier{ + Project: project, + Domain: domain, + Name: name, + Version: "", + }) + + assert.NoError(t, err) + assert.Equal(t, project, output.Project) + assert.Equal(t, domain, output.Domain) + assert.Equal(t, name, output.Name) + assert.Equal(t, "v2", output.Version) + assert.Equal(t, []byte{3, 4}, output.Closure) + assert.Equal(t, pythonTestTaskType, output.Type) } func TestListTasks(t *testing.T) { From 5341abfab5afa5650654f19538bd401c9e7fd97c Mon Sep 17 00:00:00 2001 From: featherchen Date: Thu, 24 Oct 2024 21:54:14 -0700 Subject: [PATCH 4/5] fix: TestGetTask Signed-off-by: featherchen --- flyteadmin/pkg/repositories/gormimpl/task_repo.go | 4 ++-- flyteadmin/pkg/repositories/gormimpl/task_repo_test.go | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/flyteadmin/pkg/repositories/gormimpl/task_repo.go b/flyteadmin/pkg/repositories/gormimpl/task_repo.go index c8c9a6948f..82cbd31f40 100644 --- a/flyteadmin/pkg/repositories/gormimpl/task_repo.go +++ b/flyteadmin/pkg/repositories/gormimpl/task_repo.go @@ -51,8 +51,8 @@ func (r *TaskRepo) Get(ctx context.Context, input interfaces.Identifier) (models timer := r.metrics.GetDuration.Start() var tx *gorm.DB if input.Version == "" { - tx := r.db.WithContext(ctx).Where("project = ? AND domain = ? AND name = ?", input.Project, input.Domain, input.Name).Limit(1) - tx = tx.Order("version DESC") + tx = r.db.WithContext(ctx).Where(`"tasks"."project" = ? AND "tasks"."domain" = ? AND "tasks"."name" = ?`, input.Project, input.Domain, input.Name).Limit(1) + tx = tx.Order(`"tasks"."version" DESC`) tx.Find(&task) } else { tx = r.db.WithContext(ctx).Where(&models.Task{ diff --git a/flyteadmin/pkg/repositories/gormimpl/task_repo_test.go b/flyteadmin/pkg/repositories/gormimpl/task_repo_test.go index 5204c2a1cf..643dc66e4d 100644 --- a/flyteadmin/pkg/repositories/gormimpl/task_repo_test.go +++ b/flyteadmin/pkg/repositories/gormimpl/task_repo_test.go @@ -85,7 +85,7 @@ func TestGetTask(t *testing.T) { GlobalMock.Logging = true GlobalMock.NewMock().WithQuery( - `SELECT * FROM "tasks" WHERE project = $1 AND domain = $2 AND name = $3 ORDER BY version DESC LIMIT 1`). + `SELECT * FROM "tasks" WHERE "tasks"."project" = $1 AND "tasks"."domain" = $2 AND "tasks"."name" = $3 ORDER BY "tasks"."version" DESC LIMIT 1`). WithReply(tasks) output, err = taskRepo.Get(context.Background(), interfaces.Identifier{ Project: project, @@ -98,8 +98,8 @@ func TestGetTask(t *testing.T) { assert.Equal(t, project, output.Project) assert.Equal(t, domain, output.Domain) assert.Equal(t, name, output.Name) - assert.Equal(t, "v2", output.Version) - assert.Equal(t, []byte{3, 4}, output.Closure) + assert.Equal(t, version, output.Version) + assert.Equal(t, []byte{1, 2}, output.Closure) assert.Equal(t, pythonTestTaskType, output.Type) } From fb6f97d321b2bd951fbee8ad598d226f096a04ad Mon Sep 17 00:00:00 2001 From: featherchen Date: Thu, 7 Nov 2024 19:07:37 -0800 Subject: [PATCH 5/5] feat(test): Get integration test Signed-off-by: featherchen --- flyteadmin/tests/task_execution_test.go | 31 +++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/flyteadmin/tests/task_execution_test.go b/flyteadmin/tests/task_execution_test.go index e380104684..c333bd7fd5 100644 --- a/flyteadmin/tests/task_execution_test.go +++ b/flyteadmin/tests/task_execution_test.go @@ -45,6 +45,37 @@ var taskExecutionIdentifier = &core.TaskExecutionIdentifier{ RetryAttempt: 1, } +func TestGetTaskExecutions(t *testing.T) { + truncateAllTablesForTestingOnly() + populateWorkflowExecutionsForTestingOnly() + + ctx := context.Background() + client, conn := GetTestAdminServiceClient() + defer conn.Close() + + _, err := client.CreateTask(ctx, &admin.TaskCreateRequest{ + Id: taskIdentifier, + Spec: testutils.GetValidTaskRequest().Spec, + }) + require.NoError(t, err) + + resp, err := client.GetTask(ctx, &admin.ObjectGetRequest{ + Id: &core.Identifier{ + ResourceType: core.ResourceType_TASK, + Project: project, + Domain: "development", + Name: "task name", + }, + }) + + assert.Nil(t, err) + assert.Equal(t, resp.Id.Project, project) + assert.Equal(t, resp.Id.Domain, "development") + assert.Equal(t, resp.Id.Name, "task name") + assert.Equal(t, resp.Id.Version, "task version") + +} + func createTaskAndNodeExecution( ctx context.Context, t *testing.T, client service.AdminServiceClient, conn *grpc.ClientConn, occurredAtProto *timestamp.Timestamp) {