Skip to content

Commit

Permalink
fixed bug with passing down authentication
Browse files Browse the repository at this point in the history
  • Loading branch information
sarahwooders committed Oct 19, 2023
1 parent 234326d commit 3c371f9
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 14 deletions.
11 changes: 10 additions & 1 deletion skyplane/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,4 +99,13 @@ def copy(self, src: str, dst: str, recursive: bool = False, max_instances: Optio
pipeline.start(progress=True)

def object_store(self):
return ObjectStore()
"""
Returns an object store interface
"""
return ObjectStore(
host_uuid=self.clientid,
aws_auth=self.aws_auth,
azure_auth=self.azure_auth,
gcp_auth=self.gcp_auth,
ibmcloud_auth=self.ibmcloud_auth
)
53 changes: 45 additions & 8 deletions skyplane/api/obj_store.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,65 @@
from skyplane import compute
from typing import Optional
from skyplane.obj_store.object_store_interface import ObjectStoreInterface


class ObjectStore:
def __init__(self) -> None:
pass

def __init__(
self,
aws_auth: Optional[compute.AWSAuthentication] = None,
azure_auth: Optional[compute.AzureAuthentication] = None,
gcp_auth: Optional[compute.GCPAuthentication] = None,
host_uuid: Optional[str] = None,
ibmcloud_auth: Optional[compute.IBMCloudAuthentication] = None,
):
"""
:param aws_auth: authentication information for aws
:type aws_auth: compute.AWSAuthentication
:param azure_auth: authentication information for azure
:type azure_auth: compute.AzureAuthentication
:param gcp_auth: authentication information for gcp
:type gcp_auth: compute.GCPAuthentication
:param host_uuid: the uuid of the local host that requests the provision task
:type host_uuid: string
:param ibmcloud_auth: authentication information for aws
:type ibmcloud_auth: compute.IBMCloudAuthentication
"""
self.aws_auth = aws_auth
self.azure_auth = azure_auth
self.gcp_auth = gcp_auth
self.host_uuid = host_uuid
self.ibmcloud_auth = ibmcloud_auth

def create_interface(self, provider, bucket_name):
return ObjectStoreInterface.create(
f"{provider}:infer",
bucket_name,
self.aws_auth,
self.azure_auth,
self.gcp_auth,
self.host_uuid,
self.ibmcloud_auth
)

def download_object(self, bucket_name: str, provider: str, key: str, filename: str):
obj_store = ObjectStoreInterface.create(f"{provider}:infer", bucket_name)
obj_store = self.create_interface(provider, bucket_name)
obj_store.download_object(key, filename)

def upload_object(self, filename: str, bucket_name: str, provider: str, key: str):
obj_store = ObjectStoreInterface.create(f"{provider}:infer", bucket_name)
obj_store = self.create_interface(provider, bucket_name)
obj_store.upload_object(filename, key)

def exists(self, bucket_name: str, provider: str, key: str) -> bool:
obj_store = ObjectStoreInterface.create(f"{provider}:infer", bucket_name)
obj_store = self.create_interface(provider, bucket_name)
return obj_store.exists(key)

def bucket_exists(self, bucket_name: str, provider: str) -> bool:
# azure not implemented
if provider == "azure":
raise NotImplementedError(f"Provider {provider} not implemented")

obj_store = ObjectStoreInterface.create(f"{provider}:infer", bucket_name)
obj_store = self.create_interface(provider, bucket_name)
return obj_store.bucket_exists()

def create_bucket(self, region: str, bucket_name: str):
Expand All @@ -31,7 +68,7 @@ def create_bucket(self, region: str, bucket_name: str):
if provider == "azure":
raise NotImplementedError(f"Provider {provider} not implemented")

obj_store = ObjectStoreInterface.create(region, bucket_name)
obj_store = self.create_interface(provider, bucket_name)
obj_store.create_bucket(region.split(":")[1])

# TODO: create util function for this
Expand All @@ -47,5 +84,5 @@ def delete_bucket(self, bucket_name: str, provider: str):
if provider == "azure":
raise NotImplementedError(f"Provider {provider} not implemented")

obj_store = ObjectStoreInterface.create(f"{provider}:infer", bucket_name)
obj_store = self.create_interface(provider, bucket_name)
obj_store.delete_bucket()
5 changes: 3 additions & 2 deletions skyplane/obj_store/s3_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ def full_path(self):


class S3Interface(ObjectStoreInterface):
def __init__(self, bucket_name: str):
self.auth = compute.AWSAuthentication()
def __init__(self, bucket_name: str, auth: Optional[compute.AWSAuthentication] = None):
print("PASSED AUTH", auth)
self.auth = compute.AWSAuthentication() if auth is None else auth
self.requester_pays = False
self.bucket_name = bucket_name
self._cached_s3_clients = {}
Expand Down
16 changes: 13 additions & 3 deletions skyplane/obj_store/storage_interface.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from skyplane import compute
from skyplane.utils import logger
from typing import Iterator, Any
from typing import Iterator, Any, Optional


class StorageInterface:
Expand Down Expand Up @@ -35,12 +36,21 @@ def list_objects(self, prefix="") -> Iterator[Any]:
raise NotImplementedError()

@staticmethod
def create(region_tag: str, bucket: str):
def create(
region_tag: str,
bucket: str,
aws_auth: Optional[compute.AWSAuthentication] = None,
azure_auth: Optional[compute.AzureAuthentication] = None,
gcp_auth: Optional[compute.GCPAuthentication] = None,
host_uuid: Optional[str] = None,
ibmcloud_auth: Optional[compute.IBMCloudAuthentication] = None
):
# TODO: plug in manual setting for all other authentications
# TODO: modify this to also support local file
if region_tag.startswith("aws"):
from skyplane.obj_store.s3_interface import S3Interface

return S3Interface(bucket)
return S3Interface(bucket, aws_auth)
elif region_tag.startswith("gcp"):
from skyplane.obj_store.gcs_interface import GCSInterface

Expand Down

0 comments on commit 3c371f9

Please sign in to comment.