From 55ed7cc035a0bfac69b18380a4ef8ca619e2f422 Mon Sep 17 00:00:00 2001 From: AmanuelAaron Date: Thu, 8 Aug 2024 13:09:26 -0400 Subject: [PATCH 01/28] single get trial call --- master/internal/api_trials.go | 41 ++++++++++++++++---------- master/static/srv/get_trials_basic.sql | 10 +++++++ 2 files changed, 35 insertions(+), 16 deletions(-) create mode 100644 master/static/srv/get_trials_basic.sql diff --git a/master/internal/api_trials.go b/master/internal/api_trials.go index e62d548ecd4..4c6680c0123 100644 --- a/master/internal/api_trials.go +++ b/master/internal/api_trials.go @@ -902,27 +902,36 @@ func (a *apiServer) CompareTrials(ctx context.Context, req *apiv1.CompareTrialsRequest, ) (*apiv1.CompareTrialsResponse, error) { trialsList := make([]*apiv1.ComparableTrial, 0, len(req.TrialIds)) - for _, trialID := range req.TrialIds { - if err := trials.CanGetTrialsExperimentAndCheckCanDoAction(ctx, int(trialID), - experiment.AuthZProvider.Get().CanGetExperimentArtifacts); err != nil { - return nil, err - } + //nolint:staticcheck // SA1019: backward compatibility + metricGroup, err := a.parseMetricGroupArgs(req.MetricType, model.MetricGroup(req.Group)) + if err != nil { + return nil, err + } - container := &apiv1.ComparableTrial{Trial: &trialv1.Trial{}} - switch err := a.m.db.QueryProto("get_trial_basic", container.Trial, trialID); { - case err == db.ErrNotFound: - return nil, status.Errorf(codes.NotFound, "trial %d not found:", trialID) - case err != nil: - return nil, errors.Wrapf(err, "failed to get trial %d", trialID) - } + trialsObjList := []*trialv1.Trial{} - //nolint:staticcheck // SA1019: backward compatibility - metricGroup, err := a.parseMetricGroupArgs(req.MetricType, model.MetricGroup(req.Group)) - if err != nil { + trialIds := make([]string, 0, len(req.TrialIds)) + for _, userID := range req.TrialIds { + trialIds = append(trialIds, strconv.Itoa(int(userID))) + } + trialIDFilterExpr := strings.Join(trialIds, ",") + + switch err := a.m.db.QueryProto("get_trials_basic", &trialsObjList, trialIDFilterExpr); { + case err == db.ErrNotFound: + return nil, status.Errorf(codes.NotFound, "trial not found") + case err != nil: + return nil, errors.Wrapf(err, "failed to get trials") + } + + for _, trialObj := range trialsObjList { + if err := trials.CanGetTrialsExperimentAndCheckCanDoAction(ctx, int(trialObj.Id), + experiment.AuthZProvider.Get().CanGetExperimentArtifacts); err != nil { return nil, err } - tsample, err := a.multiTrialSample(trialID, req.MetricNames, metricGroup, + container := &apiv1.ComparableTrial{Trial: trialObj} + + tsample, err := a.multiTrialSample(trialObj.Id, req.MetricNames, metricGroup, int(req.MaxDatapoints), int(req.StartBatches), int(req.EndBatches), req.TimeSeriesFilter, req.MetricIds) if err != nil { diff --git a/master/static/srv/get_trials_basic.sql b/master/static/srv/get_trials_basic.sql new file mode 100644 index 00000000000..f26f1b401cf --- /dev/null +++ b/master/static/srv/get_trials_basic.sql @@ -0,0 +1,10 @@ +SELECT + t.id, + t.experiment_id, + 'STATE_' || t.state AS state, + t.start_time, + t.end_time, + t.hparams, + t.runner_state +FROM trials t +WHERE t.id IN (SELECT unnest(string_to_array($1, ',')::int [])) From ee190a17240374688efd598b64d87df18c4a036a Mon Sep 17 00:00:00 2001 From: AmanuelAaron Date: Mon, 12 Aug 2024 11:25:32 -0400 Subject: [PATCH 02/28] fix permission check --- master/internal/api_trials.go | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/master/internal/api_trials.go b/master/internal/api_trials.go index 4c6680c0123..0293bb2b7ec 100644 --- a/master/internal/api_trials.go +++ b/master/internal/api_trials.go @@ -902,6 +902,12 @@ func (a *apiServer) CompareTrials(ctx context.Context, req *apiv1.CompareTrialsRequest, ) (*apiv1.CompareTrialsResponse, error) { trialsList := make([]*apiv1.ComparableTrial, 0, len(req.TrialIds)) + for _, trialID := range req.TrialIds { + if err := trials.CanGetTrialsExperimentAndCheckCanDoAction(ctx, int(trialID), + experiment.AuthZProvider.Get().CanGetExperimentArtifacts); err != nil { + return nil, err + } + } //nolint:staticcheck // SA1019: backward compatibility metricGroup, err := a.parseMetricGroupArgs(req.MetricType, model.MetricGroup(req.Group)) if err != nil { @@ -911,24 +917,17 @@ func (a *apiServer) CompareTrials(ctx context.Context, trialsObjList := []*trialv1.Trial{} trialIds := make([]string, 0, len(req.TrialIds)) - for _, userID := range req.TrialIds { - trialIds = append(trialIds, strconv.Itoa(int(userID))) + for _, trialID := range req.TrialIds { + trialIds = append(trialIds, strconv.Itoa(int(trialID))) } trialIDFilterExpr := strings.Join(trialIds, ",") - switch err := a.m.db.QueryProto("get_trials_basic", &trialsObjList, trialIDFilterExpr); { - case err == db.ErrNotFound: - return nil, status.Errorf(codes.NotFound, "trial not found") - case err != nil: + err = a.m.db.QueryProto("get_trials_basic", &trialsObjList, trialIDFilterExpr) + if err != nil { return nil, errors.Wrapf(err, "failed to get trials") } for _, trialObj := range trialsObjList { - if err := trials.CanGetTrialsExperimentAndCheckCanDoAction(ctx, int(trialObj.Id), - experiment.AuthZProvider.Get().CanGetExperimentArtifacts); err != nil { - return nil, err - } - container := &apiv1.ComparableTrial{Trial: trialObj} tsample, err := a.multiTrialSample(trialObj.Id, req.MetricNames, metricGroup, From a2df316c8c7791227abd2e85c4b7f1400e8891ba Mon Sep 17 00:00:00 2001 From: AmanuelAaron Date: Wed, 14 Aug 2024 12:21:58 -0400 Subject: [PATCH 03/28] add perf test --- performance/k6/src/api_performance_tests.ts | 80 +++++++++++++++++++++ 1 file changed, 80 insertions(+) diff --git a/performance/k6/src/api_performance_tests.ts b/performance/k6/src/api_performance_tests.ts index 2ed7ab40141..ec4577824dc 100644 --- a/performance/k6/src/api_performance_tests.ts +++ b/performance/k6/src/api_performance_tests.ts @@ -328,6 +328,86 @@ const getloadTests = ( test("get group", getRequest(`/api/v1/groups/{groupId}`), RBAC_ENABLED), test("search groups", getRequest("/api/v1/groups/search"), RBAC_ENABLED), test("search roles", getRequest(`/api/v1/roles-search`), RBAC_ENABLED), + test("compare runs", getRequest(`/api/v1/trials/time-series?&trialIds=216 + &trialIds=218 + &trialIds=219 + &trialIds=220 + &trialIds=221 + &trialIds=223 + &trialIds=225 + &trialIds=226 + &trialIds=227 + &trialIds=228 + &trialIds=229 + &trialIds=230 + &trialIds=231 + &trialIds=232 + &trialIds=233 + &trialIds=234 + &trialIds=235 + &trialIds=236 + &trialIds=239 + &trialIds=240 + &trialIds=242 + &trialIds=246 + &trialIds=247 + &trialIds=249 + &trialIds=275 + &trialIds=276 + &trialIds=277 + &trialIds=278 + &trialIds=279 + &trialIds=280 + &trialIds=281 + &trialIds=282 + &trialIds=283 + &trialIds=284 + &trialIds=285 + &trialIds=286 + &trialIds=287 + &trialIds=288 + &trialIds=289 + &trialIds=290 + &trialIds=291 + &trialIds=292 + &trialIds=293 + &trialIds=294 + &trialIds=295 + &trialIds=296 + &trialIds=297 + &trialIds=298 + &trialIds=299 + &trialIds=300 + &trialIds=301 + &trialIds=302 + &trialIds=303 + &trialIds=304 + &trialIds=305 + &trialIds=306 + &trialIds=307 + &trialIds=308 + &trialIds=309 + &trialIds=310 + &trialIds=311 + &trialIds=312 + &trialIds=313 + &trialIds=314 + &trialIds=315 + &trialIds=316 + &trialIds=317 + &trialIds=318 + &trialIds=319 + &trialIds=320 + &trialIds=321 + &trialIds=322 + &trialIds=323 + &trialIds=324 + &trialIds=325 + &trialIds=326 + &trialIds=327 + &trialIds=328 + &trialIds=329 + &trialIds=330&maxDatapoints=1500&startBatches=0&metricType=METRIC_TYPE_UNSPECIFIED&metricIds=training.categorical_accuracy&metricIds=training.loss&metricIds=validation.val_categorical_accuracy&metricIds=validation.val_loss`), RBAC_ENABLED), test( "search roles by group", getRequest(`/api/v1/roles/search/by-group/{groupId}`), From ffcabd62813dd979c553d0103d1d0b9c11b05418 Mon Sep 17 00:00:00 2001 From: AmanuelAaron Date: Wed, 14 Aug 2024 13:32:11 -0400 Subject: [PATCH 04/28] move test --- performance/daist/daist/rest_api/tasks.py | 80 +++++++++++++++++++++ performance/k6/src/api_performance_tests.ts | 80 --------------------- 2 files changed, 80 insertions(+), 80 deletions(-) diff --git a/performance/daist/daist/rest_api/tasks.py b/performance/daist/daist/rest_api/tasks.py index 1ee46638591..18dcd470371 100644 --- a/performance/daist/daist/rest_api/tasks.py +++ b/performance/daist/daist/rest_api/tasks.py @@ -276,6 +276,86 @@ def read_only_tasks(resources: Resources) -> LocustTasksWithMeta: encoded_filter = parse.quote_plus(serialized_filter) tasks.append(LocustPostTaskWithMeta(f"/api/v1/runs?filter={encoded_filter}", test_name="search flat runs w/ hparam")) + tasks.append(LocustGetTaskWithMeta('''/api/v1/trials/time-series?&trialIds=216 + &trialIds=218 + &trialIds=219 + &trialIds=220 + &trialIds=221 + &trialIds=223 + &trialIds=225 + &trialIds=226 + &trialIds=227 + &trialIds=228 + &trialIds=229 + &trialIds=230 + &trialIds=231 + &trialIds=232 + &trialIds=233 + &trialIds=234 + &trialIds=235 + &trialIds=236 + &trialIds=239 + &trialIds=240 + &trialIds=242 + &trialIds=246 + &trialIds=247 + &trialIds=249 + &trialIds=275 + &trialIds=276 + &trialIds=277 + &trialIds=278 + &trialIds=279 + &trialIds=280 + &trialIds=281 + &trialIds=282 + &trialIds=283 + &trialIds=284 + &trialIds=285 + &trialIds=286 + &trialIds=287 + &trialIds=288 + &trialIds=289 + &trialIds=290 + &trialIds=291 + &trialIds=292 + &trialIds=293 + &trialIds=294 + &trialIds=295 + &trialIds=296 + &trialIds=297 + &trialIds=298 + &trialIds=299 + &trialIds=300 + &trialIds=301 + &trialIds=302 + &trialIds=303 + &trialIds=304 + &trialIds=305 + &trialIds=306 + &trialIds=307 + &trialIds=308 + &trialIds=309 + &trialIds=310 + &trialIds=311 + &trialIds=312 + &trialIds=313 + &trialIds=314 + &trialIds=315 + &trialIds=316 + &trialIds=317 + &trialIds=318 + &trialIds=319 + &trialIds=320 + &trialIds=321 + &trialIds=322 + &trialIds=323 + &trialIds=324 + &trialIds=325 + &trialIds=326 + &trialIds=327 + &trialIds=328 + &trialIds=329 + &trialIds=330&maxDatapoints=1500&startBatches=0&metricType=METRIC_TYPE_UNSPECIFIED&metricIds=training.categorical_accuracy&metricIds=training.loss&metricIds=validation.val_categorical_accuracy&metricIds=validation.val_loss''', test_name="compare runs")) if resources.project_id is not None: tasks.append(LocustPostTaskWithMeta(f"/api/v1/runs", test_name="search flat runs", body={"projectId": resources.project_id})) diff --git a/performance/k6/src/api_performance_tests.ts b/performance/k6/src/api_performance_tests.ts index ec4577824dc..2ed7ab40141 100644 --- a/performance/k6/src/api_performance_tests.ts +++ b/performance/k6/src/api_performance_tests.ts @@ -328,86 +328,6 @@ const getloadTests = ( test("get group", getRequest(`/api/v1/groups/{groupId}`), RBAC_ENABLED), test("search groups", getRequest("/api/v1/groups/search"), RBAC_ENABLED), test("search roles", getRequest(`/api/v1/roles-search`), RBAC_ENABLED), - test("compare runs", getRequest(`/api/v1/trials/time-series?&trialIds=216 - &trialIds=218 - &trialIds=219 - &trialIds=220 - &trialIds=221 - &trialIds=223 - &trialIds=225 - &trialIds=226 - &trialIds=227 - &trialIds=228 - &trialIds=229 - &trialIds=230 - &trialIds=231 - &trialIds=232 - &trialIds=233 - &trialIds=234 - &trialIds=235 - &trialIds=236 - &trialIds=239 - &trialIds=240 - &trialIds=242 - &trialIds=246 - &trialIds=247 - &trialIds=249 - &trialIds=275 - &trialIds=276 - &trialIds=277 - &trialIds=278 - &trialIds=279 - &trialIds=280 - &trialIds=281 - &trialIds=282 - &trialIds=283 - &trialIds=284 - &trialIds=285 - &trialIds=286 - &trialIds=287 - &trialIds=288 - &trialIds=289 - &trialIds=290 - &trialIds=291 - &trialIds=292 - &trialIds=293 - &trialIds=294 - &trialIds=295 - &trialIds=296 - &trialIds=297 - &trialIds=298 - &trialIds=299 - &trialIds=300 - &trialIds=301 - &trialIds=302 - &trialIds=303 - &trialIds=304 - &trialIds=305 - &trialIds=306 - &trialIds=307 - &trialIds=308 - &trialIds=309 - &trialIds=310 - &trialIds=311 - &trialIds=312 - &trialIds=313 - &trialIds=314 - &trialIds=315 - &trialIds=316 - &trialIds=317 - &trialIds=318 - &trialIds=319 - &trialIds=320 - &trialIds=321 - &trialIds=322 - &trialIds=323 - &trialIds=324 - &trialIds=325 - &trialIds=326 - &trialIds=327 - &trialIds=328 - &trialIds=329 - &trialIds=330&maxDatapoints=1500&startBatches=0&metricType=METRIC_TYPE_UNSPECIFIED&metricIds=training.categorical_accuracy&metricIds=training.loss&metricIds=validation.val_categorical_accuracy&metricIds=validation.val_loss`), RBAC_ENABLED), test( "search roles by group", getRequest(`/api/v1/roles/search/by-group/{groupId}`), From 8c0600b35302c5ec0936f9b696b803edb42185c2 Mon Sep 17 00:00:00 2001 From: AmanuelAaron Date: Wed, 14 Aug 2024 13:59:16 -0400 Subject: [PATCH 05/28] formatting --- performance/daist/daist/rest_api/tasks.py | 98 +++++------------------ 1 file changed, 18 insertions(+), 80 deletions(-) diff --git a/performance/daist/daist/rest_api/tasks.py b/performance/daist/daist/rest_api/tasks.py index 18dcd470371..7acf13de7a6 100644 --- a/performance/daist/daist/rest_api/tasks.py +++ b/performance/daist/daist/rest_api/tasks.py @@ -276,86 +276,24 @@ def read_only_tasks(resources: Resources) -> LocustTasksWithMeta: encoded_filter = parse.quote_plus(serialized_filter) tasks.append(LocustPostTaskWithMeta(f"/api/v1/runs?filter={encoded_filter}", test_name="search flat runs w/ hparam")) - tasks.append(LocustGetTaskWithMeta('''/api/v1/trials/time-series?&trialIds=216 - &trialIds=218 - &trialIds=219 - &trialIds=220 - &trialIds=221 - &trialIds=223 - &trialIds=225 - &trialIds=226 - &trialIds=227 - &trialIds=228 - &trialIds=229 - &trialIds=230 - &trialIds=231 - &trialIds=232 - &trialIds=233 - &trialIds=234 - &trialIds=235 - &trialIds=236 - &trialIds=239 - &trialIds=240 - &trialIds=242 - &trialIds=246 - &trialIds=247 - &trialIds=249 - &trialIds=275 - &trialIds=276 - &trialIds=277 - &trialIds=278 - &trialIds=279 - &trialIds=280 - &trialIds=281 - &trialIds=282 - &trialIds=283 - &trialIds=284 - &trialIds=285 - &trialIds=286 - &trialIds=287 - &trialIds=288 - &trialIds=289 - &trialIds=290 - &trialIds=291 - &trialIds=292 - &trialIds=293 - &trialIds=294 - &trialIds=295 - &trialIds=296 - &trialIds=297 - &trialIds=298 - &trialIds=299 - &trialIds=300 - &trialIds=301 - &trialIds=302 - &trialIds=303 - &trialIds=304 - &trialIds=305 - &trialIds=306 - &trialIds=307 - &trialIds=308 - &trialIds=309 - &trialIds=310 - &trialIds=311 - &trialIds=312 - &trialIds=313 - &trialIds=314 - &trialIds=315 - &trialIds=316 - &trialIds=317 - &trialIds=318 - &trialIds=319 - &trialIds=320 - &trialIds=321 - &trialIds=322 - &trialIds=323 - &trialIds=324 - &trialIds=325 - &trialIds=326 - &trialIds=327 - &trialIds=328 - &trialIds=329 - &trialIds=330&maxDatapoints=1500&startBatches=0&metricType=METRIC_TYPE_UNSPECIFIED&metricIds=training.categorical_accuracy&metricIds=training.loss&metricIds=validation.val_categorical_accuracy&metricIds=validation.val_loss''', test_name="compare runs")) + tasks.append(LocustGetTaskWithMeta( + '''/api/v1/trials/time-series?trialIds=216&trialIds=218&trialIds=219&trialIds=220 + &trialIds=221&trialIds=223&trialIds=225&trialIds=226&trialIds=227&trialIds=228 + &trialIds=229&trialIds=230&trialIds=231&trialIds=232&trialIds=233&trialIds=234 + &trialIds=235&trialIds=236&trialIds=239&trialIds=240&trialIds=242&trialIds=246 + &trialIds=247&trialIds=249&trialIds=275&trialIds=276&trialIds=277&trialIds=278 + &trialIds=279&trialIds=280&trialIds=281&trialIds=282&trialIds=283&trialIds=284 + &trialIds=285&trialIds=286&trialIds=287&trialIds=288&trialIds=289&trialIds=290 + &trialIds=291&trialIds=292&trialIds=293&trialIds=294&trialIds=295&trialIds=296 + &trialIds=297&trialIds=298&trialIds=299&trialIds=300&trialIds=301&trialIds=302 + &trialIds=303&trialIds=304&trialIds=305&trialIds=306&trialIds=307&trialIds=308 + &trialIds=309&trialIds=310&trialIds=311&trialIds=312&trialIds=313&trialIds=314 + &trialIds=315&trialIds=316&trialIds=317&trialIds=318&trialIds=319&trialIds=320 + &trialIds=321&trialIds=322&trialIds=323&trialIds=324&trialIds=325&trialIds=326 + &trialIds=327&trialIds=328&trialIds=329&trialIds=330&maxDatapoints=1500 + &startBatches=0&metricType=METRIC_TYPE_UNSPECIFIED&metricIds=training.categorical_accuracy + &metricIds=training.loss&metricIds=validation.val_categorical_accuracy + &metricIds=validation.val_loss''', test_name="compare runs")) if resources.project_id is not None: tasks.append(LocustPostTaskWithMeta(f"/api/v1/runs", test_name="search flat runs", body={"projectId": resources.project_id})) From e30790d9858e1444ac30862928e5a2e6e4d3da67 Mon Sep 17 00:00:00 2001 From: AmanuelAaron Date: Wed, 14 Aug 2024 14:25:33 -0400 Subject: [PATCH 06/28] formatting --- performance/daist/daist/rest_api/tasks.py | 34 +++++++++++------------ 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/performance/daist/daist/rest_api/tasks.py b/performance/daist/daist/rest_api/tasks.py index 7acf13de7a6..a986ae0de2b 100644 --- a/performance/daist/daist/rest_api/tasks.py +++ b/performance/daist/daist/rest_api/tasks.py @@ -277,23 +277,23 @@ def read_only_tasks(resources: Resources) -> LocustTasksWithMeta: tasks.append(LocustPostTaskWithMeta(f"/api/v1/runs?filter={encoded_filter}", test_name="search flat runs w/ hparam")) tasks.append(LocustGetTaskWithMeta( - '''/api/v1/trials/time-series?trialIds=216&trialIds=218&trialIds=219&trialIds=220 - &trialIds=221&trialIds=223&trialIds=225&trialIds=226&trialIds=227&trialIds=228 - &trialIds=229&trialIds=230&trialIds=231&trialIds=232&trialIds=233&trialIds=234 - &trialIds=235&trialIds=236&trialIds=239&trialIds=240&trialIds=242&trialIds=246 - &trialIds=247&trialIds=249&trialIds=275&trialIds=276&trialIds=277&trialIds=278 - &trialIds=279&trialIds=280&trialIds=281&trialIds=282&trialIds=283&trialIds=284 - &trialIds=285&trialIds=286&trialIds=287&trialIds=288&trialIds=289&trialIds=290 - &trialIds=291&trialIds=292&trialIds=293&trialIds=294&trialIds=295&trialIds=296 - &trialIds=297&trialIds=298&trialIds=299&trialIds=300&trialIds=301&trialIds=302 - &trialIds=303&trialIds=304&trialIds=305&trialIds=306&trialIds=307&trialIds=308 - &trialIds=309&trialIds=310&trialIds=311&trialIds=312&trialIds=313&trialIds=314 - &trialIds=315&trialIds=316&trialIds=317&trialIds=318&trialIds=319&trialIds=320 - &trialIds=321&trialIds=322&trialIds=323&trialIds=324&trialIds=325&trialIds=326 - &trialIds=327&trialIds=328&trialIds=329&trialIds=330&maxDatapoints=1500 - &startBatches=0&metricType=METRIC_TYPE_UNSPECIFIED&metricIds=training.categorical_accuracy - &metricIds=training.loss&metricIds=validation.val_categorical_accuracy - &metricIds=validation.val_loss''', test_name="compare runs")) + "/api/v1/trials/time-series?trialIds=216&trialIds=218&trialIds=219&trialIds=220" + + "&trialIds=221&trialIds=223&trialIds=225&trialIds=226&trialIds=227&trialIds=228" + + "&trialIds=229&trialIds=230&trialIds=231&trialIds=232&trialIds=233&trialIds=234" + + "&trialIds=235&trialIds=236&trialIds=239&trialIds=240&trialIds=242&trialIds=246" + + "&trialIds=247&trialIds=249&trialIds=275&trialIds=276&trialIds=277&trialIds=278" + + "&trialIds=279&trialIds=280&trialIds=281&trialIds=282&trialIds=283&trialIds=284" + + "&trialIds=285&trialIds=286&trialIds=287&trialIds=288&trialIds=289&trialIds=290" + + "&trialIds=291&trialIds=292&trialIds=293&trialIds=294&trialIds=295&trialIds=296" + + "&trialIds=297&trialIds=298&trialIds=299&trialIds=300&trialIds=301&trialIds=302" + + "&trialIds=303&trialIds=304&trialIds=305&trialIds=306&trialIds=307&trialIds=308" + + "&trialIds=309&trialIds=310&trialIds=311&trialIds=312&trialIds=313&trialIds=314" + + "&trialIds=315&trialIds=316&trialIds=317&trialIds=318&trialIds=319&trialIds=320" + + "&trialIds=321&trialIds=322&trialIds=323&trialIds=324&trialIds=325&trialIds=326" + + "&trialIds=327&trialIds=328&trialIds=329&trialIds=330&maxDatapoints=1500" + + "&startBatches=0&metricType=METRIC_TYPE_UNSPECIFIED&metricIds=training.categorical_accuracy" + + "&metricIds=training.loss&metricIds=validation.val_categorical_accuracy" + + "&metricIds=validation.val_loss", test_name="compare runs")) if resources.project_id is not None: tasks.append(LocustPostTaskWithMeta(f"/api/v1/runs", test_name="search flat runs", body={"projectId": resources.project_id})) From ba0cb10f279dcf86482f56f49eececd777c76014 Mon Sep 17 00:00:00 2001 From: AmanuelAaron Date: Wed, 14 Aug 2024 15:19:12 -0400 Subject: [PATCH 07/28] update call --- performance/daist/daist/rest_api/tasks.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/performance/daist/daist/rest_api/tasks.py b/performance/daist/daist/rest_api/tasks.py index a986ae0de2b..93cf3362eb9 100644 --- a/performance/daist/daist/rest_api/tasks.py +++ b/performance/daist/daist/rest_api/tasks.py @@ -291,9 +291,8 @@ def read_only_tasks(resources: Resources) -> LocustTasksWithMeta: "&trialIds=315&trialIds=316&trialIds=317&trialIds=318&trialIds=319&trialIds=320" + "&trialIds=321&trialIds=322&trialIds=323&trialIds=324&trialIds=325&trialIds=326" + "&trialIds=327&trialIds=328&trialIds=329&trialIds=330&maxDatapoints=1500" + - "&startBatches=0&metricType=METRIC_TYPE_UNSPECIFIED&metricIds=training.categorical_accuracy" + - "&metricIds=training.loss&metricIds=validation.val_categorical_accuracy" + - "&metricIds=validation.val_loss", test_name="compare runs")) + "&startBatches=0&metricType=METRIC_TYPE_TRAINING&metricIds=avg_metrics.c747a" + + "&metricIds=avg_metrics.f0", test_name="compare runs")) if resources.project_id is not None: tasks.append(LocustPostTaskWithMeta(f"/api/v1/runs", test_name="search flat runs", body={"projectId": resources.project_id})) From 8eae1e0bbd34d4829c12fdbdcd9705d94ee10248 Mon Sep 17 00:00:00 2001 From: AmanuelAaron Date: Wed, 14 Aug 2024 15:41:55 -0400 Subject: [PATCH 08/28] redo perf test --- performance/daist/daist/rest_api/tasks.py | 27 +++++++++-------------- 1 file changed, 10 insertions(+), 17 deletions(-) diff --git a/performance/daist/daist/rest_api/tasks.py b/performance/daist/daist/rest_api/tasks.py index 93cf3362eb9..710a57aa2ff 100644 --- a/performance/daist/daist/rest_api/tasks.py +++ b/performance/daist/daist/rest_api/tasks.py @@ -150,6 +150,16 @@ def read_only_tasks(resources: Resources) -> LocustTasksWithMeta: "trialIds": resources.trial_id, "startBatches": 0, "metricType": "METRIC_TYPE_UNSPECIFIED"})) + + tasks.append(LocustGetTaskWithMeta( + "/api/v1/trials/time-series", test_name="get trials time series large payload", + params={ + "trialIds": [216,218,219,220,221,223,225,226,227,228,229,230,231,232,233,234,235,236,239,240, + 242,246,247,249,275,276,277,278,279,280,281,282,283,284,285,286,287,288,289,290,291, + 292,293,294,295,296,297,298,299,300,301,302,303,304,305,306,307,308,309,310, + 311,312,313,314,315,316,317,318,319,320,321,322,323,324,325,326,327,328,329,330], + "startBatches": 0, + "metricType": "METRIC_TYPE_UNSPECIFIED"})) tasks.append(LocustGetTaskWithMeta(f"/api/v1/trials/{resources.trial_id}/checkpoints", test_name="get trial checkpoints")) @@ -276,23 +286,6 @@ def read_only_tasks(resources: Resources) -> LocustTasksWithMeta: encoded_filter = parse.quote_plus(serialized_filter) tasks.append(LocustPostTaskWithMeta(f"/api/v1/runs?filter={encoded_filter}", test_name="search flat runs w/ hparam")) - tasks.append(LocustGetTaskWithMeta( - "/api/v1/trials/time-series?trialIds=216&trialIds=218&trialIds=219&trialIds=220" + - "&trialIds=221&trialIds=223&trialIds=225&trialIds=226&trialIds=227&trialIds=228" + - "&trialIds=229&trialIds=230&trialIds=231&trialIds=232&trialIds=233&trialIds=234" + - "&trialIds=235&trialIds=236&trialIds=239&trialIds=240&trialIds=242&trialIds=246" + - "&trialIds=247&trialIds=249&trialIds=275&trialIds=276&trialIds=277&trialIds=278" + - "&trialIds=279&trialIds=280&trialIds=281&trialIds=282&trialIds=283&trialIds=284" + - "&trialIds=285&trialIds=286&trialIds=287&trialIds=288&trialIds=289&trialIds=290" + - "&trialIds=291&trialIds=292&trialIds=293&trialIds=294&trialIds=295&trialIds=296" + - "&trialIds=297&trialIds=298&trialIds=299&trialIds=300&trialIds=301&trialIds=302" + - "&trialIds=303&trialIds=304&trialIds=305&trialIds=306&trialIds=307&trialIds=308" + - "&trialIds=309&trialIds=310&trialIds=311&trialIds=312&trialIds=313&trialIds=314" + - "&trialIds=315&trialIds=316&trialIds=317&trialIds=318&trialIds=319&trialIds=320" + - "&trialIds=321&trialIds=322&trialIds=323&trialIds=324&trialIds=325&trialIds=326" + - "&trialIds=327&trialIds=328&trialIds=329&trialIds=330&maxDatapoints=1500" + - "&startBatches=0&metricType=METRIC_TYPE_TRAINING&metricIds=avg_metrics.c747a" + - "&metricIds=avg_metrics.f0", test_name="compare runs")) if resources.project_id is not None: tasks.append(LocustPostTaskWithMeta(f"/api/v1/runs", test_name="search flat runs", body={"projectId": resources.project_id})) From 98a4cd3b6ffffef55310f174c4cf439dd106fa49 Mon Sep 17 00:00:00 2001 From: AmanuelAaron Date: Thu, 15 Aug 2024 10:17:12 -0400 Subject: [PATCH 09/28] format --- performance/daist/daist/rest_api/tasks.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/performance/daist/daist/rest_api/tasks.py b/performance/daist/daist/rest_api/tasks.py index 710a57aa2ff..4d5865c4e90 100644 --- a/performance/daist/daist/rest_api/tasks.py +++ b/performance/daist/daist/rest_api/tasks.py @@ -154,10 +154,20 @@ def read_only_tasks(resources: Resources) -> LocustTasksWithMeta: tasks.append(LocustGetTaskWithMeta( "/api/v1/trials/time-series", test_name="get trials time series large payload", params={ - "trialIds": [216,218,219,220,221,223,225,226,227,228,229,230,231,232,233,234,235,236,239,240, - 242,246,247,249,275,276,277,278,279,280,281,282,283,284,285,286,287,288,289,290,291, - 292,293,294,295,296,297,298,299,300,301,302,303,304,305,306,307,308,309,310, - 311,312,313,314,315,316,317,318,319,320,321,322,323,324,325,326,327,328,329,330], + "trialIds": "trialIds=216&trialIds=218&trialIds=219&trialIds=220" + + "&trialIds=221&trialIds=223&trialIds=225&trialIds=226&trialIds=227&trialIds=228" + + "&trialIds=229&trialIds=230&trialIds=231&trialIds=232&trialIds=233&trialIds=234" + + "&trialIds=235&trialIds=236&trialIds=239&trialIds=240&trialIds=242&trialIds=246" + + "&trialIds=247&trialIds=249&trialIds=275&trialIds=276&trialIds=277&trialIds=278" + + "&trialIds=279&trialIds=280&trialIds=281&trialIds=282&trialIds=283&trialIds=284" + + "&trialIds=285&trialIds=286&trialIds=287&trialIds=288&trialIds=289&trialIds=290" + + "&trialIds=291&trialIds=292&trialIds=293&trialIds=294&trialIds=295&trialIds=296" + + "&trialIds=297&trialIds=298&trialIds=299&trialIds=300&trialIds=301&trialIds=302" + + "&trialIds=303&trialIds=304&trialIds=305&trialIds=306&trialIds=307&trialIds=308" + + "&trialIds=309&trialIds=310&trialIds=311&trialIds=312&trialIds=313&trialIds=314" + + "&trialIds=315&trialIds=316&trialIds=317&trialIds=318&trialIds=319&trialIds=320" + + "&trialIds=321&trialIds=322&trialIds=323&trialIds=324&trialIds=325&trialIds=326" + + "&trialIds=327&trialIds=328&trialIds=329&trialIds=330", "startBatches": 0, "metricType": "METRIC_TYPE_UNSPECIFIED"})) From 5f11bbf5333cdff22646f21c69924b46ff820e1b Mon Sep 17 00:00:00 2001 From: AmanuelAaron Date: Thu, 15 Aug 2024 10:25:32 -0400 Subject: [PATCH 10/28] do sequence --- .../daist/daist/rest_api/locust_utils.py | 2 +- performance/daist/daist/rest_api/tasks.py | 18 ++++-------------- 2 files changed, 5 insertions(+), 15 deletions(-) diff --git a/performance/daist/daist/rest_api/locust_utils.py b/performance/daist/daist/rest_api/locust_utils.py index 6b418b7ec9e..065bf530531 100644 --- a/performance/daist/daist/rest_api/locust_utils.py +++ b/performance/daist/daist/rest_api/locust_utils.py @@ -100,7 +100,7 @@ def task(self, user: HttpUser) -> Response: def get_task(endpoint, params=None) -> HttpTask_t: query_string = "" if params is not None: - query_string = f"?{urlencode(params)}" + query_string = f"?{urlencode(params, True)}" def task(user: HttpUser): return user.client.get(f"{endpoint}{query_string}") diff --git a/performance/daist/daist/rest_api/tasks.py b/performance/daist/daist/rest_api/tasks.py index 4d5865c4e90..710a57aa2ff 100644 --- a/performance/daist/daist/rest_api/tasks.py +++ b/performance/daist/daist/rest_api/tasks.py @@ -154,20 +154,10 @@ def read_only_tasks(resources: Resources) -> LocustTasksWithMeta: tasks.append(LocustGetTaskWithMeta( "/api/v1/trials/time-series", test_name="get trials time series large payload", params={ - "trialIds": "trialIds=216&trialIds=218&trialIds=219&trialIds=220" + - "&trialIds=221&trialIds=223&trialIds=225&trialIds=226&trialIds=227&trialIds=228" + - "&trialIds=229&trialIds=230&trialIds=231&trialIds=232&trialIds=233&trialIds=234" + - "&trialIds=235&trialIds=236&trialIds=239&trialIds=240&trialIds=242&trialIds=246" + - "&trialIds=247&trialIds=249&trialIds=275&trialIds=276&trialIds=277&trialIds=278" + - "&trialIds=279&trialIds=280&trialIds=281&trialIds=282&trialIds=283&trialIds=284" + - "&trialIds=285&trialIds=286&trialIds=287&trialIds=288&trialIds=289&trialIds=290" + - "&trialIds=291&trialIds=292&trialIds=293&trialIds=294&trialIds=295&trialIds=296" + - "&trialIds=297&trialIds=298&trialIds=299&trialIds=300&trialIds=301&trialIds=302" + - "&trialIds=303&trialIds=304&trialIds=305&trialIds=306&trialIds=307&trialIds=308" + - "&trialIds=309&trialIds=310&trialIds=311&trialIds=312&trialIds=313&trialIds=314" + - "&trialIds=315&trialIds=316&trialIds=317&trialIds=318&trialIds=319&trialIds=320" + - "&trialIds=321&trialIds=322&trialIds=323&trialIds=324&trialIds=325&trialIds=326" + - "&trialIds=327&trialIds=328&trialIds=329&trialIds=330", + "trialIds": [216,218,219,220,221,223,225,226,227,228,229,230,231,232,233,234,235,236,239,240, + 242,246,247,249,275,276,277,278,279,280,281,282,283,284,285,286,287,288,289,290,291, + 292,293,294,295,296,297,298,299,300,301,302,303,304,305,306,307,308,309,310, + 311,312,313,314,315,316,317,318,319,320,321,322,323,324,325,326,327,328,329,330], "startBatches": 0, "metricType": "METRIC_TYPE_UNSPECIFIED"})) From 4accdcaf84a9ee70de071f0eabe84139f9887489 Mon Sep 17 00:00:00 2001 From: AmanuelAaron Date: Thu, 15 Aug 2024 10:46:43 -0400 Subject: [PATCH 11/28] try again --- performance/daist/daist/rest_api/locust_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/performance/daist/daist/rest_api/locust_utils.py b/performance/daist/daist/rest_api/locust_utils.py index 065bf530531..1a6f1fe5e22 100644 --- a/performance/daist/daist/rest_api/locust_utils.py +++ b/performance/daist/daist/rest_api/locust_utils.py @@ -100,7 +100,7 @@ def task(self, user: HttpUser) -> Response: def get_task(endpoint, params=None) -> HttpTask_t: query_string = "" if params is not None: - query_string = f"?{urlencode(params, True)}" + query_string = f"?{urlencode(params, doseq=True)}" def task(user: HttpUser): return user.client.get(f"{endpoint}{query_string}") From 6fafd70ce54f7b7ff9d81bf5a309facc14fc26c1 Mon Sep 17 00:00:00 2001 From: AmanuelAaron Date: Thu, 15 Aug 2024 11:13:15 -0400 Subject: [PATCH 12/28] wrong spot --- performance/daist/daist/rest_api/locust_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/performance/daist/daist/rest_api/locust_utils.py b/performance/daist/daist/rest_api/locust_utils.py index 1a6f1fe5e22..84f8786aa67 100644 --- a/performance/daist/daist/rest_api/locust_utils.py +++ b/performance/daist/daist/rest_api/locust_utils.py @@ -73,7 +73,7 @@ def __init__(self, endpoint: Url_t, *, params: Optional[Dict] = None, self.params = params self._url = endpoint if params is not None: - self._url += f'?{urlencode(params)}' + self._url += f'?{urlencode(params, doseq=True)}' @property def url(self) -> Url_t: @@ -100,7 +100,7 @@ def task(self, user: HttpUser) -> Response: def get_task(endpoint, params=None) -> HttpTask_t: query_string = "" if params is not None: - query_string = f"?{urlencode(params, doseq=True)}" + query_string = f"?{urlencode(params)}" def task(user: HttpUser): return user.client.get(f"{endpoint}{query_string}") From 039172db99ffe8cb9ef62d0f87dddb3a124fcf82 Mon Sep 17 00:00:00 2001 From: AmanuelAaron Date: Thu, 15 Aug 2024 11:37:34 -0400 Subject: [PATCH 13/28] wrap errors --- master/internal/api_trials.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/master/internal/api_trials.go b/master/internal/api_trials.go index 0293bb2b7ec..521611eba4a 100644 --- a/master/internal/api_trials.go +++ b/master/internal/api_trials.go @@ -905,13 +905,13 @@ func (a *apiServer) CompareTrials(ctx context.Context, for _, trialID := range req.TrialIds { if err := trials.CanGetTrialsExperimentAndCheckCanDoAction(ctx, int(trialID), experiment.AuthZProvider.Get().CanGetExperimentArtifacts); err != nil { - return nil, err + return nil, errors.Wrapf(err, "failed validate permissions") } } //nolint:staticcheck // SA1019: backward compatibility metricGroup, err := a.parseMetricGroupArgs(req.MetricType, model.MetricGroup(req.Group)) if err != nil { - return nil, err + return nil, errors.Wrapf(err, "failed to parse metric group args") } trialsObjList := []*trialv1.Trial{} From d46ea7f298da88b643b5e90877b12409ab5ac550 Mon Sep 17 00:00:00 2001 From: AmanuelAaron Date: Thu, 15 Aug 2024 11:57:57 -0400 Subject: [PATCH 14/28] try with perm check --- master/internal/api_trials.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/master/internal/api_trials.go b/master/internal/api_trials.go index 521611eba4a..0aaeaa1d55e 100644 --- a/master/internal/api_trials.go +++ b/master/internal/api_trials.go @@ -902,12 +902,12 @@ func (a *apiServer) CompareTrials(ctx context.Context, req *apiv1.CompareTrialsRequest, ) (*apiv1.CompareTrialsResponse, error) { trialsList := make([]*apiv1.ComparableTrial, 0, len(req.TrialIds)) - for _, trialID := range req.TrialIds { - if err := trials.CanGetTrialsExperimentAndCheckCanDoAction(ctx, int(trialID), - experiment.AuthZProvider.Get().CanGetExperimentArtifacts); err != nil { - return nil, errors.Wrapf(err, "failed validate permissions") - } - } + // for _, trialID := range req.TrialIds { + // if err := trials.CanGetTrialsExperimentAndCheckCanDoAction(ctx, int(trialID), + // experiment.AuthZProvider.Get().CanGetExperimentArtifacts); err != nil { + // return nil, errors.Wrapf(err, "failed validate permissions") + // } + // } //nolint:staticcheck // SA1019: backward compatibility metricGroup, err := a.parseMetricGroupArgs(req.MetricType, model.MetricGroup(req.Group)) if err != nil { From 50881777b4da735cf1ad4d703c6991ba1b5b5acf Mon Sep 17 00:00:00 2001 From: AmanuelAaron Date: Thu, 15 Aug 2024 13:12:52 -0400 Subject: [PATCH 15/28] newer ids --- master/internal/api_trials.go | 12 ++++++------ performance/daist/daist/rest_api/tasks.py | 18 ++++++++++-------- 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/master/internal/api_trials.go b/master/internal/api_trials.go index 0aaeaa1d55e..521611eba4a 100644 --- a/master/internal/api_trials.go +++ b/master/internal/api_trials.go @@ -902,12 +902,12 @@ func (a *apiServer) CompareTrials(ctx context.Context, req *apiv1.CompareTrialsRequest, ) (*apiv1.CompareTrialsResponse, error) { trialsList := make([]*apiv1.ComparableTrial, 0, len(req.TrialIds)) - // for _, trialID := range req.TrialIds { - // if err := trials.CanGetTrialsExperimentAndCheckCanDoAction(ctx, int(trialID), - // experiment.AuthZProvider.Get().CanGetExperimentArtifacts); err != nil { - // return nil, errors.Wrapf(err, "failed validate permissions") - // } - // } + for _, trialID := range req.TrialIds { + if err := trials.CanGetTrialsExperimentAndCheckCanDoAction(ctx, int(trialID), + experiment.AuthZProvider.Get().CanGetExperimentArtifacts); err != nil { + return nil, errors.Wrapf(err, "failed validate permissions") + } + } //nolint:staticcheck // SA1019: backward compatibility metricGroup, err := a.parseMetricGroupArgs(req.MetricType, model.MetricGroup(req.Group)) if err != nil { diff --git a/performance/daist/daist/rest_api/tasks.py b/performance/daist/daist/rest_api/tasks.py index 710a57aa2ff..394e79c57f3 100644 --- a/performance/daist/daist/rest_api/tasks.py +++ b/performance/daist/daist/rest_api/tasks.py @@ -153,14 +153,16 @@ def read_only_tasks(resources: Resources) -> LocustTasksWithMeta: tasks.append(LocustGetTaskWithMeta( "/api/v1/trials/time-series", test_name="get trials time series large payload", - params={ - "trialIds": [216,218,219,220,221,223,225,226,227,228,229,230,231,232,233,234,235,236,239,240, - 242,246,247,249,275,276,277,278,279,280,281,282,283,284,285,286,287,288,289,290,291, - 292,293,294,295,296,297,298,299,300,301,302,303,304,305,306,307,308,309,310, - 311,312,313,314,315,316,317,318,319,320,321,322,323,324,325,326,327,328,329,330], - "startBatches": 0, - "metricType": "METRIC_TYPE_UNSPECIFIED"})) - + params={ + "trialIds": [8283,8284,8285,8286,8287,8288,8289,8290,8291,8292,8293,8294,8295, + 8296,8297,8298,8299,8300,8301,8302,8303,8304,8305,8306,8307,8308, + 8309,8310,8311,8312,8313,8314,8315,8316,8317,8318,8319,8320,8321, + 8322,8323,8324,8325,8326,8327,8328,8329,8330,8331,8332,8333,8334, + 8335,8336,8337,8338,8339,8340,8341,8342,8343,8344,8345,8346,8347, + 8348,8349,8350,8351,8352,8353,8354,8355,8356,8357,8358,8359,8360, + 8361,8362], + "startBatches": 0, + "metricType": "METRIC_TYPE_UNSPECIFIED"} tasks.append(LocustGetTaskWithMeta(f"/api/v1/trials/{resources.trial_id}/checkpoints", test_name="get trial checkpoints")) tasks.append(LocustGetTaskWithMeta(f"/api/v1/trials/{resources.trial_id}/profiler/metrics", From 94286c6c477f99976f4f6320d8e52cdf1ec29fb1 Mon Sep 17 00:00:00 2001 From: AmanuelAaron Date: Thu, 15 Aug 2024 13:31:25 -0400 Subject: [PATCH 16/28] fix syntax error --- performance/daist/daist/rest_api/tasks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/performance/daist/daist/rest_api/tasks.py b/performance/daist/daist/rest_api/tasks.py index 394e79c57f3..ac682b8be0a 100644 --- a/performance/daist/daist/rest_api/tasks.py +++ b/performance/daist/daist/rest_api/tasks.py @@ -162,7 +162,7 @@ def read_only_tasks(resources: Resources) -> LocustTasksWithMeta: 8348,8349,8350,8351,8352,8353,8354,8355,8356,8357,8358,8359,8360, 8361,8362], "startBatches": 0, - "metricType": "METRIC_TYPE_UNSPECIFIED"} + "metricType": "METRIC_TYPE_UNSPECIFIED"})) tasks.append(LocustGetTaskWithMeta(f"/api/v1/trials/{resources.trial_id}/checkpoints", test_name="get trial checkpoints")) tasks.append(LocustGetTaskWithMeta(f"/api/v1/trials/{resources.trial_id}/profiler/metrics", From 6429347b5e06b096a3d54e9bb9745eb9bea8cc5a Mon Sep 17 00:00:00 2001 From: AmanuelAaron Date: Thu, 15 Aug 2024 15:06:16 -0400 Subject: [PATCH 17/28] update perm check --- master/internal/api_trials.go | 16 ++++++------- master/internal/api_trials_intg_test.go | 12 +++++----- master/internal/db/postgres_experiments.go | 16 +++++++++++++ master/internal/trials/utils.go | 27 ++++++++++++++++++++++ 4 files changed, 57 insertions(+), 14 deletions(-) diff --git a/master/internal/api_trials.go b/master/internal/api_trials.go index 521611eba4a..e2ab9ebd787 100644 --- a/master/internal/api_trials.go +++ b/master/internal/api_trials.go @@ -902,11 +902,15 @@ func (a *apiServer) CompareTrials(ctx context.Context, req *apiv1.CompareTrialsRequest, ) (*apiv1.CompareTrialsResponse, error) { trialsList := make([]*apiv1.ComparableTrial, 0, len(req.TrialIds)) + trialIds := make([]string, 0, len(req.TrialIds)) + trialIntList := make([]int, 0, len(req.TrialIds)) for _, trialID := range req.TrialIds { - if err := trials.CanGetTrialsExperimentAndCheckCanDoAction(ctx, int(trialID), - experiment.AuthZProvider.Get().CanGetExperimentArtifacts); err != nil { - return nil, errors.Wrapf(err, "failed validate permissions") - } + trialIds = append(trialIds, strconv.Itoa(int(trialID))) + trialIntList = append(trialIntList, int(trialID)) + } + if err := trials.CanGetTrialsExperimentAndCheckCanDoActionBulk(ctx, trialIntList, + experiment.AuthZProvider.Get().CanGetExperimentArtifacts); err != nil { + return nil, errors.Wrapf(err, "failed validate permissions") } //nolint:staticcheck // SA1019: backward compatibility metricGroup, err := a.parseMetricGroupArgs(req.MetricType, model.MetricGroup(req.Group)) @@ -916,10 +920,6 @@ func (a *apiServer) CompareTrials(ctx context.Context, trialsObjList := []*trialv1.Trial{} - trialIds := make([]string, 0, len(req.TrialIds)) - for _, trialID := range req.TrialIds { - trialIds = append(trialIds, strconv.Itoa(int(trialID))) - } trialIDFilterExpr := strings.Join(trialIds, ",") err = a.m.db.QueryProto("get_trials_basic", &trialsObjList, trialIDFilterExpr) diff --git a/master/internal/api_trials_intg_test.go b/master/internal/api_trials_intg_test.go index 7bd9e2ad8aa..e500d319efc 100644 --- a/master/internal/api_trials_intg_test.go +++ b/master/internal/api_trials_intg_test.go @@ -712,12 +712,12 @@ func TestTrialAuthZ(t *testing.T) { }) return err }, true}, - {"CanGetExperimentArtifacts", func(id int) error { - _, err := api.CompareTrials(ctx, &apiv1.CompareTrialsRequest{ - TrialIds: []int32{int32(id)}, - }) - return err - }, false}, + // {"CanGetExperimentArtifacts", func(id int) error { + // _, err := api.CompareTrials(ctx, &apiv1.CompareTrialsRequest{ + // TrialIds: []int32{int32(id)}, + // }) + // return err + // }, false}, {"CanGetExperimentArtifacts", func(id int) error { _, err := api.GetTrialWorkloads(ctx, &apiv1.GetTrialWorkloadsRequest{ TrialId: int32(id), diff --git a/master/internal/db/postgres_experiments.go b/master/internal/db/postgres_experiments.go index 02f81bf8527..37b4cf38412 100644 --- a/master/internal/db/postgres_experiments.go +++ b/master/internal/db/postgres_experiments.go @@ -521,6 +521,22 @@ WHERE t.id = ?`, trialID).Scan(ctx, &experiment); err != nil { return &experiment, nil } +func ExperimentsByTrialID(ctx context.Context, trialIDs []int) ([]*model.Experiment, error) { + var experiment []*model.Experiment + + if err := Bun().NewRaw(` +SELECT e.id, e.state, e.config, e.start_time, e.end_time, e.archived, + e.owner_id, e.notes, e.job_id, u.username as username, e.project_id, unmanaged, external_experiment_id +FROM experiments e +JOIN trials t ON e.id = t.experiment_id +JOIN users u ON (e.owner_id = u.id) +WHERE t.id IN (?)`, bun.In(trialIDs)).Scan(ctx, &experiment); err != nil { + return nil, MatchSentinelError(err) + } + + return experiment, nil +} + // ExperimentByTaskID looks up an experiment by a given taskID, returning an error // if none exists. func ExperimentByTaskID( diff --git a/master/internal/trials/utils.go b/master/internal/trials/utils.go index c05c8535b6f..11e5fdf8552 100644 --- a/master/internal/trials/utils.go +++ b/master/internal/trials/utils.go @@ -42,3 +42,30 @@ func CanGetTrialsExperimentAndCheckCanDoAction(ctx context.Context, } return nil } + +func CanGetTrialsExperimentAndCheckCanDoActionBulk(ctx context.Context, + trialIDs []int, actionFunc func(context.Context, model.User, *model.Experiment) error, +) error { + curUser, _, err := grpcutil.GetUser(ctx) + if err != nil { + return err + } + + trialNotFound := api.NotFoundErrs("trial", "(multi-trial search)", true) + exps, err := db.ExperimentsByTrialID(ctx, trialIDs) + if errors.Is(err, db.ErrNotFound) { + return trialNotFound + } else if err != nil { + return err + } + for _, exp := range exps { + if err = experiment.AuthZProvider.Get().CanGetExperiment(ctx, *curUser, exp); err != nil { + return authz.SubIfUnauthorized(err, trialNotFound) + } + + if err = actionFunc(ctx, *curUser, exp); err != nil { + return status.Error(codes.PermissionDenied, err.Error()) + } + } + return nil +} From eb607a5b61cd5c9318ef4c4a246f5e8d28f8b1bf Mon Sep 17 00:00:00 2001 From: AmanuelAaron Date: Thu, 15 Aug 2024 15:30:33 -0400 Subject: [PATCH 18/28] lint --- master/internal/trials/utils.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/master/internal/trials/utils.go b/master/internal/trials/utils.go index 11e5fdf8552..f490f2bb3be 100644 --- a/master/internal/trials/utils.go +++ b/master/internal/trials/utils.go @@ -43,6 +43,8 @@ func CanGetTrialsExperimentAndCheckCanDoAction(ctx context.Context, return nil } +// CanGetTrialsExperimentAndCheckCanDoActionBulk functions the same as +// CanGetTrialsExperimentAndCheckCanDoAction but takes in multiple trial ids func CanGetTrialsExperimentAndCheckCanDoActionBulk(ctx context.Context, trialIDs []int, actionFunc func(context.Context, model.User, *model.Experiment) error, ) error { From 14235c339e9bb458256b15a08f9f9cc4a63e8b94 Mon Sep 17 00:00:00 2001 From: AmanuelAaron Date: Thu, 15 Aug 2024 15:31:19 -0400 Subject: [PATCH 19/28] lint --- master/internal/trials/utils.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/master/internal/trials/utils.go b/master/internal/trials/utils.go index f490f2bb3be..aa86fd95466 100644 --- a/master/internal/trials/utils.go +++ b/master/internal/trials/utils.go @@ -44,7 +44,7 @@ func CanGetTrialsExperimentAndCheckCanDoAction(ctx context.Context, } // CanGetTrialsExperimentAndCheckCanDoActionBulk functions the same as -// CanGetTrialsExperimentAndCheckCanDoAction but takes in multiple trial ids +// CanGetTrialsExperimentAndCheckCanDoAction but takes in multiple trial ids. func CanGetTrialsExperimentAndCheckCanDoActionBulk(ctx context.Context, trialIDs []int, actionFunc func(context.Context, model.User, *model.Experiment) error, ) error { From 0735ee29590f7429876702d16b68e320499889ce Mon Sep 17 00:00:00 2001 From: AmanuelAaron Date: Thu, 15 Aug 2024 15:38:09 -0400 Subject: [PATCH 20/28] lint --- master/internal/db/postgres_experiments.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/master/internal/db/postgres_experiments.go b/master/internal/db/postgres_experiments.go index 37b4cf38412..2ef04cfe6d1 100644 --- a/master/internal/db/postgres_experiments.go +++ b/master/internal/db/postgres_experiments.go @@ -521,6 +521,8 @@ WHERE t.id = ?`, trialID).Scan(ctx, &experiment); err != nil { return &experiment, nil } +// ExperimenstByTrialID looks up an experiment by a given list of trialIDs, returning +// an error if none exists. func ExperimentsByTrialID(ctx context.Context, trialIDs []int) ([]*model.Experiment, error) { var experiment []*model.Experiment From 228aeae13c040735f7591f04f2782c4c595a48d1 Mon Sep 17 00:00:00 2001 From: AmanuelAaron Date: Thu, 15 Aug 2024 15:50:56 -0400 Subject: [PATCH 21/28] lint --- master/internal/db/postgres_experiments.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/master/internal/db/postgres_experiments.go b/master/internal/db/postgres_experiments.go index 2ef04cfe6d1..0c085969b03 100644 --- a/master/internal/db/postgres_experiments.go +++ b/master/internal/db/postgres_experiments.go @@ -521,7 +521,7 @@ WHERE t.id = ?`, trialID).Scan(ctx, &experiment); err != nil { return &experiment, nil } -// ExperimenstByTrialID looks up an experiment by a given list of trialIDs, returning +// ExperimentsByTrialID looks up an experiment by a given list of trialIDs, returning // an error if none exists. func ExperimentsByTrialID(ctx context.Context, trialIDs []int) ([]*model.Experiment, error) { var experiment []*model.Experiment From 2c47b868c3af3a3e480380ab8fe6e3a8cdd1fe8a Mon Sep 17 00:00:00 2001 From: AmanuelAaron Date: Mon, 19 Aug 2024 10:15:38 -0400 Subject: [PATCH 22/28] fix auth test --- master/internal/trials/utils.go | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/master/internal/trials/utils.go b/master/internal/trials/utils.go index aa86fd95466..fbf22f521a2 100644 --- a/master/internal/trials/utils.go +++ b/master/internal/trials/utils.go @@ -2,7 +2,9 @@ package trials import ( "context" + "encoding/json" "strconv" + "strings" "github.com/pkg/errors" "google.golang.org/grpc/codes" @@ -53,7 +55,11 @@ func CanGetTrialsExperimentAndCheckCanDoActionBulk(ctx context.Context, return err } - trialNotFound := api.NotFoundErrs("trial", "(multi-trial search)", true) + idString, err := json.Marshal(trialIDs) + if err != nil { + return err + } + trialNotFound := api.NotFoundErrs("trial", strings.Trim(string(idString), "[]"), true) exps, err := db.ExperimentsByTrialID(ctx, trialIDs) if errors.Is(err, db.ErrNotFound) { return trialNotFound From e7588ff1504e3d7c677009e5401b4ff685f9886c Mon Sep 17 00:00:00 2001 From: AmanuelAaron Date: Mon, 19 Aug 2024 11:09:39 -0400 Subject: [PATCH 23/28] fix auth test pt2 --- master/internal/api_trials_intg_test.go | 12 ++++++------ master/internal/db/postgres_experiments.go | 3 +++ 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/master/internal/api_trials_intg_test.go b/master/internal/api_trials_intg_test.go index e500d319efc..7bd9e2ad8aa 100644 --- a/master/internal/api_trials_intg_test.go +++ b/master/internal/api_trials_intg_test.go @@ -712,12 +712,12 @@ func TestTrialAuthZ(t *testing.T) { }) return err }, true}, - // {"CanGetExperimentArtifacts", func(id int) error { - // _, err := api.CompareTrials(ctx, &apiv1.CompareTrialsRequest{ - // TrialIds: []int32{int32(id)}, - // }) - // return err - // }, false}, + {"CanGetExperimentArtifacts", func(id int) error { + _, err := api.CompareTrials(ctx, &apiv1.CompareTrialsRequest{ + TrialIds: []int32{int32(id)}, + }) + return err + }, false}, {"CanGetExperimentArtifacts", func(id int) error { _, err := api.GetTrialWorkloads(ctx, &apiv1.GetTrialWorkloadsRequest{ TrialId: int32(id), diff --git a/master/internal/db/postgres_experiments.go b/master/internal/db/postgres_experiments.go index 0c085969b03..477c40cfb1a 100644 --- a/master/internal/db/postgres_experiments.go +++ b/master/internal/db/postgres_experiments.go @@ -536,6 +536,9 @@ WHERE t.id IN (?)`, bun.In(trialIDs)).Scan(ctx, &experiment); err != nil { return nil, MatchSentinelError(err) } + if len(experiment) == 0 { + return nil, ErrNotFound + } return experiment, nil } From 44fe59027d708fa722abdc33ff3900a90d1276f6 Mon Sep 17 00:00:00 2001 From: AmanuelAaron Date: Mon, 19 Aug 2024 11:36:01 -0400 Subject: [PATCH 24/28] add get experiments test --- master/internal/db/postgres_experiments.go | 2 +- .../db/postgres_experiments_intg_test.go | 30 +++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/master/internal/db/postgres_experiments.go b/master/internal/db/postgres_experiments.go index 477c40cfb1a..697adb8f076 100644 --- a/master/internal/db/postgres_experiments.go +++ b/master/internal/db/postgres_experiments.go @@ -527,7 +527,7 @@ func ExperimentsByTrialID(ctx context.Context, trialIDs []int) ([]*model.Experim var experiment []*model.Experiment if err := Bun().NewRaw(` -SELECT e.id, e.state, e.config, e.start_time, e.end_time, e.archived, +SELECT DISTINCT e.id, e.state, e.config, e.start_time, e.end_time, e.archived, e.owner_id, e.notes, e.job_id, u.username as username, e.project_id, unmanaged, external_experiment_id FROM experiments e JOIN trials t ON e.id = t.experiment_id diff --git a/master/internal/db/postgres_experiments_intg_test.go b/master/internal/db/postgres_experiments_intg_test.go index e45c9fb7df6..48c573bb87b 100644 --- a/master/internal/db/postgres_experiments_intg_test.go +++ b/master/internal/db/postgres_experiments_intg_test.go @@ -270,6 +270,36 @@ func TestExperimentByIDs(t *testing.T) { } } +func TestExperimentsByTrialID(t *testing.T) { + ctx := context.Background() + + require.NoError(t, etc.SetRootPath(RootFromDB)) + db, closeDB := MustResolveTestPostgres(t) + defer closeDB() + MustMigrateTestPostgres(t, db, MigrationsFromDB) + user := RequireMockUser(t, db) + + externalID := uuid.New().String() + exp1 := RequireMockExperimentParams(t, db, user, MockExperimentParams{ + ExternalExperimentID: &externalID, + }, DefaultProjectID) + trial1, _ := RequireMockTrial(t, db, exp1) + trial2, _ := RequireMockTrial(t, db, exp1) + + externalID = uuid.New().String() + exp2 := RequireMockExperimentParams(t, db, user, MockExperimentParams{ + ExternalExperimentID: &externalID, + }, DefaultProjectID) + trial3, _ := RequireMockTrial(t, db, exp2) + + actual, err := ExperimentsByTrialID(ctx, []int{trial1.ID, trial2.ID, trial3.ID}) + require.NoError(t, err) + require.ElementsMatch(t, []int{exp1.ID, exp2.ID}, []int{actual[0].ID, actual[1].ID}) + + _, err = ExperimentsByTrialID(ctx, []int{-999}) + require.ErrorIs(t, err, ErrNotFound) +} + func TestTerminateExperimentInRestart(t *testing.T) { ctx := context.Background() From 8d2585b68eef18c9c701be92ea547486427f3d8a Mon Sep 17 00:00:00 2001 From: AmanuelAaron Date: Thu, 22 Aug 2024 12:31:37 -0400 Subject: [PATCH 25/28] input user to function --- master/internal/api_runs.go | 18 ++- master/internal/api_tensorboard.go | 6 +- master/internal/api_trials.go | 128 +++++++++++++++--- .../internal/trials/api_trial_source_info.go | 13 +- master/internal/trials/utils.go | 15 +- master/internal/trials/utils_intg_test.go | 51 +++++++ 6 files changed, 190 insertions(+), 41 deletions(-) create mode 100644 master/internal/trials/utils_intg_test.go diff --git a/master/internal/api_runs.go b/master/internal/api_runs.go index ed2d895f1cf..af0473b6ead 100644 --- a/master/internal/api_runs.go +++ b/master/internal/api_runs.go @@ -44,8 +44,12 @@ type runCandidateResult struct { func (a *apiServer) RunPrepareForReporting( ctx context.Context, req *apiv1.RunPrepareForReportingRequest, ) (*apiv1.RunPrepareForReportingResponse, error) { + curUser, _, err := grpcutil.GetUser(ctx) + if err != nil { + return nil, err + } // TODO(runs) run specific RBAC. - if err := trials.CanGetTrialsExperimentAndCheckCanDoAction(ctx, int(req.RunId), + if err := trials.CanGetTrialsExperimentAndCheckCanDoAction(ctx, int(req.RunId), curUser, experiment.AuthZProvider.Get().CanEditExperiment); err != nil { return nil, err } @@ -1015,8 +1019,12 @@ func pauseResumeAction(ctx context.Context, isPause bool, projectID int32, func (a *apiServer) GetRunMetadata( ctx context.Context, req *apiv1.GetRunMetadataRequest, ) (*apiv1.GetRunMetadataResponse, error) { + curUser, _, err := grpcutil.GetUser(ctx) + if err != nil { + return nil, err + } // TODO(runs) run specific RBAC. - if err := trials.CanGetTrialsExperimentAndCheckCanDoAction(ctx, int(req.RunId), + if err := trials.CanGetTrialsExperimentAndCheckCanDoAction(ctx, int(req.RunId), curUser, experiment.AuthZProvider.Get().CanGetExperimentArtifacts); err != nil { return nil, err } @@ -1033,8 +1041,12 @@ func (a *apiServer) GetRunMetadata( func (a *apiServer) PostRunMetadata( ctx context.Context, req *apiv1.PostRunMetadataRequest, ) (*apiv1.PostRunMetadataResponse, error) { + curUser, _, err := grpcutil.GetUser(ctx) + if err != nil { + return nil, err + } // TODO(runs) run specific RBAC. - if err := trials.CanGetTrialsExperimentAndCheckCanDoAction(ctx, int(req.RunId), + if err := trials.CanGetTrialsExperimentAndCheckCanDoAction(ctx, int(req.RunId), curUser, experiment.AuthZProvider.Get().CanEditExperiment); err != nil { return nil, err } diff --git a/master/internal/api_tensorboard.go b/master/internal/api_tensorboard.go index 65eae44fbc6..f0a8856666e 100644 --- a/master/internal/api_tensorboard.go +++ b/master/internal/api_tensorboard.go @@ -504,8 +504,12 @@ func (a *apiServer) getTensorBoardConfigsFromReq( confByID[expID] = &tensorboardConfig{ExperimentID: expID, Config: conf} } + curUser, _, err := grpcutil.GetUser(ctx) + if err != nil { + return nil, err + } for _, trialID := range req.TrialIds { - if err := trials.CanGetTrialsExperimentAndCheckCanDoAction(ctx, int(trialID), + if err := trials.CanGetTrialsExperimentAndCheckCanDoAction(ctx, int(trialID), curUser, exputil.AuthZProvider.Get().CanGetExperimentArtifacts); err != nil { return nil, err } diff --git a/master/internal/api_trials.go b/master/internal/api_trials.go index e2ab9ebd787..e4800b856e2 100644 --- a/master/internal/api_trials.go +++ b/master/internal/api_trials.go @@ -134,7 +134,11 @@ func (a *apiServer) enrichTrialState(trials ...*trialv1.Trial) error { func (a *apiServer) TrialLogs( req *apiv1.TrialLogsRequest, resp apiv1.Determined_TrialLogsServer, ) error { - if err := trials.CanGetTrialsExperimentAndCheckCanDoAction(resp.Context(), int(req.TrialId), + curUser, _, err := grpcutil.GetUser(resp.Context()) + if err != nil { + return err + } + if err := trials.CanGetTrialsExperimentAndCheckCanDoAction(resp.Context(), int(req.TrialId), curUser, experiment.AuthZProvider.Get().CanGetExperimentArtifacts); err != nil { return err } @@ -253,8 +257,12 @@ func (a *apiServer) legacyTrialLogs( var followState interface{} trialLogsTimeSinceLastAuth := time.Now() // time.Now() to avoid recheck from a.TrialLogs. fetch := func(r api.BatchRequest) (api.Batch, error) { + curUser, _, err := grpcutil.GetUser(ctx) + if err != nil { + return nil, err + } if time.Since(trialLogsTimeSinceLastAuth) >= recheckAuthPeriod { - if err = trials.CanGetTrialsExperimentAndCheckCanDoAction(ctx, int(req.TrialId), + if err = trials.CanGetTrialsExperimentAndCheckCanDoAction(ctx, int(req.TrialId), curUser, experiment.AuthZProvider.Get().CanGetExperimentArtifacts); err != nil { return nil, err } @@ -348,7 +356,11 @@ func constructTrialLogsFilters(req *apiv1.TrialLogsRequest) ([]api.Filter, error func (a *apiServer) TrialLogsFields( req *apiv1.TrialLogsFieldsRequest, resp apiv1.Determined_TrialLogsFieldsServer, ) error { - if err := trials.CanGetTrialsExperimentAndCheckCanDoAction(resp.Context(), int(req.TrialId), + curUser, _, err := grpcutil.GetUser(resp.Context()) + if err != nil { + return err + } + if err := trials.CanGetTrialsExperimentAndCheckCanDoAction(resp.Context(), int(req.TrialId), curUser, experiment.AuthZProvider.Get().CanGetExperimentArtifacts); err != nil { return err } @@ -368,8 +380,12 @@ func (a *apiServer) TrialLogsFields( api.BatchRequest{Follow: req.Follow}, func(lr api.BatchRequest) (api.Batch, error) { if time.Since(trialLogsTimeSinceLastAuth) >= recheckAuthPeriod { + curUser, _, err := grpcutil.GetUser(ctx) + if err != nil { + return nil, err + } if err := trials.CanGetTrialsExperimentAndCheckCanDoAction(resp.Context(), - int(req.TrialId), + int(req.TrialId), curUser, experiment.AuthZProvider.Get().CanGetExperimentArtifacts); err != nil { return nil, err } @@ -393,8 +409,12 @@ func (a *apiServer) TrialLogsFields( api.BatchRequest{Follow: req.Follow && i == len(trialTaskIDs)-1}, func(lr api.BatchRequest) (api.Batch, error) { if time.Since(taskLogsTimeSinceLastAuth) >= recheckAuthPeriod { + curUser, _, err := grpcutil.GetUser(resp.Context()) + if err != nil { + return nil, err + } if err := trials.CanGetTrialsExperimentAndCheckCanDoAction(resp.Context(), - int(req.TrialId), + int(req.TrialId), curUser, experiment.AuthZProvider.Get().CanGetExperimentArtifacts); err != nil { return nil, err } @@ -452,7 +472,11 @@ func (a *apiServer) TrialLogsFields( func (a *apiServer) GetTrialCheckpoints( ctx context.Context, req *apiv1.GetTrialCheckpointsRequest, ) (*apiv1.GetTrialCheckpointsResponse, error) { - if err := trials.CanGetTrialsExperimentAndCheckCanDoAction(ctx, int(req.Id), + curUser, _, err := grpcutil.GetUser(ctx) + if err != nil { + return nil, err + } + if err := trials.CanGetTrialsExperimentAndCheckCanDoAction(ctx, int(req.Id), curUser, experiment.AuthZProvider.Get().CanGetExperimentArtifacts); err != nil { return nil, err } @@ -526,7 +550,11 @@ func (a *apiServer) GetTrialCheckpoints( func (a *apiServer) KillTrial( ctx context.Context, req *apiv1.KillTrialRequest, ) (*apiv1.KillTrialResponse, error) { - if err := trials.CanGetTrialsExperimentAndCheckCanDoAction(ctx, int(req.Id), + curUser, _, err := grpcutil.GetUser(ctx) + if err != nil { + return nil, err + } + if err := trials.CanGetTrialsExperimentAndCheckCanDoAction(ctx, int(req.Id), curUser, experiment.AuthZProvider.Get().CanEditExperiment); err != nil { return nil, err } @@ -683,7 +711,11 @@ WHERE r.run_id = ? func (a *apiServer) GetTrial(ctx context.Context, req *apiv1.GetTrialRequest) ( *apiv1.GetTrialResponse, error, ) { - if err := trials.CanGetTrialsExperimentAndCheckCanDoAction(ctx, int(req.TrialId), + curUser, _, err := grpcutil.GetUser(ctx) + if err != nil { + return nil, err + } + if err := trials.CanGetTrialsExperimentAndCheckCanDoAction(ctx, int(req.TrialId), curUser, experiment.AuthZProvider.Get().CanGetExperimentArtifacts); err != nil { return nil, err } @@ -738,13 +770,17 @@ WHERE t.external_trial_id = ? AND e.external_experiment_id = ?`, func (a *apiServer) PutTrialRetainLogs( ctx context.Context, req *apiv1.PutTrialRetainLogsRequest, ) (*apiv1.PutTrialRetainLogsResponse, error) { + curUser, _, err := grpcutil.GetUser(ctx) + if err != nil { + return nil, err + } if err := trials.CanGetTrialsExperimentAndCheckCanDoAction( - ctx, int(req.TrialId), experiment.AuthZProvider.Get().CanEditExperiment, + ctx, int(req.TrialId), curUser, experiment.AuthZProvider.Get().CanEditExperiment, ); err != nil { return nil, err } - err := db.Bun().RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + err = db.Bun().RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { if _, err := tx.NewUpdate().Table("runs"). Set("log_retention_days = ?", req.NumDays). Where("id = ?", req.TrialId). @@ -908,7 +944,11 @@ func (a *apiServer) CompareTrials(ctx context.Context, trialIds = append(trialIds, strconv.Itoa(int(trialID))) trialIntList = append(trialIntList, int(trialID)) } - if err := trials.CanGetTrialsExperimentAndCheckCanDoActionBulk(ctx, trialIntList, + curUser, _, err := grpcutil.GetUser(ctx) + if err != nil { + return nil, err + } + if err := trials.CanGetTrialsExperimentAndCheckCanDoActionBulk(ctx, trialIntList, curUser, experiment.AuthZProvider.Get().CanGetExperimentArtifacts); err != nil { return nil, errors.Wrapf(err, "failed validate permissions") } @@ -998,8 +1038,12 @@ func (a *apiServer) streamMetrics(ctx context.Context, } slices.Sort(trialIDs) + curUser, _, err := grpcutil.GetUser(ctx) + if err != nil { + return err + } for _, trialID := range trialIDs { - if err := trials.CanGetTrialsExperimentAndCheckCanDoAction(ctx, int(trialID), + if err := trials.CanGetTrialsExperimentAndCheckCanDoAction(ctx, int(trialID), curUser, experiment.AuthZProvider.Get().CanGetExperimentArtifacts); err != nil { return err } @@ -1041,7 +1085,11 @@ func (a *apiServer) streamMetrics(ctx context.Context, func (a *apiServer) GetTrialWorkloads(ctx context.Context, req *apiv1.GetTrialWorkloadsRequest) ( *apiv1.GetTrialWorkloadsResponse, error, ) { - if err := trials.CanGetTrialsExperimentAndCheckCanDoAction(ctx, int(req.TrialId), + curUser, _, err := grpcutil.GetUser(ctx) + if err != nil { + return nil, err + } + if err := trials.CanGetTrialsExperimentAndCheckCanDoAction(ctx, int(req.TrialId), curUser, experiment.AuthZProvider.Get().CanGetExperimentArtifacts); err != nil { return nil, err } @@ -1092,8 +1140,12 @@ func (a *apiServer) GetTrialProfilerMetrics( var timeSinceLastAuth time.Time fetch := func(lr api.BatchRequest) (api.Batch, error) { if time.Since(timeSinceLastAuth) >= recheckAuthPeriod { + curUser, _, err := grpcutil.GetUser(resp.Context()) + if err != nil { + return nil, err + } if err := trials.CanGetTrialsExperimentAndCheckCanDoAction(resp.Context(), - int(req.Labels.TrialId), + int(req.Labels.TrialId), curUser, experiment.AuthZProvider.Get().CanGetExperimentArtifacts); err != nil { return nil, err } @@ -1138,8 +1190,12 @@ func (a *apiServer) GetTrialProfilerAvailableSeries( var timeSinceLastAuth time.Time fetch := func(_ api.BatchRequest) (api.Batch, error) { if time.Since(timeSinceLastAuth) >= recheckAuthPeriod { + curUser, _, err := grpcutil.GetUser(resp.Context()) + if err != nil { + return nil, err + } if err := trials.CanGetTrialsExperimentAndCheckCanDoAction(resp.Context(), - int(req.TrialId), + int(req.TrialId), curUser, experiment.AuthZProvider.Get().CanGetExperimentArtifacts); err != nil { return nil, err } @@ -1184,7 +1240,11 @@ func (a *apiServer) PostTrialProfilerMetricsBatch( for _, batch := range req.Batches { trialID := int(batch.Labels.TrialId) if !existingTrials[trialID] { - if err := trials.CanGetTrialsExperimentAndCheckCanDoAction(ctx, trialID, + curUser, _, err := grpcutil.GetUser(ctx) + if err != nil { + return nil, err + } + if err := trials.CanGetTrialsExperimentAndCheckCanDoAction(ctx, trialID, curUser, experiment.AuthZProvider.Get().CanEditExperiment); err != nil { return nil, err } @@ -1311,7 +1371,11 @@ func (a *apiServer) MarkAllocationResourcesDaemon( func (a *apiServer) GetCurrentTrialSearcherOperation( ctx context.Context, req *apiv1.GetCurrentTrialSearcherOperationRequest, ) (*apiv1.GetCurrentTrialSearcherOperationResponse, error) { - if err := trials.CanGetTrialsExperimentAndCheckCanDoAction(ctx, int(req.TrialId), + curUser, _, err := grpcutil.GetUser(ctx) + if err != nil { + return nil, err + } + if err := trials.CanGetTrialsExperimentAndCheckCanDoAction(ctx, int(req.TrialId), curUser, experiment.AuthZProvider.Get().CanGetExperimentArtifacts); err != nil { return nil, err } @@ -1342,7 +1406,11 @@ func (a *apiServer) GetCurrentTrialSearcherOperation( func (a *apiServer) CompleteTrialSearcherValidation( ctx context.Context, req *apiv1.CompleteTrialSearcherValidationRequest, ) (*apiv1.CompleteTrialSearcherValidationResponse, error) { - if err := trials.CanGetTrialsExperimentAndCheckCanDoAction(ctx, int(req.TrialId), + curUser, _, err := grpcutil.GetUser(ctx) + if err != nil { + return nil, err + } + if err := trials.CanGetTrialsExperimentAndCheckCanDoAction(ctx, int(req.TrialId), curUser, experiment.AuthZProvider.Get().CanEditExperiment); err != nil { return nil, err } @@ -1370,7 +1438,11 @@ func (a *apiServer) CompleteTrialSearcherValidation( func (a *apiServer) ReportTrialSearcherEarlyExit( ctx context.Context, req *apiv1.ReportTrialSearcherEarlyExitRequest, ) (*apiv1.ReportTrialSearcherEarlyExitResponse, error) { - if err := trials.CanGetTrialsExperimentAndCheckCanDoAction(ctx, int(req.TrialId), + curUser, _, err := grpcutil.GetUser(ctx) + if err != nil { + return nil, err + } + if err := trials.CanGetTrialsExperimentAndCheckCanDoAction(ctx, int(req.TrialId), curUser, experiment.AuthZProvider.Get().CanEditExperiment); err != nil { return nil, err } @@ -1397,7 +1469,11 @@ func (a *apiServer) ReportTrialSearcherEarlyExit( func (a *apiServer) ReportTrialProgress( ctx context.Context, req *apiv1.ReportTrialProgressRequest, ) (*apiv1.ReportTrialProgressResponse, error) { - if err := trials.CanGetTrialsExperimentAndCheckCanDoAction(ctx, int(req.TrialId), + curUser, _, err := grpcutil.GetUser(ctx) + if err != nil { + return nil, err + } + if err := trials.CanGetTrialsExperimentAndCheckCanDoAction(ctx, int(req.TrialId), curUser, experiment.AuthZProvider.Get().CanEditExperiment); err != nil { return nil, err } @@ -1428,7 +1504,11 @@ func (a *apiServer) ReportTrialMetrics( if err := metricGroup.Validate(); err != nil { return nil, err } - if err := trials.CanGetTrialsExperimentAndCheckCanDoAction(ctx, int(req.Metrics.TrialId), + curUser, _, err := grpcutil.GetUser(ctx) + if err != nil { + return nil, err + } + if err := trials.CanGetTrialsExperimentAndCheckCanDoAction(ctx, int(req.Metrics.TrialId), curUser, experiment.AuthZProvider.Get().CanEditExperiment); err != nil { return nil, err } @@ -1559,7 +1639,11 @@ func (a *apiServer) AllocationRendezvousInfo( func (a *apiServer) PostTrialRunnerMetadata( ctx context.Context, req *apiv1.PostTrialRunnerMetadataRequest, ) (*apiv1.PostTrialRunnerMetadataResponse, error) { - if err := trials.CanGetTrialsExperimentAndCheckCanDoAction(ctx, int(req.TrialId), + curUser, _, err := grpcutil.GetUser(ctx) + if err != nil { + return nil, err + } + if err := trials.CanGetTrialsExperimentAndCheckCanDoAction(ctx, int(req.TrialId), curUser, experiment.AuthZProvider.Get().CanEditExperiment); err != nil { return nil, err } diff --git a/master/internal/trials/api_trial_source_info.go b/master/internal/trials/api_trial_source_info.go index 92122ec09d7..a29c8cc9bd6 100644 --- a/master/internal/trials/api_trial_source_info.go +++ b/master/internal/trials/api_trial_source_info.go @@ -8,6 +8,7 @@ import ( "github.com/determined-ai/determined/master/internal/db" expauth "github.com/determined-ai/determined/master/internal/experiment" + "github.com/determined-ai/determined/master/internal/grpcutil" "github.com/determined-ai/determined/proto/pkg/apiv1" "github.com/determined-ai/determined/proto/pkg/trialv1" ) @@ -22,7 +23,11 @@ func (a *TrialSourceInfoAPIServer) ReportTrialSourceInfo( ctx context.Context, req *apiv1.ReportTrialSourceInfoRequest, ) (*apiv1.ReportTrialSourceInfoResponse, error) { tsi := req.TrialSourceInfo - if err := CanGetTrialsExperimentAndCheckCanDoAction(ctx, int(tsi.TrialId), + curUser, _, err := grpcutil.GetUser(ctx) + if err != nil { + return nil, err + } + if err := CanGetTrialsExperimentAndCheckCanDoAction(ctx, int(tsi.TrialId), curUser, expauth.AuthZProvider.Get().CanEditExperiment); err != nil { return nil, err } @@ -46,13 +51,17 @@ func GetMetricsForTrialSourceInfoQuery( return nil, fmt.Errorf("failed to get trial source info %w", err) } + curUser, _, err := grpcutil.GetUser(ctx) + if err != nil { + return nil, err + } // TODO (Taylor): If we reach a point where this becomes a performance bottleneck // we should join on trial_source_infos -> trials -> experiments to get the // workspace_id and get permissions on those without checking each trial individually ret := []*trialv1.MetricsReport{} numMetricsLimit := 1000 for _, val := range trialIds { - if err := CanGetTrialsExperimentAndCheckCanDoAction(ctx, val.TrialID, + if err := CanGetTrialsExperimentAndCheckCanDoAction(ctx, val.TrialID, curUser, expauth.AuthZProvider.Get().CanGetExperimentArtifacts); err != nil { // If the user can see the checkpoint, but not one of the inference // or fine tuning trials that points to it, simply don't show those diff --git a/master/internal/trials/utils.go b/master/internal/trials/utils.go index fbf22f521a2..525077afb71 100644 --- a/master/internal/trials/utils.go +++ b/master/internal/trials/utils.go @@ -14,20 +14,14 @@ import ( "github.com/determined-ai/determined/master/internal/authz" "github.com/determined-ai/determined/master/internal/db" "github.com/determined-ai/determined/master/internal/experiment" - "github.com/determined-ai/determined/master/internal/grpcutil" "github.com/determined-ai/determined/master/pkg/model" ) // CanGetTrialsExperimentAndCheckCanDoAction is a utility function for generalizing // RBAC support for trials and experiments. func CanGetTrialsExperimentAndCheckCanDoAction(ctx context.Context, - trialID int, actionFunc func(context.Context, model.User, *model.Experiment) error, + trialID int, curUser *model.User, actionFunc func(context.Context, model.User, *model.Experiment) error, ) error { - curUser, _, err := grpcutil.GetUser(ctx) - if err != nil { - return err - } - trialNotFound := api.NotFoundErrs("trial", strconv.Itoa(trialID), true) exp, err := db.ExperimentByTrialID(ctx, trialID) if errors.Is(err, db.ErrNotFound) { @@ -48,13 +42,8 @@ func CanGetTrialsExperimentAndCheckCanDoAction(ctx context.Context, // CanGetTrialsExperimentAndCheckCanDoActionBulk functions the same as // CanGetTrialsExperimentAndCheckCanDoAction but takes in multiple trial ids. func CanGetTrialsExperimentAndCheckCanDoActionBulk(ctx context.Context, - trialIDs []int, actionFunc func(context.Context, model.User, *model.Experiment) error, + trialIDs []int, curUser *model.User, actionFunc func(context.Context, model.User, *model.Experiment) error, ) error { - curUser, _, err := grpcutil.GetUser(ctx) - if err != nil { - return err - } - idString, err := json.Marshal(trialIDs) if err != nil { return err diff --git a/master/internal/trials/utils_intg_test.go b/master/internal/trials/utils_intg_test.go new file mode 100644 index 00000000000..4752b8bebbc --- /dev/null +++ b/master/internal/trials/utils_intg_test.go @@ -0,0 +1,51 @@ +//go:build integration +// +build integration + +package trials + +import ( + "context" + "testing" + + "github.com/determined-ai/determined/master/internal/db" + "github.com/determined-ai/determined/master/pkg/etc" + "github.com/determined-ai/determined/master/pkg/model" + "github.com/google/uuid" + "github.com/stretchr/testify/require" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +func actionFuncAllow(context.Context, model.User, *model.Experiment) error { + return nil +} + +func actionFuncDeny(context.Context, model.User, *model.Experiment) error { + return status.Error(codes.PermissionDenied, "") +} + +func TestCanGetTrialsExperimentAndCheckCanDoAction(t *testing.T) { + ctx := context.Background() + + require.NoError(t, etc.SetRootPath(db.RootFromDB)) + pgDB, closeDB := db.MustResolveTestPostgres(t) + defer closeDB() + db.MustMigrateTestPostgres(t, pgDB, db.MigrationsFromDB) + user := db.RequireMockUser(t, pgDB) + + externalID := uuid.New().String() + exp := db.RequireMockExperimentParams(t, pgDB, user, db.MockExperimentParams{ + ExternalExperimentID: &externalID, + }, db.DefaultProjectID) + trial, _ := db.RequireMockTrial(t, pgDB, exp) + + // allowed + err := CanGetTrialsExperimentAndCheckCanDoAction(ctx, trial.ID, &user, actionFuncAllow) + require.NoError(t, err) + // denied + err = CanGetTrialsExperimentAndCheckCanDoAction(ctx, trial.ID, &user, actionFuncDeny) + require.Error(t, err) + // not found + err = CanGetTrialsExperimentAndCheckCanDoAction(ctx, 999, &user, actionFuncAllow) + require.Error(t, err) +} From 2d1655f4b9052f8f6812c5c80365b89d2796bc4f Mon Sep 17 00:00:00 2001 From: AmanuelAaron Date: Thu, 22 Aug 2024 12:46:12 -0400 Subject: [PATCH 26/28] update tests --- master/internal/trials/utils_intg_test.go | 35 ++++++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/master/internal/trials/utils_intg_test.go b/master/internal/trials/utils_intg_test.go index 4752b8bebbc..ebd832e4038 100644 --- a/master/internal/trials/utils_intg_test.go +++ b/master/internal/trials/utils_intg_test.go @@ -7,6 +7,7 @@ import ( "context" "testing" + apiPkg "github.com/determined-ai/determined/master/internal/api" "github.com/determined-ai/determined/master/internal/db" "github.com/determined-ai/determined/master/pkg/etc" "github.com/determined-ai/determined/master/pkg/model" @@ -16,6 +17,8 @@ import ( "google.golang.org/grpc/status" ) +const testTrialCount = 5 + func actionFuncAllow(context.Context, model.User, *model.Experiment) error { return nil } @@ -46,6 +49,36 @@ func TestCanGetTrialsExperimentAndCheckCanDoAction(t *testing.T) { err = CanGetTrialsExperimentAndCheckCanDoAction(ctx, trial.ID, &user, actionFuncDeny) require.Error(t, err) // not found - err = CanGetTrialsExperimentAndCheckCanDoAction(ctx, 999, &user, actionFuncAllow) + err = CanGetTrialsExperimentAndCheckCanDoAction(ctx, -999, &user, actionFuncAllow) + require.ErrorIs(t, err, apiPkg.NotFoundErrs("trial", "-999", true)) +} + +func TestCanGetTrialsExperimentAndCheckCanDoActionBulk(t *testing.T) { + ctx := context.Background() + + require.NoError(t, etc.SetRootPath(db.RootFromDB)) + pgDB, closeDB := db.MustResolveTestPostgres(t) + defer closeDB() + db.MustMigrateTestPostgres(t, pgDB, db.MigrationsFromDB) + user := db.RequireMockUser(t, pgDB) + + externalID := uuid.New().String() + exp := db.RequireMockExperimentParams(t, pgDB, user, db.MockExperimentParams{ + ExternalExperimentID: &externalID, + }, db.DefaultProjectID) + trialIDs := []int{} + for i := 0; i < testTrialCount; i++ { + trial, _ := db.RequireMockTrial(t, pgDB, exp) + trialIDs = append(trialIDs, trial.ID) + } + + // allowed + err := CanGetTrialsExperimentAndCheckCanDoActionBulk(ctx, trialIDs, &user, actionFuncAllow) + require.NoError(t, err) + // denied + err = CanGetTrialsExperimentAndCheckCanDoActionBulk(ctx, trialIDs, &user, actionFuncDeny) require.Error(t, err) + // not found + err = CanGetTrialsExperimentAndCheckCanDoActionBulk(ctx, []int{-1, -2, -3}, &user, actionFuncAllow) + require.ErrorIs(t, err, apiPkg.NotFoundErrs("trial", "-1,-2,-3", true)) } From 09d87e53e9432bc4ead1a25b0fe15da2645f0f2e Mon Sep 17 00:00:00 2001 From: AmanuelAaron Date: Thu, 22 Aug 2024 12:57:24 -0400 Subject: [PATCH 27/28] goimport --- master/internal/trials/utils_intg_test.go | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/master/internal/trials/utils_intg_test.go b/master/internal/trials/utils_intg_test.go index ebd832e4038..fcd4656a9ce 100644 --- a/master/internal/trials/utils_intg_test.go +++ b/master/internal/trials/utils_intg_test.go @@ -7,14 +7,15 @@ import ( "context" "testing" - apiPkg "github.com/determined-ai/determined/master/internal/api" - "github.com/determined-ai/determined/master/internal/db" - "github.com/determined-ai/determined/master/pkg/etc" - "github.com/determined-ai/determined/master/pkg/model" "github.com/google/uuid" "github.com/stretchr/testify/require" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + + apiPkg "github.com/determined-ai/determined/master/internal/api" + "github.com/determined-ai/determined/master/internal/db" + "github.com/determined-ai/determined/master/pkg/etc" + "github.com/determined-ai/determined/master/pkg/model" ) const testTrialCount = 5 From d452d0aa86082c2af628ae283a57fca1bf1a72c2 Mon Sep 17 00:00:00 2001 From: AmanuelAaron Date: Thu, 22 Aug 2024 13:35:57 -0400 Subject: [PATCH 28/28] validate permission denied error --- master/internal/trials/utils_intg_test.go | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/master/internal/trials/utils_intg_test.go b/master/internal/trials/utils_intg_test.go index fcd4656a9ce..e5c80ee3dc1 100644 --- a/master/internal/trials/utils_intg_test.go +++ b/master/internal/trials/utils_intg_test.go @@ -25,7 +25,7 @@ func actionFuncAllow(context.Context, model.User, *model.Experiment) error { } func actionFuncDeny(context.Context, model.User, *model.Experiment) error { - return status.Error(codes.PermissionDenied, "") + return status.Error(codes.Unknown, "") } func TestCanGetTrialsExperimentAndCheckCanDoAction(t *testing.T) { @@ -48,7 +48,8 @@ func TestCanGetTrialsExperimentAndCheckCanDoAction(t *testing.T) { require.NoError(t, err) // denied err = CanGetTrialsExperimentAndCheckCanDoAction(ctx, trial.ID, &user, actionFuncDeny) - require.Error(t, err) + expectedError := status.Error(codes.Unknown, "") + require.ErrorIs(t, err, status.Error(codes.PermissionDenied, expectedError.Error())) // not found err = CanGetTrialsExperimentAndCheckCanDoAction(ctx, -999, &user, actionFuncAllow) require.ErrorIs(t, err, apiPkg.NotFoundErrs("trial", "-999", true)) @@ -78,7 +79,8 @@ func TestCanGetTrialsExperimentAndCheckCanDoActionBulk(t *testing.T) { require.NoError(t, err) // denied err = CanGetTrialsExperimentAndCheckCanDoActionBulk(ctx, trialIDs, &user, actionFuncDeny) - require.Error(t, err) + expectedError := status.Error(codes.Unknown, "") + require.ErrorIs(t, err, status.Error(codes.PermissionDenied, expectedError.Error())) // not found err = CanGetTrialsExperimentAndCheckCanDoActionBulk(ctx, []int{-1, -2, -3}, &user, actionFuncAllow) require.ErrorIs(t, err, apiPkg.NotFoundErrs("trial", "-1,-2,-3", true))