diff --git a/tests/test_experiments.py b/tests/test_experiments.py index 217bdd35..38e0962d 100644 --- a/tests/test_experiments.py +++ b/tests/test_experiments.py @@ -3,7 +3,6 @@ import pytest from pytz import UTC -import inspect from faculty.clients.experiment import ( Experiment, @@ -11,6 +10,7 @@ ExperimentRunStatus, ListExperimentRunsResponse, Metric, + Page, Pagination, Param, SingleFilter, @@ -55,6 +55,12 @@ deleted_at=DELETED_AT, ) +FILTER = SingleFilter( + SingleFilterBy.EXPERIMENT_ID, None, SingleFilterOperator.EQUAL_TO, "2" +) + +SORT = [Sort(SortBy.METRIC, "metric_key", SortOrder.ASC)] + EXPERIMENT_RUN = ExperimentRun( id=EXPERIMENT_RUN_ID, run_number=EXPERIMENT_RUN_NUMBER, @@ -70,29 +76,53 @@ params=[PARAM], metrics=[METRIC], ) - -PAGINATION = Pagination(start=20, size=10, previous=None, next=None) - +PAGINATION = Pagination(0, 1, None, None) LIST_EXPERIMENT_RUNS_RESPONSE = ListExperimentRunsResponse( runs=[EXPERIMENT_RUN], pagination=PAGINATION ) - -FILTER = SingleFilter( - SingleFilterBy.EXPERIMENT_ID, None, SingleFilterOperator.EQUAL_TO, "2" -) - -SORT = [Sort(SortBy.METRIC, "metric_key", SortOrder.ASC)] - - -def test_experiment_run_query(mocker): - - experiment_client_mock = mocker.MagicMock() - experiment_client_mock.query_runs.return_value = ( - LIST_EXPERIMENT_RUNS_RESPONSE +EXPECTED_RUNS = [ + FacultyExperimentRun( + id=EXPERIMENT_RUN_ID, + run_number=EXPERIMENT_RUN_NUMBER, + name=EXPERIMENT_RUN_NAME, + parent_run_id=PARENT_RUN_ID, + experiment_id=EXPERIMENT.id, + artifact_location="faculty:", + status=ExperimentRunStatus.RUNNING, + started_at=RUN_STARTED_AT, + ended_at=RUN_ENDED_AT, + deleted_at=DELETED_AT, + tags=[TAG], + params=[PARAM], + metrics=[METRIC], ) - mocker.patch("faculty.client", return_value=experiment_client_mock) +] - expected_response = FacultyExperimentRun( +PAGINATION_MULTIPLE_1 = Pagination(0, 1, None, Page(1, 1)) +LIST_EXPERIMENT_RUNS_RESPONSE_MULTIPLE_1 = ListExperimentRunsResponse( + runs=[EXPERIMENT_RUN], pagination=PAGINATION_MULTIPLE_1 +) +EXPERIMENT_RUN_MULTIPLE_2 = ExperimentRun( + id=7, + run_number=EXPERIMENT_RUN_NUMBER, + name=EXPERIMENT_RUN_NAME, + parent_run_id=PARENT_RUN_ID, + experiment_id=EXPERIMENT.id, + artifact_location="faculty:", + status=ExperimentRunStatus.RUNNING, + started_at=RUN_STARTED_AT, + ended_at=RUN_ENDED_AT, + deleted_at=DELETED_AT, + tags=[TAG], + params=[PARAM], + metrics=[METRIC], +) +PAGINATION_MULTIPLE_2 = Pagination(1, 1, Page(0, 1), None) +LIST_EXPERIMENT_RUNS_RESPONSE_MULTIPLE_2 = ListExperimentRunsResponse( + runs=[EXPERIMENT_RUN_MULTIPLE_2], pagination=PAGINATION_MULTIPLE_2 +) +EXPECTED_RUNS_2 = [ + FacultyExperimentRun( id=EXPERIMENT_RUN_ID, run_number=EXPERIMENT_RUN_NUMBER, name=EXPERIMENT_RUN_NAME, @@ -106,27 +136,65 @@ def test_experiment_run_query(mocker): tags=[TAG], params=[PARAM], metrics=[METRIC], + ), + FacultyExperimentRun( + id=7, + run_number=EXPERIMENT_RUN_NUMBER, + name=EXPERIMENT_RUN_NAME, + parent_run_id=PARENT_RUN_ID, + experiment_id=EXPERIMENT.id, + artifact_location="faculty:", + status=ExperimentRunStatus.RUNNING, + started_at=RUN_STARTED_AT, + ended_at=RUN_ENDED_AT, + deleted_at=DELETED_AT, + tags=[TAG], + params=[PARAM], + metrics=[METRIC], + ), +] + + +@pytest.mark.parametrize( + "query_runs_side_effects,expected_runs", + [ + [[LIST_EXPERIMENT_RUNS_RESPONSE], EXPECTED_RUNS], + [ + [ + LIST_EXPERIMENT_RUNS_RESPONSE_MULTIPLE_1, + LIST_EXPERIMENT_RUNS_RESPONSE_MULTIPLE_2, + ], + EXPECTED_RUNS_2, + ], + ], +) +def test_experiment_run_query_single_call( + mocker, query_runs_side_effects, expected_runs +): + experiment_client_mock = mocker.MagicMock() + experiment_client_mock.query_runs = mocker.MagicMock( + side_effect=query_runs_side_effects ) + mocker.patch("faculty.client", return_value=experiment_client_mock) response = FacultyExperimentRun.query(PROJECT_ID, FILTER, SORT) assert isinstance(response, ExperimentRunQueryResult) - returned_run = list(response)[0] - assert isinstance(returned_run, FacultyExperimentRun) - assert all( + returned_runs = list(response) + for expected_run, returned_run in zip(expected_runs, returned_runs): + assert isinstance(returned_run, FacultyExperimentRun) + assert _are_runs_equal(expected_run, returned_run) + + +def _are_runs_equal(this, that): + return all( list( i == j for i, j in ( list( zip( - [ - getattr(returned_run, attr) - for attr in returned_run.__dict__.keys() - ], - [ - getattr(expected_response, attr) - for attr in expected_response.__dict__.keys() - ], + [getattr(this, attr) for attr in this.__dict__.keys()], + [getattr(that, attr) for attr in that.__dict__.keys()], ) ) )