From 3c371f9fa7bfca4d2be614df8818e57b746a9845 Mon Sep 17 00:00:00 2001 From: Sarah Wooders Date: Wed, 18 Oct 2023 21:11:16 -0700 Subject: [PATCH] fixed bug with passing down authentication --- skyplane/api/client.py | 11 ++++- skyplane/api/obj_store.py | 53 +++++++++++++++++++++---- skyplane/obj_store/s3_interface.py | 5 ++- skyplane/obj_store/storage_interface.py | 16 ++++++-- 4 files changed, 71 insertions(+), 14 deletions(-) diff --git a/skyplane/api/client.py b/skyplane/api/client.py index 526337b24..a1a43143a 100644 --- a/skyplane/api/client.py +++ b/skyplane/api/client.py @@ -99,4 +99,13 @@ def copy(self, src: str, dst: str, recursive: bool = False, max_instances: Optio pipeline.start(progress=True) def object_store(self): - return ObjectStore() + """ + Returns an object store interface + """ + return ObjectStore( + host_uuid=self.clientid, + aws_auth=self.aws_auth, + azure_auth=self.azure_auth, + gcp_auth=self.gcp_auth, + ibmcloud_auth=self.ibmcloud_auth + ) diff --git a/skyplane/api/obj_store.py b/skyplane/api/obj_store.py index cbf57a05b..42808e285 100644 --- a/skyplane/api/obj_store.py +++ b/skyplane/api/obj_store.py @@ -1,20 +1,57 @@ +from skyplane import compute +from typing import Optional from skyplane.obj_store.object_store_interface import ObjectStoreInterface class ObjectStore: - def __init__(self) -> None: - pass + + def __init__( + self, + aws_auth: Optional[compute.AWSAuthentication] = None, + azure_auth: Optional[compute.AzureAuthentication] = None, + gcp_auth: Optional[compute.GCPAuthentication] = None, + host_uuid: Optional[str] = None, + ibmcloud_auth: Optional[compute.IBMCloudAuthentication] = None, + ): + """ + :param aws_auth: authentication information for aws + :type aws_auth: compute.AWSAuthentication + :param azure_auth: authentication information for azure + :type azure_auth: compute.AzureAuthentication + :param gcp_auth: authentication information for gcp + :type gcp_auth: compute.GCPAuthentication + :param host_uuid: the uuid of the local host that requests the provision task + :type host_uuid: string + :param ibmcloud_auth: authentication information for aws + :type ibmcloud_auth: compute.IBMCloudAuthentication + """ + self.aws_auth = aws_auth + self.azure_auth = azure_auth + self.gcp_auth = gcp_auth + self.host_uuid = host_uuid + self.ibmcloud_auth = ibmcloud_auth + + def create_interface(self, provider, bucket_name): + return ObjectStoreInterface.create( + f"{provider}:infer", + bucket_name, + self.aws_auth, + self.azure_auth, + self.gcp_auth, + self.host_uuid, + self.ibmcloud_auth + ) def download_object(self, bucket_name: str, provider: str, key: str, filename: str): - obj_store = ObjectStoreInterface.create(f"{provider}:infer", bucket_name) + obj_store = self.create_interface(provider, bucket_name) obj_store.download_object(key, filename) def upload_object(self, filename: str, bucket_name: str, provider: str, key: str): - obj_store = ObjectStoreInterface.create(f"{provider}:infer", bucket_name) + obj_store = self.create_interface(provider, bucket_name) obj_store.upload_object(filename, key) def exists(self, bucket_name: str, provider: str, key: str) -> bool: - obj_store = ObjectStoreInterface.create(f"{provider}:infer", bucket_name) + obj_store = self.create_interface(provider, bucket_name) return obj_store.exists(key) def bucket_exists(self, bucket_name: str, provider: str) -> bool: @@ -22,7 +59,7 @@ def bucket_exists(self, bucket_name: str, provider: str) -> bool: if provider == "azure": raise NotImplementedError(f"Provider {provider} not implemented") - obj_store = ObjectStoreInterface.create(f"{provider}:infer", bucket_name) + obj_store = self.create_interface(provider, bucket_name) return obj_store.bucket_exists() def create_bucket(self, region: str, bucket_name: str): @@ -31,7 +68,7 @@ def create_bucket(self, region: str, bucket_name: str): if provider == "azure": raise NotImplementedError(f"Provider {provider} not implemented") - obj_store = ObjectStoreInterface.create(region, bucket_name) + obj_store = self.create_interface(provider, bucket_name) obj_store.create_bucket(region.split(":")[1]) # TODO: create util function for this @@ -47,5 +84,5 @@ def delete_bucket(self, bucket_name: str, provider: str): if provider == "azure": raise NotImplementedError(f"Provider {provider} not implemented") - obj_store = ObjectStoreInterface.create(f"{provider}:infer", bucket_name) + obj_store = self.create_interface(provider, bucket_name) obj_store.delete_bucket() diff --git a/skyplane/obj_store/s3_interface.py b/skyplane/obj_store/s3_interface.py index f502cd77e..59b78cf29 100644 --- a/skyplane/obj_store/s3_interface.py +++ b/skyplane/obj_store/s3_interface.py @@ -19,8 +19,9 @@ def full_path(self): class S3Interface(ObjectStoreInterface): - def __init__(self, bucket_name: str): - self.auth = compute.AWSAuthentication() + def __init__(self, bucket_name: str, auth: Optional[compute.AWSAuthentication] = None): + print("PASSED AUTH", auth) + self.auth = compute.AWSAuthentication() if auth is None else auth self.requester_pays = False self.bucket_name = bucket_name self._cached_s3_clients = {} diff --git a/skyplane/obj_store/storage_interface.py b/skyplane/obj_store/storage_interface.py index a430573f8..6abfc247c 100644 --- a/skyplane/obj_store/storage_interface.py +++ b/skyplane/obj_store/storage_interface.py @@ -1,5 +1,6 @@ +from skyplane import compute from skyplane.utils import logger -from typing import Iterator, Any +from typing import Iterator, Any, Optional class StorageInterface: @@ -35,12 +36,21 @@ def list_objects(self, prefix="") -> Iterator[Any]: raise NotImplementedError() @staticmethod - def create(region_tag: str, bucket: str): + def create( + region_tag: str, + bucket: str, + aws_auth: Optional[compute.AWSAuthentication] = None, + azure_auth: Optional[compute.AzureAuthentication] = None, + gcp_auth: Optional[compute.GCPAuthentication] = None, + host_uuid: Optional[str] = None, + ibmcloud_auth: Optional[compute.IBMCloudAuthentication] = None + ): + # TODO: plug in manual setting for all other authentications # TODO: modify this to also support local file if region_tag.startswith("aws"): from skyplane.obj_store.s3_interface import S3Interface - return S3Interface(bucket) + return S3Interface(bucket, aws_auth) elif region_tag.startswith("gcp"): from skyplane.obj_store.gcs_interface import GCSInterface