diff --git a/.travis.yml b/.travis.yml index 1c724c43..3d8ba6c3 100644 --- a/.travis.yml +++ b/.travis.yml @@ -18,8 +18,6 @@ matrix: env: TOXENV=license - python: 2.7 env: TOXENV=py27 - - python: 3.4 - env: TOXENV=py34 - python: 3.5 env: TOXENV=py35 - python: 3.6 diff --git a/faculty/_util/__init__.py b/faculty/_util/__init__.py new file mode 100644 index 00000000..f5221fcb --- /dev/null +++ b/faculty/_util/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2018-2019 Faculty Science Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/faculty/_util/resolvers.py b/faculty/_util/resolvers.py new file mode 100644 index 00000000..92bf06bf --- /dev/null +++ b/faculty/_util/resolvers.py @@ -0,0 +1,99 @@ +# Copyright 2018-2019 Faculty Science Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from uuid import UUID + +from cachetools.func import lru_cache + +from faculty.context import get_context +from faculty.clients import AccountClient, ProjectClient + + +def _make_uuid(value): + """Make a UUID from the passed value. + + Pass through UUID objects as the UUID constructor fails when passed UUID + objects. + """ + if isinstance(value, UUID): + return value + else: + return UUID(value) + + +def _project_from_name(session, name): + """Provided a project name, find a matching project ID. + + This method searches all the projects accessible to the active user for a + matching project. If not exactly one project matches, a ValueError is + raised. + """ + + user_id = AccountClient(session).authenticated_user_id() + projects = ProjectClient(session).list_accessible_by_user(user_id) + + matches = [project for project in projects if project.name == name] + if len(matches) == 1: + return matches[0] + elif len(matches) == 0: + raise ValueError("No projects of name {} found".format(name)) + else: + raise ValueError("Multiple projects of name {} found".format(name)) + + +@lru_cache() +def resolve_project_id(session, project=None): + """Resolve the ID of a project based on ID, name or the current context. + + This helper encapsulates logic for determining a project in three + situations: + + * If ``None`` is passed as the project, or if no project is passed, the + project will be inferred from the runtime context (i.e. environment + variables), and so will correspond to the 'current project' when run + inside Faculty platform. + * If a ``uuid.UUID`` or a string containing a valid UUID is passed, this + will be assumed to be the ID of the project and will be returned. + * If any other string is passed, the Faculty platform will be queried for + projects matching that name. If exactly one of that name is accessible to + the user, its ID will be returned, otherwise a ``ValueError`` will be + raised. + + Parameters + ---------- + session : faculty.session.Session + project : str, uuid.UUID or None + Information to use to determine the active project. + + Returns + ------- + uuid.UUID + The ID of the project + """ + + if project is None: + context = get_context() + if context.project_id is None: + raise ValueError( + "Must pass a project name or ID when none can be determined " + "from the runtime context" + ) + else: + return context.project_id + else: + try: + return _make_uuid(project) + except ValueError: + return _project_from_name(session, project).id diff --git a/faculty/clients/experiment/__init__.py b/faculty/clients/experiment/__init__.py index 7f6a623f..5022eb03 100644 --- a/faculty/clients/experiment/__init__.py +++ b/faculty/clients/experiment/__init__.py @@ -38,6 +38,7 @@ RestoreExperimentRunsResponse, RunIdFilter, RunNumberSort, + SortOrder, StartedAtSort, Tag, TagFilter, diff --git a/faculty/clients/experiment/_models.py b/faculty/clients/experiment/_models.py index 00b6fecc..cf05a620 100644 --- a/faculty/clients/experiment/_models.py +++ b/faculty/clients/experiment/_models.py @@ -16,6 +16,8 @@ from collections import namedtuple from enum import Enum +from attr import attrs, attrib + class LifecycleStage(Enum): ACTIVE = "active" @@ -81,13 +83,73 @@ class ComparisonOperator(Enum): GREATER_THAN_OR_EQUAL_TO = "ge" -ProjectIdFilter = namedtuple("ProjectIdFilter", ["operator", "value"]) -ExperimentIdFilter = namedtuple("ExperimentIdFilter", ["operator", "value"]) -RunIdFilter = namedtuple("RunIdFilter", ["operator", "value"]) -DeletedAtFilter = namedtuple("DeletedAtFilter", ["operator", "value"]) -TagFilter = namedtuple("TagFilter", ["key", "operator", "value"]) -ParamFilter = namedtuple("ParamFilter", ["key", "operator", "value"]) -MetricFilter = namedtuple("MetricFilter", ["key", "operator", "value"]) +def _matching_compound(filter, operator): + return isinstance(filter, CompoundFilter) and filter.operator == operator + + +def _combine_filters(first, second, op): + if _matching_compound(first, op) and _matching_compound(second, op): + conditions = first.conditions + second.conditions + elif _matching_compound(first, op): + conditions = first.conditions + [second] + elif _matching_compound(second, op): + conditions = [first] + second.conditions + else: + conditions = [first, second] + return CompoundFilter(op, conditions) + + +class BaseFilter(object): + def __and__(self, other): + return _combine_filters(self, other, LogicalOperator.AND) + + def __or__(self, other): + return _combine_filters(self, other, LogicalOperator.OR) + + +@attrs +class ProjectIdFilter(BaseFilter): + operator = attrib() + value = attrib() + + +@attrs +class ExperimentIdFilter(BaseFilter): + operator = attrib() + value = attrib() + + +@attrs +class RunIdFilter(BaseFilter): + operator = attrib() + value = attrib() + + +@attrs +class DeletedAtFilter(BaseFilter): + operator = attrib() + value = attrib() + + +@attrs +class TagFilter(BaseFilter): + key = attrib() + operator = attrib() + value = attrib() + + +@attrs +class ParamFilter(BaseFilter): + key = attrib() + operator = attrib() + value = attrib() + + +@attrs +class MetricFilter(BaseFilter): + key = attrib() + operator = attrib() + value = attrib() class LogicalOperator(Enum): @@ -95,7 +157,10 @@ class LogicalOperator(Enum): OR = "or" -CompoundFilter = namedtuple("CompoundFilter", ["operator", "conditions"]) +@attrs +class CompoundFilter(BaseFilter): + operator = attrib() + conditions = attrib() class SortOrder(Enum): diff --git a/faculty/experiment.py b/faculty/experiment.py new file mode 100644 index 00000000..44be2c6d --- /dev/null +++ b/faculty/experiment.py @@ -0,0 +1,316 @@ +# Copyright 2018-2019 Faculty Science Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from attr import attrs, attrib +import pandas + +from faculty.session import get_session +from faculty._util.resolvers import resolve_project_id +from faculty.clients.experiment import ExperimentClient + +from faculty.clients.experiment import ( + ComparisonOperator, + DeletedAtFilter, + ExperimentIdFilter, + MetricFilter, + ParamFilter, + RunIdFilter, + TagFilter, +) + + +@attrs +class ExperimentRun(object): + """A single run of an experiment.""" + + id = attrib() + run_number = attrib() + experiment_id = attrib() + name = attrib() + parent_run_id = attrib() + artifact_location = attrib() + status = attrib() + started_at = attrib() + ended_at = attrib() + deleted_at = attrib() + tags = attrib() + params = attrib() + metrics = attrib() + + @classmethod + def _from_client_model(cls, client_object): + return cls(**client_object._asdict()) + + @classmethod + def query(cls, project=None, filter=None, sort=None, **session_config): + """Query the platform for experiment runs. + + Parameters + ---------- + project : str, UUID, or None + The name or ID of a project. If ``None`` is passed (the default), + the project will be inferred from the runtime context. + filter : a filter object from ``faculty.clients.experiment`` + Condition(s) to filter experiment runs by. ``FilterBy`` provides a + convenience interface for constructing filter objects. + sort : a sequence of sort objects from ``faculty.clients.experiment`` + Condition(s) to sort experiment runs by. + **session_config + Configuration options to build the session with. + + Returns + ------- + ExperimentRunList + + Examples + -------- + Get all experiment runs in the current project: + + >>> ExperimentRun.query() + ExperimentRunList([ExperimentRun(...)]) + + Get all experiment runs in a named project: + + >>> ExperimentRun.query("my project") + ExperimentRunList([ExperimentRun(...)]) + + Filter experiment runs by experiment ID: + + >>> ExperimentRun.query(filter=FilterBy.experiment_id() == 2) + ExperimentRunList([ExperimentRun(...)]) + + Filter experiment runs by a more complex condition: + + >>> filter = ( + ... FilterBy.experiment_id().one_of([2, 3, 4]) & + ... (FilterBy.metric("accuracy") > 0.9) & + ... (FilterBy.param("alpha") < 0.3) + ... ) + >>> ExperimentRun.query("my project", filter) + ExperimentRunList([ExperimentRun(...)]) + """ + + session = get_session(**session_config) + project_id = resolve_project_id(session, project) + + def _get_runs(): + client = ExperimentClient(session) + + response = client.query_runs(project_id, filter, sort) + for run in response.runs: + yield cls._from_client_model(run) + + while response.pagination.next is not None: + response = client.query_runs( + project_id, + filter, + sort, + start=response.pagination.next.start, + limit=response.pagination.next.limit, + ) + for run in response.runs: + yield cls._from_client_model(run) + + return ExperimentRunList(_get_runs()) + + +class ExperimentRunList(list): + """A list of experiment runs. + + This collection is a subclass of ``list``, and so supports all its + functionality, but adds the ``as_dataframe`` method which returns a + representation of the contained ExperimentRuns as a ``pandas.DataFrame``. + """ + + def __repr__(self): + return "{}({})".format( + self.__class__.__name__, super(ExperimentRunList, self).__repr__() + ) + + def as_dataframe(self): + """Get the experiment runs as a pandas DataFrame. + + Returns + ------- + pandas.DataFrame + """ + + records = [] + for run in self: + row = { + ("experiment_id", ""): run.experiment_id, + ("run_id", ""): run.id, + ("run_number", ""): run.run_number, + ("status", ""): run.status.value, + ("started_at", ""): run.started_at, + ("ended_at", ""): run.ended_at, + } + for param in run.params: + row[("params", param.key)] = param.value + for metric in run.metrics: + row[("metrics", metric.key)] = metric.value + records.append(row) + + df = pandas.DataFrame(records) + df.columns = pandas.MultiIndex.from_tuples(df.columns) + + # Reorder columns and return + column_order = [ + "experiment_id", + "run_id", + "run_number", + "status", + "started_at", + "ended_at", + ] + if "params" in df.columns: + column_order.append("params") + if "metrics" in df.columns: + column_order.append("metrics") + return df[column_order] + + +class _FilterBuilder(object): + def __init__(self, constructor, *constructor_args): + self.constructor = constructor + self.constructor_args = constructor_args + + def _build(self, *args): + return self.constructor(*(self.constructor_args + args)) + + def defined(self, value=True): + return self._build(ComparisonOperator.DEFINED, value) + + def __eq__(self, value): + return self._build(ComparisonOperator.EQUAL_TO, value) + + def __ne__(self, value): + return self._build(ComparisonOperator.NOT_EQUAL_TO, value) + + def __gt__(self, value): + return self._build(ComparisonOperator.GREATER_THAN, value) + + def __ge__(self, value): + return self._build(ComparisonOperator.GREATER_THAN_OR_EQUAL_TO, value) + + def __lt__(self, value): + return self._build(ComparisonOperator.LESS_THAN, value) + + def __le__(self, value): + return self._build(ComparisonOperator.LESS_THAN_OR_EQUAL_TO, value) + + def one_of(self, values): + try: + first, remaining = values[0], values[1:] + except IndexError: + raise ValueError("Must provide at least one value") + filter = self == first + for val in remaining: + filter |= self == val + return filter + + +class FilterBy(object): + @staticmethod + def experiment_id(): + """Filter by experiment ID. + + Examples + -------- + Get runs for experiment 4: + + >>> FilterBy.experiment_id() == 4 + """ + return _FilterBuilder(ExperimentIdFilter) + + @staticmethod + def run_id(): + """Filter by run ID. + + Examples + -------- + Get the run with a specified ID: + + >>> FilterBy.run_id() == "945f1d96-9937-4b95-aa3f-addcdd1c8749" + """ + return _FilterBuilder(RunIdFilter) + + @staticmethod + def deleted_at(): + """Filter by run deletion time. + + Examples + -------- + Get runs deleted more than ten minutes ago: + + >>> from datetime import datetime, timedelta + >>> FilterBy.deleted_at() < datetime.now() - timedelta(minutes=10) + + Get non-deleted runs: + + >>> FilterBy.deleted_at() == None + """ + return _FilterBuilder(DeletedAtFilter) + + @staticmethod + def tag(key): + """Filter by run tag. + + Examples + -------- + Get runs with a particular tag: + + >>> FilterBy.tag("key") == "value" + + Get runs where a tag is set, with any value: + + >>> FilterBy.tag("key") != None + """ + return _FilterBuilder(TagFilter, key) + + @staticmethod + def param(key): + """Filter by parameter. + + Examples + -------- + Get runs with a particular parameter value: + + >>> FilterBy.param("key") == "value" + + Params also support filtering by numeric value: + + >>> FilterBy.param("alpha") > 0.2 + """ + return _FilterBuilder(ParamFilter, key) + + @staticmethod + def metric(key): + """Filter by metric. + + Examples + -------- + Get runs with matching metric values: + + >>> FilterBy.metric("accuracy") > 0.9 + + To filter a range of values, combine them with ``&``: + + >>> ( + ... (FilterBy.metric("accuracy") > 0.8 ) & + ... (FilterBy.metric("accuracy") > 0.9) + ... ) + """ + return _FilterBuilder(MetricFilter, key) diff --git a/setup.py b/setup.py index e7f25fa8..2ff51703 100644 --- a/setup.py +++ b/setup.py @@ -31,9 +31,11 @@ "pytz", "six", "enum34; python_version<'3.4'", + "cachetools", + "attrs", + "pandas", # Install marshmallow with 'reco' (recommended) extras to ensure a # compatible version of python-dateutil is available - "attrs", "marshmallow[reco]==3.0.0rc3", "marshmallow_enum", "marshmallow-oneofschema==2.0.0b2", diff --git a/tests/_util/__init__.py b/tests/_util/__init__.py new file mode 100644 index 00000000..f5221fcb --- /dev/null +++ b/tests/_util/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2018-2019 Faculty Science Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/_util/test_resolvers.py b/tests/_util/test_resolvers.py new file mode 100644 index 00000000..da82d948 --- /dev/null +++ b/tests/_util/test_resolvers.py @@ -0,0 +1,109 @@ +# Copyright 2018-2019 Faculty Science Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from uuid import uuid4 + +import pytest + +from faculty._util.resolvers import resolve_project_id + + +PROJECT_ID = uuid4() + + +@pytest.fixture(autouse=True) +def clear_cache(): + resolve_project_id.cache_clear() + + +@pytest.fixture +def mock_session(mocker): + return mocker.Mock() + + +@pytest.fixture +def mock_account_client(mocker, mock_session): + class_mock = mocker.patch("faculty._util.resolvers.AccountClient") + yield class_mock.return_value + class_mock.assert_called_once_with(mock_session) + + +@pytest.fixture +def mock_project_client(mocker, mock_session): + class_mock = mocker.patch("faculty._util.resolvers.ProjectClient") + yield class_mock.return_value + class_mock.assert_called_once_with(mock_session) + + +def test_resolve_project_id( + mocker, mock_session, mock_account_client, mock_project_client +): + project = mocker.Mock() + project.name = "project name" + mock_project_client.list_accessible_by_user.return_value = [ + mocker.Mock(), + project, + mocker.Mock(), + ] + + assert resolve_project_id(mock_session, "project name") == project.id + + mock_account_client.authenticated_user_id.assert_called_once_with() + mock_project_client.list_accessible_by_user.assert_called_once_with( + mock_account_client.authenticated_user_id.return_value + ) + + +def test_resolve_project_id_no_matches( + mocker, mock_session, mock_account_client, mock_project_client +): + mock_project_client.list_accessible_by_user.return_value = [ + mocker.Mock(), + mocker.Mock(), + ] + with pytest.raises(ValueError, match="No projects .* found"): + resolve_project_id(mock_session, "project name") + + +def test_resolve_project_id_multiple_matches( + mocker, mock_session, mock_account_client, mock_project_client +): + project = mocker.Mock() + project.name = "project name" + mock_project_client.list_accessible_by_user.return_value = [ + project, + project, + ] + with pytest.raises(ValueError, match="Multiple projects .* found"): + resolve_project_id(mock_session, "project name") + + +@pytest.mark.parametrize("argument", [PROJECT_ID, str(PROJECT_ID)]) +def test_resolve_project_id_is_uuid(mock_session, argument): + assert resolve_project_id(mock_session, argument) == PROJECT_ID + + +def test_resolve_project_id_from_context(mocker, mock_session): + context = mocker.Mock() + mocker.patch("faculty._util.resolvers.get_context", return_value=context) + assert resolve_project_id(mock_session) == context.project_id + + +def test_resolve_project_id_from_context_missing(mocker, mock_session): + context = mocker.Mock() + context.project_id = None + mocker.patch("faculty._util.resolvers.get_context", return_value=context) + with pytest.raises(ValueError): + resolve_project_id(mock_session) diff --git a/tests/clients/experiment/test_models.py b/tests/clients/experiment/test_models.py new file mode 100644 index 00000000..f40dd74b --- /dev/null +++ b/tests/clients/experiment/test_models.py @@ -0,0 +1,109 @@ +# Copyright 2018-2019 Faculty Science Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import uuid + +import pytest + +from faculty.clients.experiment._models import ( + ComparisonOperator, + CompoundFilter, + DeletedAtFilter, + ExperimentIdFilter, + LogicalOperator, + MetricFilter, + ParamFilter, + ProjectIdFilter, + RunIdFilter, + TagFilter, +) + + +SINGLE_FILTERS = [ + ProjectIdFilter(ComparisonOperator.EQUAL_TO, uuid.uuid4()), + ExperimentIdFilter(ComparisonOperator.NOT_EQUAL_TO, 4), + RunIdFilter(ComparisonOperator.EQUAL_TO, uuid.uuid4()), + DeletedAtFilter(ComparisonOperator.DEFINED, False), + TagFilter("key", ComparisonOperator.EQUAL_TO, "value"), + ParamFilter("key", ComparisonOperator.NOT_EQUAL_TO, "value"), + ParamFilter("key", ComparisonOperator.GREATER_THAN, 0.3), + MetricFilter("key", ComparisonOperator.LESS_THAN_OR_EQUAL_TO, 0.6), +] +AND_FILTER = CompoundFilter( + LogicalOperator.AND, + [ + ExperimentIdFilter(ComparisonOperator.EQUAL_TO, 4), + ParamFilter("key", ComparisonOperator.GREATER_THAN_OR_EQUAL_TO, 0.4), + ], +) +OR_FILTER = CompoundFilter( + LogicalOperator.OR, + [ + ExperimentIdFilter(ComparisonOperator.EQUAL_TO, 4), + ExperimentIdFilter(ComparisonOperator.EQUAL_TO, 5), + ], +) + + +@pytest.mark.parametrize("left", SINGLE_FILTERS + [OR_FILTER]) +@pytest.mark.parametrize("right", SINGLE_FILTERS + [OR_FILTER]) +def test_non_mergable_and(left, right): + assert (left & right) == CompoundFilter(LogicalOperator.AND, [left, right]) + + +@pytest.mark.parametrize("left", SINGLE_FILTERS + [AND_FILTER]) +@pytest.mark.parametrize("right", SINGLE_FILTERS + [AND_FILTER]) +def test_non_mergable_or(left, right): + assert (left | right) == CompoundFilter(LogicalOperator.OR, [left, right]) + + +@pytest.mark.parametrize("right", SINGLE_FILTERS) +def test_left_mergeable_and(right): + assert (AND_FILTER & right) == CompoundFilter( + LogicalOperator.AND, AND_FILTER.conditions + [right] + ) + + +@pytest.mark.parametrize("right", SINGLE_FILTERS) +def test_left_mergeable_or(right): + assert (OR_FILTER | right) == CompoundFilter( + LogicalOperator.OR, OR_FILTER.conditions + [right] + ) + + +@pytest.mark.parametrize("left", SINGLE_FILTERS) +def test_right_mergeable_and(left): + assert (left & AND_FILTER) == CompoundFilter( + LogicalOperator.AND, [left] + AND_FILTER.conditions + ) + + +@pytest.mark.parametrize("left", SINGLE_FILTERS) +def test_right_mergeable_or(left): + assert (left | OR_FILTER) == CompoundFilter( + LogicalOperator.OR, [left] + OR_FILTER.conditions + ) + + +def test_fully_mergable_and(): + assert (AND_FILTER & AND_FILTER) == CompoundFilter( + LogicalOperator.AND, AND_FILTER.conditions + AND_FILTER.conditions + ) + + +def test_fully_mergable_or(): + assert (OR_FILTER | OR_FILTER) == CompoundFilter( + LogicalOperator.OR, OR_FILTER.conditions + OR_FILTER.conditions + ) diff --git a/tests/test_experiment.py b/tests/test_experiment.py new file mode 100644 index 00000000..63e1f94c --- /dev/null +++ b/tests/test_experiment.py @@ -0,0 +1,353 @@ +# Copyright 2018-2019 Faculty Science Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from datetime import datetime +from uuid import uuid4 +import operator + +import pandas +import pytest +from pytz import UTC +import mock + +from faculty.clients.experiment import ( + ComparisonOperator, + CompoundFilter, + DeletedAtFilter, + ExperimentIdFilter, + ExperimentRun as ClientExperimentRun, + ExperimentRunStatus, + LogicalOperator, + MetricFilter, + ParamFilter, + RunIdFilter, + TagFilter, +) + +from faculty.experiment import ExperimentRun, ExperimentRunList, FilterBy + + +DATETIMES = [ + datetime(2018, 3, 10, 11, 39, 12, 110000, tzinfo=UTC), + datetime(2018, 3, 10, 11, 32, 6, 247000, tzinfo=UTC), + datetime(2018, 3, 10, 11, 32, 30, 172000, tzinfo=UTC), + datetime(2018, 3, 10, 11, 37, 42, 482000, tzinfo=UTC), +] + + +def mock_client_run(): + return ClientExperimentRun( + **{field: mock.Mock() for field in ClientExperimentRun._fields} + ) + + +def expected_resource_run_for_client_run(run): + return ExperimentRun( + id=run.id, + run_number=run.run_number, + experiment_id=run.experiment_id, + name=run.name, + parent_run_id=run.parent_run_id, + artifact_location=run.artifact_location, + status=run.status, + started_at=run.started_at, + ended_at=run.ended_at, + deleted_at=run.deleted_at, + tags=run.tags, + params=run.params, + metrics=run.metrics, + ) + + +def test_experiment_run_query(mocker): + session = mocker.Mock() + get_session_mock = mocker.patch( + "faculty.experiment.get_session", return_value=session + ) + + project_id = mocker.Mock() + resolve_project_id_mock = mocker.patch( + "faculty.experiment.resolve_project_id", return_value=project_id + ) + + client = mocker.Mock() + mocker.patch("faculty.experiment.ExperimentClient", return_value=client) + client_runs = [mock_client_run(), mock_client_run()] + client_response = mocker.Mock() + client_response.runs = client_runs + client_response.pagination.next = None + client.query_runs.return_value = client_response + + filter = mocker.Mock() + sort = mocker.Mock() + + runs = ExperimentRun.query("my project", filter, sort, extra_conf="foo") + assert runs == ExperimentRunList( + [expected_resource_run_for_client_run(run) for run in client_runs] + ) + + get_session_mock.assert_called_once_with(extra_conf="foo") + resolve_project_id_mock.assert_called_once_with(session, "my project") + client.query_runs.assert_called_once_with(project_id, filter, sort) + + +def test_experiment_run_query_multiple_pages(mocker): + session = mocker.Mock() + get_session_mock = mocker.patch( + "faculty.experiment.get_session", return_value=session + ) + + project_id = mocker.Mock() + resolve_project_id_mock = mocker.patch( + "faculty.experiment.resolve_project_id", return_value=project_id + ) + + client = mocker.Mock() + mocker.patch("faculty.experiment.ExperimentClient", return_value=client) + client_response_0 = mocker.Mock() + client_response_0.runs = [mock_client_run(), mock_client_run()] + client_response_1 = mocker.Mock() + client_response_1.runs = [mock_client_run(), mock_client_run()] + client_response_2 = mocker.Mock() + client_response_2.runs = [mock_client_run(), mock_client_run()] + client_response_2.pagination.next = None + client.query_runs.side_effect = [ + client_response_0, + client_response_1, + client_response_2, + ] + all_client_runs = ( + client_response_0.runs + + client_response_1.runs + + client_response_2.runs + ) + + filter = mocker.Mock() + sort = mocker.Mock() + + runs = ExperimentRun.query("my project", filter, sort, extra_conf="foo") + assert runs == ExperimentRunList( + [expected_resource_run_for_client_run(run) for run in all_client_runs] + ) + + get_session_mock.assert_called_once_with(extra_conf="foo") + resolve_project_id_mock.assert_called_once_with(session, "my project") + client.query_runs.assert_has_calls( + [ + mocker.call(project_id, filter, sort), + mocker.call( + project_id, + filter, + sort, + start=client_response_0.pagination.next.start, + limit=client_response_0.pagination.next.limit, + ), + mocker.call( + project_id, + filter, + sort, + start=client_response_1.pagination.next.start, + limit=client_response_1.pagination.next.limit, + ), + ] + ) + + +def test_experiment_run_list_as_dataframe(mocker): + run_0 = mocker.Mock( + experiment_id=1, + id=uuid4(), + run_number=3, + status=ExperimentRunStatus.FINISHED, + started_at=DATETIMES[0], + ended_at=DATETIMES[1], + params=[ + mocker.Mock(key="classic", value="foo"), + mocker.Mock(key="monty", value="spam"), + ], + metrics=[ + mocker.Mock(key="accuracy", value=0.87), + mocker.Mock(key="f1_score", value=0.76), + ], + ) + run_1 = mocker.Mock( + experiment_id=2, + id=uuid4(), + run_number=4, + status=ExperimentRunStatus.RUNNING, + started_at=DATETIMES[2], + ended_at=DATETIMES[3], + params=[ + mocker.Mock(key="classic", value="bar"), + mocker.Mock(key="monty", value="eggs"), + ], + metrics=[ + mocker.Mock(key="accuracy", value=0.91), + mocker.Mock(key="f1_score", value=0.72), + ], + ) + + runs_df = ExperimentRunList([run_0, run_1]).as_dataframe() + + assert list(runs_df.columns) == [ + ("experiment_id", ""), + ("run_id", ""), + ("run_number", ""), + ("status", ""), + ("started_at", ""), + ("ended_at", ""), + ("params", "classic"), + ("params", "monty"), + ("metrics", "accuracy"), + ("metrics", "f1_score"), + ] + assert (runs_df.experiment_id == [1, 2]).all() + assert (runs_df.run_id == [run_0.id, run_1.id]).all() + assert (runs_df.run_number == [3, 4]).all() + assert (runs_df.status == ["finished", "running"]).all() + assert ( + runs_df.started_at == pandas.Series([DATETIMES[0], DATETIMES[2]]) + ).all() + assert ( + runs_df.ended_at == pandas.Series([DATETIMES[1], DATETIMES[3]]) + ).all() + assert (runs_df.params.classic == ["foo", "bar"]).all() + assert (runs_df.params.monty == ["spam", "eggs"]).all() + assert (runs_df.metrics.accuracy == [0.87, 0.91]).all() + assert (runs_df.metrics.f1_score == [0.76, 0.72]).all() + + +def test_experiment_run_list_as_dataframe_no_params(mocker): + run = mocker.Mock( + experiment_id=1, + id=uuid4(), + run_number=3, + status=ExperimentRunStatus.FINISHED, + started_at=DATETIMES[0], + ended_at=DATETIMES[1], + params=[], + metrics=[mocker.Mock(key="accuracy", value=0.91)], + ) + + runs_df = ExperimentRunList([run]).as_dataframe() + + assert list(runs_df.columns) == [ + ("experiment_id", ""), + ("run_id", ""), + ("run_number", ""), + ("status", ""), + ("started_at", ""), + ("ended_at", ""), + ("metrics", "accuracy"), + ] + + +def test_experiment_run_list_as_dataframe_no_metrics(mocker): + run = mocker.Mock( + experiment_id=1, + id=uuid4(), + run_number=3, + status=ExperimentRunStatus.FINISHED, + started_at=DATETIMES[0], + ended_at=DATETIMES[1], + params=[mocker.Mock(key="classic", value="bar")], + metrics=[], + ) + + runs_df = ExperimentRunList([run]).as_dataframe() + + assert list(runs_df.columns) == [ + ("experiment_id", ""), + ("run_id", ""), + ("run_number", ""), + ("status", ""), + ("started_at", ""), + ("ended_at", ""), + ("params", "classic"), + ] + + +FILTER_BY_NO_KEY_CASES = [ + (FilterBy.experiment_id, ExperimentIdFilter), + (FilterBy.run_id, RunIdFilter), + (FilterBy.deleted_at, DeletedAtFilter), +] +FILTER_BY_WITH_KEY_CASES = [ + (FilterBy.tag, TagFilter), + (FilterBy.param, ParamFilter), + (FilterBy.metric, MetricFilter), +] + +OPERATOR_CASES = [ + (lambda x, v: x.defined(v), ComparisonOperator.DEFINED), + (operator.eq, ComparisonOperator.EQUAL_TO), + (operator.ne, ComparisonOperator.NOT_EQUAL_TO), + (operator.gt, ComparisonOperator.GREATER_THAN), + (operator.ge, ComparisonOperator.GREATER_THAN_OR_EQUAL_TO), + (operator.lt, ComparisonOperator.LESS_THAN), + (operator.le, ComparisonOperator.LESS_THAN_OR_EQUAL_TO), +] + + +@pytest.mark.parametrize("method, filter_cls", FILTER_BY_NO_KEY_CASES) +@pytest.mark.parametrize("python_operator, expected_operator", OPERATOR_CASES) +def test_filter_by_no_key( + mocker, method, filter_cls, python_operator, expected_operator +): + value = mocker.Mock() + filter = python_operator(method(), value) + expected = filter_cls(expected_operator, value) + assert filter == expected + + +@pytest.mark.parametrize("method, filter_cls", FILTER_BY_WITH_KEY_CASES) +@pytest.mark.parametrize("python_operator, expected_operator", OPERATOR_CASES) +def test_filter_by_with_key( + mocker, method, filter_cls, python_operator, expected_operator +): + key = mocker.Mock() + value = mocker.Mock() + filter = python_operator(method(key), value) + expected = filter_cls(key, expected_operator, value) + assert filter == expected + + +@pytest.mark.parametrize("method, filter_cls", FILTER_BY_NO_KEY_CASES) +def test_filter_by_one_of_no_key(mocker, method, filter_cls): + values = [mocker.Mock(), mocker.Mock()] + filter = method().one_of(values) + expected = CompoundFilter( + LogicalOperator.OR, + [ + filter_cls(ComparisonOperator.EQUAL_TO, values[0]), + filter_cls(ComparisonOperator.EQUAL_TO, values[1]), + ], + ) + assert filter == expected + + +@pytest.mark.parametrize("method, filter_cls", FILTER_BY_WITH_KEY_CASES) +def test_filter_by_one_of_with_key(mocker, method, filter_cls): + key = mocker.Mock() + values = [mocker.Mock(), mocker.Mock()] + filter = method(key).one_of(values) + expected = CompoundFilter( + LogicalOperator.OR, + [ + filter_cls(key, ComparisonOperator.EQUAL_TO, values[0]), + filter_cls(key, ComparisonOperator.EQUAL_TO, values[1]), + ], + ) + assert filter == expected diff --git a/tox.ini b/tox.ini index 0b5826ae..9624a1e8 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = py27, py34, py35, py36, py37, flake8, black, license +envlist = py27, py35, py36, py37, flake8, black, license [testenv] sitepackages = False @@ -9,6 +9,7 @@ deps = pytest-mock requests_mock python-dateutil>=2.7 + mock commands = pytest {posargs} [testenv:flake8]