diff --git a/harness/determined/common/api/__init__.py b/harness/determined/common/api/__init__.py index b70004ba126..90f8abbd147 100644 --- a/harness/determined/common/api/__init__.py +++ b/harness/determined/common/api/__init__.py @@ -1,5 +1,5 @@ from determined.common.api import authentication, errors, metric, bindings -from determined.common.api._session import BaseSession, UnauthSession, Session +from determined.common.api._session import BaseSession, UnauthSession, Session, TaskSession from determined.common.api._util import ( PageOpts, get_ntsc_details, diff --git a/harness/determined/common/api/_session.py b/harness/determined/common/api/_session.py index e7c5f306e6e..5baeccbe365 100644 --- a/harness/determined/common/api/_session.py +++ b/harness/determined/common/api/_session.py @@ -281,6 +281,8 @@ class Session(BaseSession): By far, most BaseSessions in the codebase will be this Session subclass. """ + AUTH_HEADER = "Authorization" + def __init__( self, master: str, @@ -308,10 +310,27 @@ def _make_http_session(self) -> requests.Session: server_hostname=self.cert.name if self.cert else None, verify=self.cert.bundle if self.cert else None, max_retries=self._max_retries, - headers={"Authorization": f"Bearer {self.token}"}, + headers={self.AUTH_HEADER: f"Bearer {self.token}"}, ) +class TaskSession(Session): + """ + ``TaskSession`` is a subclass of ``Session`` designed to be used for authenticating requests + using a task session token. It simply overrides the authentication header name used for + requests. + + Most sessions that are created from user input should use ``Session`` instead, which + authenticates requests using a user token (i.e. the CLI, SDK). + + Task session tokens really only have a longer expiration, and should be used for internal + sessions that may persist throughout a long training job (i.e. Core API). + """ + + # Authentication header name for task session tokens + AUTH_HEADER = "Grpc-Metadata-x-allocation-token" + + class _HTTPSAdapter(adapters.HTTPAdapter): """Overrides the hostname checked against for TLS verification. diff --git a/harness/determined/common/api/authentication.py b/harness/determined/common/api/authentication.py index 7eea446c6b0..50109125b80 100644 --- a/harness/determined/common/api/authentication.py +++ b/harness/determined/common/api/authentication.py @@ -101,6 +101,10 @@ def get_det_password_from_env() -> Optional[str]: return os.environ.get("DET_PASS") +def get_det_session_token_from_env() -> Optional[str]: + return os.environ.get("DET_SESSION_TOKEN") + + def login( master_address: str, username: str, @@ -307,6 +311,27 @@ def logout_all(master_address: str, cert: Optional[certs.Cert]) -> None: logout(master_address, user, cert) +def login_from_task( + master_address: str, + cert: Optional[certs.Cert], +) -> "api.TaskSession": + """ + Creates a ``TaskSession`` from environment variables to be used for authenticating subsequent + requests. + + This method should only be called on-cluster, from inside a task container. + """ + session_token = get_det_session_token_from_env() + if not session_token: + raise ValueError("DET_SESSION_TOKEN environment variable not set.") + + username = get_det_username_from_env() + if not username: + raise ValueError("DET_USER environment variable not set.") + + return api.TaskSession(master=master_address, username=username, token=session_token, cert=cert) + + def _is_token_valid(master_address: str, token: str, cert: Optional[certs.Cert]) -> bool: """ Find out whether the given token is valid by attempting to use it diff --git a/harness/determined/core/_context.py b/harness/determined/core/_context.py index 827a355b952..1e5f200f6db 100644 --- a/harness/determined/core/_context.py +++ b/harness/determined/core/_context.py @@ -242,9 +242,10 @@ def init( # We are on the cluster. cert = certs.default_load(info.master_url) - session = authentication.login_with_cache(info.master_url, cert=cert).with_retry( - util.get_max_retries_config() - ) + session = authentication.login_from_task( + master_address=info.master_url, + cert=cert, + ).with_retry(util.get_max_retries_config()) if distributed is None: if len(info.container_addrs) > 1 or len(info.slot_ids) > 1: diff --git a/harness/determined/exec/gc_checkpoints.py b/harness/determined/exec/gc_checkpoints.py index 8baa28e0370..0693962f39f 100644 --- a/harness/determined/exec/gc_checkpoints.py +++ b/harness/determined/exec/gc_checkpoints.py @@ -27,7 +27,7 @@ def patch_checkpoints(storage_ids_to_resources: Dict[str, Dict[str, int]]) -> No cert = certs.default_load(info.master_url) # With backoff retries for 64 seconds - sess = authentication.login_with_cache(info.master_url, cert=cert).with_retry( + sess = authentication.login_from_task(info.master_url, cert=cert).with_retry( urllib3.util.retry.Retry(total=6, backoff_factor=0.5) ) diff --git a/harness/determined/exec/launch.py b/harness/determined/exec/launch.py index 9ef077a119d..96502ed2c78 100644 --- a/harness/determined/exec/launch.py +++ b/harness/determined/exec/launch.py @@ -22,7 +22,7 @@ def trigger_preemption(signum: int, frame: types.FrameType) -> None: logger.info("SIGTERM: Preemption imminent.") # Notify the master that we need to be preempted cert = certs.default_load(info.master_url) - sess = authentication.login_with_cache(info.master_url, cert=cert) + sess = authentication.login_from_task(info.master_url, cert=cert) sess.post(f"/api/v1/allocations/{info.allocation_id}/signals/pending_preemption") diff --git a/harness/determined/exec/prep_container.py b/harness/determined/exec/prep_container.py index 18e27e76584..392bcdda1e7 100644 --- a/harness/determined/exec/prep_container.py +++ b/harness/determined/exec/prep_container.py @@ -388,7 +388,7 @@ def do_proxy(sess: api.Session, allocation_id: str) -> None: cert = certs.default_load(info.master_url) # With backoff retries for 64 seconds - sess = authentication.login_with_cache(info.master_url, cert=cert).with_retry( + sess = authentication.login_from_task(info.master_url, cert=cert).with_retry( urllib3.util.retry.Retry(total=6, backoff_factor=0.5) ) diff --git a/harness/determined/experimental/core_v2/_core_context_v2.py b/harness/determined/experimental/core_v2/_core_context_v2.py index f3cb3623813..3a0c4c624a2 100644 --- a/harness/determined/experimental/core_v2/_core_context_v2.py +++ b/harness/determined/experimental/core_v2/_core_context_v2.py @@ -6,7 +6,7 @@ import determined as det from determined import core, experimental, tensorboard -from determined.common import constants, storage, util +from determined.common import api, constants, storage, util from determined.common.api import authentication, certs logger = logging.getLogger("determined.core") @@ -43,9 +43,9 @@ def _make_v2_context( # We are on the cluster. cert = certs.default_load(info.master_url) - session = authentication.login_with_cache(info.master_url, cert=cert).with_retry( - util.get_max_retries_config() - ) + session: api.Session = authentication.login_from_task( + info.master_url, cert=cert + ).with_retry(util.get_max_retries_config()) else: unmanaged = True diff --git a/harness/determined/launch/deepspeed.py b/harness/determined/launch/deepspeed.py index 347306a1ff1..6cff0c91d91 100644 --- a/harness/determined/launch/deepspeed.py +++ b/harness/determined/launch/deepspeed.py @@ -286,7 +286,7 @@ def main(script: List[str]) -> int: # Mark sshd containers as daemon containers that the master should kill when all non-daemon # containers (deepspeed launcher, in this case) have exited. cert = certs.default_load(info.master_url) - sess = authentication.login_with_cache(info.master_url, cert=cert) + sess = authentication.login_from_task(info.master_url, cert=cert) sess.post(f"/api/v1/allocations/{info.allocation_id}/resources/{resources_id}/daemon") # Wrap it in a pid_server to ensure that we can't hang if a worker fails. diff --git a/harness/determined/launch/horovod.py b/harness/determined/launch/horovod.py index a30f5bce83a..899a290c364 100644 --- a/harness/determined/launch/horovod.py +++ b/harness/determined/launch/horovod.py @@ -128,7 +128,7 @@ def main(hvd_args: List[str], script: List[str], autohorovod: bool) -> int: # Mark sshd containers as daemon resources that the master should kill when all non-daemon # containers (horovodrun, in this case) have exited. cert = certs.default_load(info.master_url) - sess = authentication.login_with_cache(info.master_url, cert=cert) + sess = authentication.login_from_task(info.master_url, cert=cert) sess.post(f"/api/v1/allocations/{info.allocation_id}/resources/{resources_id}/daemon") pid_server_cmd, run_sshd_command = create_sshd_worker_cmd( diff --git a/harness/tests/cli/test_auth.py b/harness/tests/cli/test_auth.py index 7a7cd6f66e0..7209a7716db 100644 --- a/harness/tests/cli/test_auth.py +++ b/harness/tests/cli/test_auth.py @@ -10,7 +10,7 @@ from determined.cli import cli from determined.common import api -from determined.common.api import authentication +from determined.common.api import authentication, certs from tests.cli import util MOCK_MASTER_URL = "http://localhost:8080" @@ -439,3 +439,28 @@ def test_logout_all() -> None: mts.clear_active() cli.main(["user", "logout", "--all"]) + + +def test_login_from_task() -> None: + mock_session_token = "abcde12345" + mock_user = "abababa" + mock_cert = certs.Cert() + with contextlib.ExitStack() as es: + # Configure environment variables. + es.enter_context(util.setenv_optional("DET_SESSION_TOKEN", mock_session_token)) + es.enter_context(util.setenv_optional("DET_USER", mock_user)) + + with responses.RequestsMock( + registry=registries.OrderedRegistry, assert_all_requests_are_fired=True + ) as rsps: + sess = authentication.login_from_task(master_address=MOCK_MASTER_URL, cert=mock_cert) + assert sess.token == mock_session_token + assert sess.username == mock_user + assert sess.cert == mock_cert + + rsps.get( + f"{MOCK_MASTER_URL}/api/v1/me", + status=200, + match=[matchers.header_matcher({sess.AUTH_HEADER: f"Bearer {mock_session_token}"})], + ) + sess.get("/api/v1/me") diff --git a/master/static/srv/check_idle.py b/master/static/srv/check_idle.py index 9949065b1ae..f80c4986655 100755 --- a/master/static/srv/check_idle.py +++ b/master/static/srv/check_idle.py @@ -83,7 +83,7 @@ def main(): notebook_server = f"https://127.0.0.1:{port}/proxy/{notebook_id}" master_url = api.canonicalize_master_url(os.environ["DET_MASTER"]) cert = certs.default_load(master_url) - sess = authentication.login_with_cache(master_url, cert=cert) + sess = authentication.login_from_task(master_url, cert=cert) try: idle_type = IdleType[os.environ["NOTEBOOK_IDLE_TYPE"].upper()] except KeyError: diff --git a/master/static/srv/check_ready_logs.py b/master/static/srv/check_ready_logs.py index ee409b10b19..51e1a4ba186 100644 --- a/master/static/srv/check_ready_logs.py +++ b/master/static/srv/check_ready_logs.py @@ -37,7 +37,7 @@ def main(ready: Pattern, waiting: Optional[Pattern] = None): cert = certs.default_load(master_url) # This only runs on-cluster, so it is expected the username and session token are present in the # environment. - sess = authentication.login_with_cache(master_url, cert=cert) + sess = authentication.login_from_task(master_url, cert=cert) allocation_id = str(os.environ["DET_ALLOCATION_ID"]) for line in sys.stdin: if ready.match(line): @@ -49,11 +49,15 @@ def main(ready: Pattern, waiting: Optional[Pattern] = None): if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Read STDIN for a match and mark a task as ready") + parser = argparse.ArgumentParser( + description="Read STDIN for a match and mark a task as ready" + ) parser.add_argument( "--ready-regex", type=str, help="the pattern to match task ready", required=True ) - parser.add_argument("--waiting-regex", type=str, help="the pattern to match task waiting") + parser.add_argument( + "--waiting-regex", type=str, help="the pattern to match task waiting" + ) args = parser.parse_args() ready_regex = re.compile(args.ready_regex)