Skip to content

Commit

Permalink
fix: use task sessions in Core API
Browse files Browse the repository at this point in the history
  • Loading branch information
azhou-determined committed Aug 22, 2024
1 parent 3a91552 commit 716cd95
Show file tree
Hide file tree
Showing 13 changed files with 93 additions and 19 deletions.
2 changes: 1 addition & 1 deletion harness/determined/common/api/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from determined.common.api import authentication, errors, metric, bindings
from determined.common.api._session import BaseSession, UnauthSession, Session
from determined.common.api._session import BaseSession, UnauthSession, Session, TaskSession
from determined.common.api._util import (
PageOpts,
get_ntsc_details,
Expand Down
21 changes: 20 additions & 1 deletion harness/determined/common/api/_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,8 @@ class Session(BaseSession):
By far, most BaseSessions in the codebase will be this Session subclass.
"""

AUTH_HEADER = "Authorization"

def __init__(
self,
master: str,
Expand Down Expand Up @@ -308,10 +310,27 @@ def _make_http_session(self) -> requests.Session:
server_hostname=self.cert.name if self.cert else None,
verify=self.cert.bundle if self.cert else None,
max_retries=self._max_retries,
headers={"Authorization": f"Bearer {self.token}"},
headers={self.AUTH_HEADER: f"Bearer {self.token}"},
)


class TaskSession(Session):
"""
``TaskSession`` is a subclass of ``Session`` designed to be used for authenticating requests
using a task session token. It simply overrides the authentication header name used for
requests.
Most sessions that are created from user input should use ``Session`` instead, which
authenticates requests using a user token (i.e. the CLI, SDK).
Task session tokens really only have a longer expiration, and should be used for internal
sessions that may persist throughout a long training job (i.e. Core API).
"""

# Authentication header name for task session tokens
AUTH_HEADER = "Grpc-Metadata-x-allocation-token"


class _HTTPSAdapter(adapters.HTTPAdapter):
"""Overrides the hostname checked against for TLS verification.
Expand Down
25 changes: 25 additions & 0 deletions harness/determined/common/api/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@ def get_det_password_from_env() -> Optional[str]:
return os.environ.get("DET_PASS")


def get_det_session_token_from_env() -> Optional[str]:
return os.environ.get("DET_SESSION_TOKEN")


def login(
master_address: str,
username: str,
Expand Down Expand Up @@ -307,6 +311,27 @@ def logout_all(master_address: str, cert: Optional[certs.Cert]) -> None:
logout(master_address, user, cert)


def login_from_task(
master_address: str,
cert: Optional[certs.Cert],
) -> "api.TaskSession":
"""
Creates a ``TaskSession`` from environment variables to be used for authenticating subsequent
requests.
This method should only be called on-cluster, from inside a task container.
"""
session_token = get_det_session_token_from_env()
if not session_token:
raise ValueError("DET_SESSION_TOKEN environment variable not set.")

username = get_det_username_from_env()
if not username:
raise ValueError("DET_USER environment variable not set.")

return api.TaskSession(master=master_address, username=username, token=session_token, cert=cert)


def _is_token_valid(master_address: str, token: str, cert: Optional[certs.Cert]) -> bool:
"""
Find out whether the given token is valid by attempting to use it
Expand Down
7 changes: 4 additions & 3 deletions harness/determined/core/_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,9 +242,10 @@ def init(

# We are on the cluster.
cert = certs.default_load(info.master_url)
session = authentication.login_with_cache(info.master_url, cert=cert).with_retry(
util.get_max_retries_config()
)
session = authentication.login_from_task(
master_address=info.master_url,
cert=cert,
).with_retry(util.get_max_retries_config())

if distributed is None:
if len(info.container_addrs) > 1 or len(info.slot_ids) > 1:
Expand Down
2 changes: 1 addition & 1 deletion harness/determined/exec/gc_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def patch_checkpoints(storage_ids_to_resources: Dict[str, Dict[str, int]]) -> No

cert = certs.default_load(info.master_url)
# With backoff retries for 64 seconds
sess = authentication.login_with_cache(info.master_url, cert=cert).with_retry(
sess = authentication.login_from_task(info.master_url, cert=cert).with_retry(
urllib3.util.retry.Retry(total=6, backoff_factor=0.5)
)

Expand Down
2 changes: 1 addition & 1 deletion harness/determined/exec/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def trigger_preemption(signum: int, frame: types.FrameType) -> None:
logger.info("SIGTERM: Preemption imminent.")
# Notify the master that we need to be preempted
cert = certs.default_load(info.master_url)
sess = authentication.login_with_cache(info.master_url, cert=cert)
sess = authentication.login_from_task(info.master_url, cert=cert)
sess.post(f"/api/v1/allocations/{info.allocation_id}/signals/pending_preemption")


Expand Down
2 changes: 1 addition & 1 deletion harness/determined/exec/prep_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ def do_proxy(sess: api.Session, allocation_id: str) -> None:

cert = certs.default_load(info.master_url)
# With backoff retries for 64 seconds
sess = authentication.login_with_cache(info.master_url, cert=cert).with_retry(
sess = authentication.login_from_task(info.master_url, cert=cert).with_retry(
urllib3.util.retry.Retry(total=6, backoff_factor=0.5)
)

Expand Down
8 changes: 4 additions & 4 deletions harness/determined/experimental/core_v2/_core_context_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import determined as det
from determined import core, experimental, tensorboard
from determined.common import constants, storage, util
from determined.common import api, constants, storage, util
from determined.common.api import authentication, certs

logger = logging.getLogger("determined.core")
Expand Down Expand Up @@ -43,9 +43,9 @@ def _make_v2_context(

# We are on the cluster.
cert = certs.default_load(info.master_url)
session = authentication.login_with_cache(info.master_url, cert=cert).with_retry(
util.get_max_retries_config()
)
session: api.Session = authentication.login_from_task(
info.master_url, cert=cert
).with_retry(util.get_max_retries_config())
else:
unmanaged = True

Expand Down
2 changes: 1 addition & 1 deletion harness/determined/launch/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def main(script: List[str]) -> int:
# Mark sshd containers as daemon containers that the master should kill when all non-daemon
# containers (deepspeed launcher, in this case) have exited.
cert = certs.default_load(info.master_url)
sess = authentication.login_with_cache(info.master_url, cert=cert)
sess = authentication.login_from_task(info.master_url, cert=cert)
sess.post(f"/api/v1/allocations/{info.allocation_id}/resources/{resources_id}/daemon")

# Wrap it in a pid_server to ensure that we can't hang if a worker fails.
Expand Down
2 changes: 1 addition & 1 deletion harness/determined/launch/horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def main(hvd_args: List[str], script: List[str], autohorovod: bool) -> int:
# Mark sshd containers as daemon resources that the master should kill when all non-daemon
# containers (horovodrun, in this case) have exited.
cert = certs.default_load(info.master_url)
sess = authentication.login_with_cache(info.master_url, cert=cert)
sess = authentication.login_from_task(info.master_url, cert=cert)
sess.post(f"/api/v1/allocations/{info.allocation_id}/resources/{resources_id}/daemon")

pid_server_cmd, run_sshd_command = create_sshd_worker_cmd(
Expand Down
27 changes: 26 additions & 1 deletion harness/tests/cli/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from determined.cli import cli
from determined.common import api
from determined.common.api import authentication
from determined.common.api import authentication, certs
from tests.cli import util

MOCK_MASTER_URL = "http://localhost:8080"
Expand Down Expand Up @@ -439,3 +439,28 @@ def test_logout_all() -> None:
mts.clear_active()

cli.main(["user", "logout", "--all"])


def test_login_from_task() -> None:
mock_session_token = "abcde12345"
mock_user = "abababa"
mock_cert = certs.Cert()
with contextlib.ExitStack() as es:
# Configure environment variables.
es.enter_context(util.setenv_optional("DET_SESSION_TOKEN", mock_session_token))
es.enter_context(util.setenv_optional("DET_USER", mock_user))

with responses.RequestsMock(
registry=registries.OrderedRegistry, assert_all_requests_are_fired=True
) as rsps:
sess = authentication.login_from_task(master_address=MOCK_MASTER_URL, cert=mock_cert)
assert sess.token == mock_session_token
assert sess.username == mock_user
assert sess.cert == mock_cert

rsps.get(
f"{MOCK_MASTER_URL}/api/v1/me",
status=200,
match=[matchers.header_matcher({sess.AUTH_HEADER: f"Bearer {mock_session_token}"})],
)
sess.get("/api/v1/me")
2 changes: 1 addition & 1 deletion master/static/srv/check_idle.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def main():
notebook_server = f"https://127.0.0.1:{port}/proxy/{notebook_id}"
master_url = api.canonicalize_master_url(os.environ["DET_MASTER"])
cert = certs.default_load(master_url)
sess = authentication.login_with_cache(master_url, cert=cert)
sess = authentication.login_from_task(master_url, cert=cert)
try:
idle_type = IdleType[os.environ["NOTEBOOK_IDLE_TYPE"].upper()]
except KeyError:
Expand Down
10 changes: 7 additions & 3 deletions master/static/srv/check_ready_logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def main(ready: Pattern, waiting: Optional[Pattern] = None):
cert = certs.default_load(master_url)
# This only runs on-cluster, so it is expected the username and session token are present in the
# environment.
sess = authentication.login_with_cache(master_url, cert=cert)
sess = authentication.login_from_task(master_url, cert=cert)
allocation_id = str(os.environ["DET_ALLOCATION_ID"])
for line in sys.stdin:
if ready.match(line):
Expand All @@ -49,11 +49,15 @@ def main(ready: Pattern, waiting: Optional[Pattern] = None):


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Read STDIN for a match and mark a task as ready")
parser = argparse.ArgumentParser(
description="Read STDIN for a match and mark a task as ready"
)
parser.add_argument(
"--ready-regex", type=str, help="the pattern to match task ready", required=True
)
parser.add_argument("--waiting-regex", type=str, help="the pattern to match task waiting")
parser.add_argument(
"--waiting-regex", type=str, help="the pattern to match task waiting"
)
args = parser.parse_args()

ready_regex = re.compile(args.ready_regex)
Expand Down

0 comments on commit 716cd95

Please sign in to comment.