Skip to content

Commit

Permalink
parameterize
Browse files Browse the repository at this point in the history
  • Loading branch information
AmanuelAaron committed Aug 28, 2024
1 parent f866ed1 commit 7691de3
Showing 1 changed file with 47 additions and 47 deletions.
94 changes: 47 additions & 47 deletions master/internal/api_runs_intg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,36 +193,12 @@ func TestSearchRunsSort(t *testing.T) {
HParams: hyperparameters2,
}, task2.TaskID))

// Sort by start time
resp, err = api.SearchRuns(ctx, &apiv1.SearchRunsRequest{
ProjectId: req.ProjectId,
Sort: ptrs.Ptr("startTime=asc"),
})

require.NoError(t, err)
require.Equal(t, int32(exp.ID), resp.Runs[0].Experiment.Id)
require.Equal(t, int32(exp2.ID), resp.Runs[1].Experiment.Id)

// Sort by hyperparameter
resp, err = api.SearchRuns(ctx, &apiv1.SearchRunsRequest{
ProjectId: req.ProjectId,
Sort: ptrs.Ptr("hp.global_batch_size=desc"),
})

require.NoError(t, err)
require.Equal(t, int32(exp2.ID), resp.Runs[0].Experiment.Id)
require.Equal(t, int32(exp.ID), resp.Runs[1].Experiment.Id)

// Sort by nested hyperparameter
resp, err = api.SearchRuns(ctx, &apiv1.SearchRunsRequest{
ProjectId: req.ProjectId,
Sort: ptrs.Ptr("hp.test1.test2=desc"),
})

// Get runs in project
resp, err = api.SearchRuns(ctx, req)
require.NoError(t, err)
require.Equal(t, int32(exp2.ID), resp.Runs[0].Experiment.Id)
require.Equal(t, int32(exp.ID), resp.Runs[1].Experiment.Id)
require.Len(t, resp.Runs, 2)

// add metadata
rawMetadata := map[string]any{
"number_key": 1,
"nested": map[string]any{
Expand All @@ -231,7 +207,7 @@ func TestSearchRunsSort(t *testing.T) {
}
metadata := newProtoStruct(t, rawMetadata)
_, err = api.PostRunMetadata(ctx, &apiv1.PostRunMetadataRequest{
RunId: resp.Runs[1].Id,
RunId: resp.Runs[0].Id,
Metadata: metadata,
})
require.NoError(t, err)
Expand All @@ -244,30 +220,54 @@ func TestSearchRunsSort(t *testing.T) {
}
metadata = newProtoStruct(t, rawMetadata)
_, err = api.PostRunMetadata(ctx, &apiv1.PostRunMetadataRequest{
RunId: resp.Runs[0].Id,
RunId: resp.Runs[1].Id,
Metadata: metadata,
})
require.NoError(t, err)

// Sort by custom metadata
resp, err = api.SearchRuns(ctx, &apiv1.SearchRunsRequest{
ProjectId: req.ProjectId,
Sort: ptrs.Ptr("metadata.number_key=desc"),
})

require.NoError(t, err)
require.Equal(t, int32(exp2.ID), resp.Runs[0].Experiment.Id)
require.Equal(t, int32(exp.ID), resp.Runs[1].Experiment.Id)
tests := map[string]struct {
sortBy string
reverse bool
}{
"StartTime": {
sortBy: "startTime=asc",
reverse: false,
},
"Hyperparameter": {
sortBy: "hp.global_batch_size=desc",
reverse: true,
},
"HyperparameterNested": {
sortBy: "hp.test1.test2=desc",
reverse: true,
},
"Metadata": {
sortBy: "metadata.number_key=desc",
reverse: true,
},
"MetadataNested": {
sortBy: "metadata.nested.number_key=desc",
reverse: true,
},
}

// Sort by nested custom metadata
resp, err = api.SearchRuns(ctx, &apiv1.SearchRunsRequest{
ProjectId: req.ProjectId,
Sort: ptrs.Ptr("metadata.nested.number_key=desc"),
})
for testCase, testVars := range tests {
t.Run(testCase, func(t *testing.T) {
resp, err = api.SearchRuns(ctx, &apiv1.SearchRunsRequest{
ProjectId: &projectID,
Sort: ptrs.Ptr(testVars.sortBy),
})

require.NoError(t, err)
require.Equal(t, int32(exp2.ID), resp.Runs[0].Experiment.Id)
require.Equal(t, int32(exp.ID), resp.Runs[1].Experiment.Id)
require.NoError(t, err)
if testVars.reverse {
require.Equal(t, int32(exp2.ID), resp.Runs[0].Experiment.Id)
require.Equal(t, int32(exp.ID), resp.Runs[1].Experiment.Id)
} else {
require.Equal(t, int32(exp.ID), resp.Runs[0].Experiment.Id)
require.Equal(t, int32(exp2.ID), resp.Runs[1].Experiment.Id)
}
})
}
}

func TestSearchRunsFilter(t *testing.T) {
Expand Down

0 comments on commit 7691de3

Please sign in to comment.