From 5a7b3e034d1d2fb409e8efdc660c1c06b8e4b65b Mon Sep 17 00:00:00 2001 From: raphaelreinauer Date: Fri, 30 Sep 2022 23:15:09 +0200 Subject: [PATCH 1/7] get rid of "No TPUs..." message when loading trainer --- gdeep/trainer/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gdeep/trainer/trainer.py b/gdeep/trainer/trainer.py index cae2dcdb..96622496 100644 --- a/gdeep/trainer/trainer.py +++ b/gdeep/trainer/trainer.py @@ -39,7 +39,7 @@ pass print("Using TPU!") except ModuleNotFoundError: - print("No TPUs...") + pass def _add_data_to_tb(func: Callable[[Any], Any]) -> Callable[[Any], Any]: From 39cd823577d4f33e5d9b07781fd7b6a35a7f047a Mon Sep 17 00:00:00 2001 From: raphaelreinauer Date: Sun, 14 Apr 2024 14:42:05 +0200 Subject: [PATCH 2/7] Cleanup --- gdeep/data/datasets/_data_cloud.py | 399 ------------------- gdeep/data/datasets/tests/test_data_cloud.py | 148 ------- gdeep/data/preprocessors/tests/__init__.py | 1 - 3 files changed, 548 deletions(-) delete mode 100644 gdeep/data/datasets/_data_cloud.py delete mode 100644 gdeep/data/datasets/tests/test_data_cloud.py diff --git a/gdeep/data/datasets/_data_cloud.py b/gdeep/data/datasets/_data_cloud.py deleted file mode 100644 index 67792750..00000000 --- a/gdeep/data/datasets/_data_cloud.py +++ /dev/null @@ -1,399 +0,0 @@ -import logging -import os -from os import listdir, makedirs -from os.path import isfile, join, isdir, exists, getsize -import requests # type: ignore -import sys -from typing import Optional, Union, List -import time - -import google -from google.cloud import storage # type: ignore -from google.oauth2 import service_account # type: ignore -import wget # type: ignore - -from gdeep.utility.constants import DEFAULT_DOWNLOAD_DIR, DATASET_BUCKET_NAME -from gdeep.utility.utils import get_checksum - -LOGGER = logging.getLogger(__name__) -LOGGER.setLevel(logging.WARNING) - - -def _check_public_access(use_public_access: bool): - """Check if the public access is enabled.""" - - def wrap(function): - def wrapper_function(*args, **kwargs): - if use_public_access: - raise ValueError("DataCloud object has public access only!") - return function(*args, **kwargs) - - return wrapper_function - - return wrap - - -class _DataCloud: - """Download handle for Google Cloud Storage buckets. - - Args: - bucket_name (str, optional): - Name of the Google Cloud Storage bucket. - Defaults to "adversarial_attack". - download_directory (str, optional): - Directory of the downloaded files. - Defaults to join('examples', 'data', 'DataCloud'). - use_public_access: (bool, optional): - Whether or not to use public api access. - Defaults to True. - path_credentials (str, optional): - Path to the credentials file. - Only used if public_access is False and credentials are not - provided. Defaults to None. - - Raises: - ValueError: If the bucket does not exist. - - Returns: - None - """ - - def __init__( - self, - bucket_name: str = DATASET_BUCKET_NAME, - download_directory: str = DEFAULT_DOWNLOAD_DIR, - use_public_access: bool = True, - path_to_credentials: Union[str, None] = None, - ) -> None: - self.bucket_name = bucket_name - self.use_public_access = use_public_access - if not self.use_public_access: - # Get storage client - if path_to_credentials is None: - self.storage_client = storage.Client() - else: - credentials = service_account.Credentials.from_service_account_file( - path_to_credentials - ) - self.storage_client = storage.Client(credentials=credentials) - self.bucket = self.storage_client.bucket(self.bucket_name) - else: - self.public_url = "https://storage.googleapis.com/" + bucket_name + "/" - - # Set up download path - self.download_directory = download_directory - - # Create a new directory because it does not exist - if not exists(self.download_directory) and self.download_directory != "": - makedirs(self.download_directory) - - def list_blobs(self) -> List[str]: - """List all blobs in the bucket. - - Returns: - List[str]: - List of blobs in the bucket. - """ - # Assert that the bucket does not use public access - if self.use_public_access: - raise ValueError( - "DataCloud object can only list blobs" "when using private access!" - ) - blobs = self.bucket.list_blobs() - return [blob.name for blob in blobs] - - def blob_exists(self, blob_name: str) -> bool: - """Check if a Blob exists in the bucket. - - Args: - blob_name (str): - Name of the Blob to check. - - - Returns: - bool: - True if the Blob exists, False otherwise. - """ - if self.use_public_access: - url = self.public_url + blob_name - response = requests.head(url) - return response.status_code == 200 - else: - blob = self.bucket.blob(blob_name) - return blob.exists() - - def download_file( - self, blob_name: str, download_directory: Union[str, None] = None - ) -> None: - """Download a blob from Google Cloud Storage bucket. - - Args: - source_blob_name (str): - Name of the blob to download. The name is relative to the - root of the bucket. - download_directory (str, optional): - Directory to download the blob to. - - Raises: - ValueError: - If the blob does not exist. - - Returns: - None - """ - url = "" - if download_directory is None: - download_directory = self.download_directory - if self.use_public_access: - url = self.public_url + blob_name - # Check if blob exists - if not self.blob_exists(blob_name): - raise google.api_core.exceptions.NotFound( # type: ignore - "Blob {} does not exist!".format(blob_name) - ) - - # If the file exists, compare checksums - if isfile(join(download_directory, blob_name)): - # Get remote md5 checksum from url in base64 format. - if self.use_public_access: - response = requests.get(url, stream=True) - response.raw.decode_content = True - # Get remote md5 checksum from url in base64 format if it exists. - checksum_remote = response.headers.get("Content-MD5") - else: - blob = self.bucket.blob(blob_name) - checksum_remote = blob.md5_hash - - checksum_local = get_checksum( - join(download_directory, blob_name), encoding="base64" - ) - if checksum_remote is not None: - if checksum_remote != checksum_local: - # Ask user if they want to download the file - answer = input( - f"File {join(download_directory, blob_name)} already" - + "exists and checksums don't match! " - + "Do you want to overwrite it? [y/N]" - ) - if answer.lower() not in ["y", "yes"]: - return - else: - print( - f"File {join(download_directory, blob_name)} " - + "already exists and checksums match! " - + "Skipping download." - ) - return - else: - print( - f"File {join(download_directory, blob_name)} already" - + "exists and remote checksum is " - + "None! Downloading anyway." - ) - # Download file - print("Downloading file {} to {}".format(blob_name, download_directory)) - if self.use_public_access: - wget.download(url, join(download_directory, blob_name)) - else: - self.bucket.blob(blob_name).download_to_filename( - join(download_directory, blob_name), checksum="md5" - ) - - def download_folder(self, blob_name: str) -> None: - """Download a folder from Google Cloud Storage bucket. - - Warning: This function does not download empty subdirectories. - - Args: - blob_name (str): - Name of the blob folder to download. The name is relative to - the root of the bucket. - - Raises: - RuntimeError: - If the folder does not exist. - - Returns: - None - """ - assert not self.use_public_access, ( - "Downloading folders is not" "is not supported with public" "access!" - ) - # Get list of files in the blob - blobs = self.bucket.list_blobs(prefix=blob_name) - for blob in blobs: - # Do not download subdirectories - if blob.name.endswith("/"): - continue - file_split = blob.name.split("/") - directory = "/".join(file_split[0:-1]) - if not exists(directory): - makedirs(join(self.download_directory, directory), exist_ok=True) - logging.getLogger().info("Downloading blob %s", blob.name) - - local_path = ( - blob.name.replace("/", "\\") if sys.platform == "win32" else blob.name - ) - - blob.download_to_filename( - join(self.download_directory, local_path), checksum="md5" - ) - - def upload_file( - self, - source_file_name: str, - target_blob_name: Union[str, None] = None, - make_public: bool = False, - overwrite: bool = False, - ) -> None: - """Upload a local file to Google Cloud Storage bucket. - - Args: - source_file_name (str): - Filename of the local file to upload. - target_blob_name (Union[str, None], optional): - Name of the target blob relative to the root of the bucket. - If None, the filename will be used. - Defaults to None. - make_public (bool, optional): - Whether or not to make the uploaded - file public. Defaults to False. - overwrite (bool, optional): - Whether or not to overwrite the target - Blob. Defaults to False. - - Raises: - RuntimeError: If the target Blob already exists. - - Returns: - None - """ - if target_blob_name is None: - target_blob_name = os.path.basename(source_file_name) - blob = self.bucket.blob(target_blob_name) - if blob.exists() and not overwrite: - raise RuntimeError(f"Blob {target_blob_name} already exists.") - logging.getLogger().info("upload file %s", source_file_name) - # Check if source_file_name is bigger than 5GB - if isfile(source_file_name) and getsize(source_file_name) > 5000000000: - raise ValueError("File is bigger than 5GB") - - # Compute MD5 checksum of the file and add it to the metadata of - # the blob - blob.md5_hash = get_checksum(source_file_name, encoding="base64") - - blob.upload_from_filename(source_file_name, checksum="md5") - if make_public: - blob.make_public() - - def upload_folder( - self, - source_folder: str, - target_folder: Optional[str] = None, - make_public: bool = False, - ) -> None: - """Upload a local folder with all it's subolders to Google - Cloud Storage bucket. - - Args: - source_folder (str): - Folder to upload. - target_folder (Union[str, None], optional): - Name of the target folder relative to the root of the bucket. - If None, the root of the bucket will be used. - Defaults to None. - make_public (bool, optional): - Whether or not to make the uploaded - file public. Defaults to False. - - Raises: - ValueError: - If the source folder is not a directory. - - Returns: - None - """ - if not isdir(source_folder): - raise ValueError("Source folder is not a directory.") - - if target_folder is None: - target_folder = "" - - # List of all files in the source folder - files = [ - join(source_folder, f) - for f in listdir(source_folder) - if isfile(join(source_folder, f)) - ] - print(files) - # Upload all files in the source folder - for file in files: - file_name = os.path.basename(file) - if target_folder == "": - self.upload_file( - join(source_folder, file_name), - target_blob_name=file_name, - make_public=make_public, - ) - else: - self.upload_file( - join(source_folder, file_name), - target_blob_name=target_folder + "/" + file_name, - make_public=make_public, - ) - - # List of all subfolders in the source folder - subfolders = [ - join(source_folder, f) - for f in listdir(source_folder) - if isdir(join(source_folder, f)) - ] - # Upload all subfolders in the source folder recursively - for subfolder in subfolders: - relative_subfolder = os.path.relpath(subfolder, source_folder) - if target_folder == "": - self.upload_folder( - join(source_folder, relative_subfolder), - target_folder=relative_subfolder, - make_public=make_public, - ) - else: - self.upload_folder( - join(source_folder, relative_subfolder), - target_folder=target_folder + "/" + relative_subfolder, - make_public=make_public, - ) - - def delete_blob(self, blob_name: str) -> None: - """Deletes a single Blob from Google Cloud Storage - - Args: - blob_name (str): - The name of the Blob to delete - - Raises: - RuntimeError: If the Blob does not exist. - - Returns: - None - """ - blob = self.bucket.blob(blob_name) - blob.delete() - - def delete_blobs(self, blobs_name: str) -> None: - """Deletes a Blob and all its children from Google Cloud Storage. - - Args: - blobs_name (str): - Name of the parent Blob to delete. - - Raises: - ValueError: - If the Blob does not exist. - - Returns: - None - """ - blobs = self.bucket.list_blobs(prefix=blobs_name) - for blob in blobs: - blob.delete() diff --git a/gdeep/data/datasets/tests/test_data_cloud.py b/gdeep/data/datasets/tests/test_data_cloud.py deleted file mode 100644 index 8c39a8d9..00000000 --- a/gdeep/data/datasets/tests/test_data_cloud.py +++ /dev/null @@ -1,148 +0,0 @@ -# %% -from os import remove, makedirs, environ -from os.path import join, exists -import logging - -import google # type: ignore -from google.cloud import storage # type: ignore -from google.cloud.storage import Bucket # type: ignore -from google.auth.exceptions import DefaultCredentialsError # type: ignore -import hashlib -import pytest -import random -from shutil import rmtree - -from gdeep.data.datasets import _DataCloud -from gdeep.utility.utils import get_checksum -from gdeep.utility import DATASET_BUCKET_NAME - -LOGGER = logging.getLogger(__name__) - - -if "GOOGLE_APPLICATION_CREDENTIALS" in dict(environ): - # Check if the credentials are valid and if the bucket can be accessed - client = storage.Client() - if Bucket(client, DATASET_BUCKET_NAME).exists(): - - def test_download(): - """Test download of sample data from bucket""" - data_cloud = _DataCloud(use_public_access=False) - file_name = "giotto-deep-big.png" - data_cloud.download_file(file_name) - - # check if correct extension is raised when trying to download - # non-existing file - non_existing_file_name: str = "giotto-deep-bigs.png" - with pytest.raises(google.api_core.exceptions.NotFound): # type: ignore - data_cloud.download_file(non_existing_file_name) - - # check if downloaded file exists - file_path = join(data_cloud.download_directory, file_name) - assert exists(file_path) - - # check if downloaded file is correct - assert "d4b12b2dc2bc199831ba803431184fcb" == get_checksum(file_path) - - remove(join(data_cloud.download_directory, file_name)) - - def test_upload(): - """Test upload of sample file to bucket.""" - data_cloud = _DataCloud(use_public_access=False) - - # create temporary file to upload to bucket - sample_file_name = "tmp.txt" - sample_text = "Create a new tmp file!" + str(random.randint(0, 1_000)) - - if exists(sample_file_name): - remove(sample_file_name) - with open(sample_file_name, "w") as f: - f.write(sample_text) - - assert exists(sample_file_name) - - # upload sample file to bucket - data_cloud.upload_file(sample_file_name) - - # check if uploading to an already existing file raises exception - with pytest.raises(RuntimeError): - data_cloud.upload_file(sample_file_name) - - # delete local temporary file - remove(sample_file_name) - - data_cloud.download_file(sample_file_name) - - data_cloud.delete_blob(sample_file_name) - - # check if downloaded file exists - file_path = join(data_cloud.download_directory, sample_file_name) - assert exists(file_path) - - with open(file_path, "r") as f: - assert f.read() == sample_text - - remove(file_path) - - def test_upload_folder(): - """Test the upload of a folder to bucket and download the - folder.""" - data_cloud = _DataCloud(use_public_access=False) - - # create temporary folder structure and temporary file to upload - # to bucket - # tmp: tmp.txt - # |- sub_tmp_1: tmp1.txt - # |- sub_tmp_2: tmp2.txt - # |- sub_tmp_2_2: tmp2_2.txt - if exists("tmp"): - rmtree("tmp") - - tmp_files = [] - - sample_dir = "tmp" - makedirs(sample_dir) - tmp_files.append(join(sample_dir, "tmp.txt")) - - sub_1_sample_dir = join(sample_dir, "sub_tmp_1") - makedirs(sub_1_sample_dir) - tmp_files.append(join(sub_1_sample_dir, "tmp1.txt")) - - sub_2_sample_dir = join(sample_dir, "sub_tmp_2") - makedirs(sub_2_sample_dir) - tmp_files.append(join(sub_2_sample_dir, "tmp2.txt")) - - sub_2_2_sample_dir = join(sub_2_sample_dir, "sub_tmp_2_2") - makedirs(sub_2_2_sample_dir) - tmp_files.append(join(sub_2_2_sample_dir, "tmp2_2.txt")) - - sample_texts = {} - - for file in tmp_files: - if exists(file): - remove(file) - sample_text = "Create a new tmp file! " + str(random.randint(0, 1_000)) - sample_texts[file] = sample_text - with open(file, "w") as f: - f.write(sample_text) - assert exists(file) - - # upload sample file to bucket - data_cloud.upload_folder(sample_dir, "tmp") - - # delete local tmp folder - rmtree(sample_dir) - - # download folder to local - data_cloud.download_folder(sample_dir) - - # delete folder in bucket - data_cloud.delete_blobs("tmp") - - # check if downloaded folder is correct - - for file in sample_texts.keys(): - with open(join(data_cloud.download_directory, file), "r") as f: - assert f.read() == sample_texts[file] - - # delete local tmp folder - rmtree(join(data_cloud.download_directory, sample_dir)) diff --git a/gdeep/data/preprocessors/tests/__init__.py b/gdeep/data/preprocessors/tests/__init__.py index 3d57db5a..e69de29b 100644 --- a/gdeep/data/preprocessors/tests/__init__.py +++ b/gdeep/data/preprocessors/tests/__init__.py @@ -1 +0,0 @@ -# from ..data_cloud import DataCloud From 01c89757f0669c9c81c7de1efca431c7fce71f55 Mon Sep 17 00:00:00 2001 From: raphaelreinauer Date: Sun, 14 Apr 2024 17:15:04 +0200 Subject: [PATCH 3/7] Revert "Cleanup" This reverts commit 39cd823577d4f33e5d9b07781fd7b6a35a7f047a. --- gdeep/data/datasets/_data_cloud.py | 399 +++++++++++++++++++ gdeep/data/datasets/tests/test_data_cloud.py | 148 +++++++ gdeep/data/preprocessors/tests/__init__.py | 1 + 3 files changed, 548 insertions(+) create mode 100644 gdeep/data/datasets/_data_cloud.py create mode 100644 gdeep/data/datasets/tests/test_data_cloud.py diff --git a/gdeep/data/datasets/_data_cloud.py b/gdeep/data/datasets/_data_cloud.py new file mode 100644 index 00000000..67792750 --- /dev/null +++ b/gdeep/data/datasets/_data_cloud.py @@ -0,0 +1,399 @@ +import logging +import os +from os import listdir, makedirs +from os.path import isfile, join, isdir, exists, getsize +import requests # type: ignore +import sys +from typing import Optional, Union, List +import time + +import google +from google.cloud import storage # type: ignore +from google.oauth2 import service_account # type: ignore +import wget # type: ignore + +from gdeep.utility.constants import DEFAULT_DOWNLOAD_DIR, DATASET_BUCKET_NAME +from gdeep.utility.utils import get_checksum + +LOGGER = logging.getLogger(__name__) +LOGGER.setLevel(logging.WARNING) + + +def _check_public_access(use_public_access: bool): + """Check if the public access is enabled.""" + + def wrap(function): + def wrapper_function(*args, **kwargs): + if use_public_access: + raise ValueError("DataCloud object has public access only!") + return function(*args, **kwargs) + + return wrapper_function + + return wrap + + +class _DataCloud: + """Download handle for Google Cloud Storage buckets. + + Args: + bucket_name (str, optional): + Name of the Google Cloud Storage bucket. + Defaults to "adversarial_attack". + download_directory (str, optional): + Directory of the downloaded files. + Defaults to join('examples', 'data', 'DataCloud'). + use_public_access: (bool, optional): + Whether or not to use public api access. + Defaults to True. + path_credentials (str, optional): + Path to the credentials file. + Only used if public_access is False and credentials are not + provided. Defaults to None. + + Raises: + ValueError: If the bucket does not exist. + + Returns: + None + """ + + def __init__( + self, + bucket_name: str = DATASET_BUCKET_NAME, + download_directory: str = DEFAULT_DOWNLOAD_DIR, + use_public_access: bool = True, + path_to_credentials: Union[str, None] = None, + ) -> None: + self.bucket_name = bucket_name + self.use_public_access = use_public_access + if not self.use_public_access: + # Get storage client + if path_to_credentials is None: + self.storage_client = storage.Client() + else: + credentials = service_account.Credentials.from_service_account_file( + path_to_credentials + ) + self.storage_client = storage.Client(credentials=credentials) + self.bucket = self.storage_client.bucket(self.bucket_name) + else: + self.public_url = "https://storage.googleapis.com/" + bucket_name + "/" + + # Set up download path + self.download_directory = download_directory + + # Create a new directory because it does not exist + if not exists(self.download_directory) and self.download_directory != "": + makedirs(self.download_directory) + + def list_blobs(self) -> List[str]: + """List all blobs in the bucket. + + Returns: + List[str]: + List of blobs in the bucket. + """ + # Assert that the bucket does not use public access + if self.use_public_access: + raise ValueError( + "DataCloud object can only list blobs" "when using private access!" + ) + blobs = self.bucket.list_blobs() + return [blob.name for blob in blobs] + + def blob_exists(self, blob_name: str) -> bool: + """Check if a Blob exists in the bucket. + + Args: + blob_name (str): + Name of the Blob to check. + + + Returns: + bool: + True if the Blob exists, False otherwise. + """ + if self.use_public_access: + url = self.public_url + blob_name + response = requests.head(url) + return response.status_code == 200 + else: + blob = self.bucket.blob(blob_name) + return blob.exists() + + def download_file( + self, blob_name: str, download_directory: Union[str, None] = None + ) -> None: + """Download a blob from Google Cloud Storage bucket. + + Args: + source_blob_name (str): + Name of the blob to download. The name is relative to the + root of the bucket. + download_directory (str, optional): + Directory to download the blob to. + + Raises: + ValueError: + If the blob does not exist. + + Returns: + None + """ + url = "" + if download_directory is None: + download_directory = self.download_directory + if self.use_public_access: + url = self.public_url + blob_name + # Check if blob exists + if not self.blob_exists(blob_name): + raise google.api_core.exceptions.NotFound( # type: ignore + "Blob {} does not exist!".format(blob_name) + ) + + # If the file exists, compare checksums + if isfile(join(download_directory, blob_name)): + # Get remote md5 checksum from url in base64 format. + if self.use_public_access: + response = requests.get(url, stream=True) + response.raw.decode_content = True + # Get remote md5 checksum from url in base64 format if it exists. + checksum_remote = response.headers.get("Content-MD5") + else: + blob = self.bucket.blob(blob_name) + checksum_remote = blob.md5_hash + + checksum_local = get_checksum( + join(download_directory, blob_name), encoding="base64" + ) + if checksum_remote is not None: + if checksum_remote != checksum_local: + # Ask user if they want to download the file + answer = input( + f"File {join(download_directory, blob_name)} already" + + "exists and checksums don't match! " + + "Do you want to overwrite it? [y/N]" + ) + if answer.lower() not in ["y", "yes"]: + return + else: + print( + f"File {join(download_directory, blob_name)} " + + "already exists and checksums match! " + + "Skipping download." + ) + return + else: + print( + f"File {join(download_directory, blob_name)} already" + + "exists and remote checksum is " + + "None! Downloading anyway." + ) + # Download file + print("Downloading file {} to {}".format(blob_name, download_directory)) + if self.use_public_access: + wget.download(url, join(download_directory, blob_name)) + else: + self.bucket.blob(blob_name).download_to_filename( + join(download_directory, blob_name), checksum="md5" + ) + + def download_folder(self, blob_name: str) -> None: + """Download a folder from Google Cloud Storage bucket. + + Warning: This function does not download empty subdirectories. + + Args: + blob_name (str): + Name of the blob folder to download. The name is relative to + the root of the bucket. + + Raises: + RuntimeError: + If the folder does not exist. + + Returns: + None + """ + assert not self.use_public_access, ( + "Downloading folders is not" "is not supported with public" "access!" + ) + # Get list of files in the blob + blobs = self.bucket.list_blobs(prefix=blob_name) + for blob in blobs: + # Do not download subdirectories + if blob.name.endswith("/"): + continue + file_split = blob.name.split("/") + directory = "/".join(file_split[0:-1]) + if not exists(directory): + makedirs(join(self.download_directory, directory), exist_ok=True) + logging.getLogger().info("Downloading blob %s", blob.name) + + local_path = ( + blob.name.replace("/", "\\") if sys.platform == "win32" else blob.name + ) + + blob.download_to_filename( + join(self.download_directory, local_path), checksum="md5" + ) + + def upload_file( + self, + source_file_name: str, + target_blob_name: Union[str, None] = None, + make_public: bool = False, + overwrite: bool = False, + ) -> None: + """Upload a local file to Google Cloud Storage bucket. + + Args: + source_file_name (str): + Filename of the local file to upload. + target_blob_name (Union[str, None], optional): + Name of the target blob relative to the root of the bucket. + If None, the filename will be used. + Defaults to None. + make_public (bool, optional): + Whether or not to make the uploaded + file public. Defaults to False. + overwrite (bool, optional): + Whether or not to overwrite the target + Blob. Defaults to False. + + Raises: + RuntimeError: If the target Blob already exists. + + Returns: + None + """ + if target_blob_name is None: + target_blob_name = os.path.basename(source_file_name) + blob = self.bucket.blob(target_blob_name) + if blob.exists() and not overwrite: + raise RuntimeError(f"Blob {target_blob_name} already exists.") + logging.getLogger().info("upload file %s", source_file_name) + # Check if source_file_name is bigger than 5GB + if isfile(source_file_name) and getsize(source_file_name) > 5000000000: + raise ValueError("File is bigger than 5GB") + + # Compute MD5 checksum of the file and add it to the metadata of + # the blob + blob.md5_hash = get_checksum(source_file_name, encoding="base64") + + blob.upload_from_filename(source_file_name, checksum="md5") + if make_public: + blob.make_public() + + def upload_folder( + self, + source_folder: str, + target_folder: Optional[str] = None, + make_public: bool = False, + ) -> None: + """Upload a local folder with all it's subolders to Google + Cloud Storage bucket. + + Args: + source_folder (str): + Folder to upload. + target_folder (Union[str, None], optional): + Name of the target folder relative to the root of the bucket. + If None, the root of the bucket will be used. + Defaults to None. + make_public (bool, optional): + Whether or not to make the uploaded + file public. Defaults to False. + + Raises: + ValueError: + If the source folder is not a directory. + + Returns: + None + """ + if not isdir(source_folder): + raise ValueError("Source folder is not a directory.") + + if target_folder is None: + target_folder = "" + + # List of all files in the source folder + files = [ + join(source_folder, f) + for f in listdir(source_folder) + if isfile(join(source_folder, f)) + ] + print(files) + # Upload all files in the source folder + for file in files: + file_name = os.path.basename(file) + if target_folder == "": + self.upload_file( + join(source_folder, file_name), + target_blob_name=file_name, + make_public=make_public, + ) + else: + self.upload_file( + join(source_folder, file_name), + target_blob_name=target_folder + "/" + file_name, + make_public=make_public, + ) + + # List of all subfolders in the source folder + subfolders = [ + join(source_folder, f) + for f in listdir(source_folder) + if isdir(join(source_folder, f)) + ] + # Upload all subfolders in the source folder recursively + for subfolder in subfolders: + relative_subfolder = os.path.relpath(subfolder, source_folder) + if target_folder == "": + self.upload_folder( + join(source_folder, relative_subfolder), + target_folder=relative_subfolder, + make_public=make_public, + ) + else: + self.upload_folder( + join(source_folder, relative_subfolder), + target_folder=target_folder + "/" + relative_subfolder, + make_public=make_public, + ) + + def delete_blob(self, blob_name: str) -> None: + """Deletes a single Blob from Google Cloud Storage + + Args: + blob_name (str): + The name of the Blob to delete + + Raises: + RuntimeError: If the Blob does not exist. + + Returns: + None + """ + blob = self.bucket.blob(blob_name) + blob.delete() + + def delete_blobs(self, blobs_name: str) -> None: + """Deletes a Blob and all its children from Google Cloud Storage. + + Args: + blobs_name (str): + Name of the parent Blob to delete. + + Raises: + ValueError: + If the Blob does not exist. + + Returns: + None + """ + blobs = self.bucket.list_blobs(prefix=blobs_name) + for blob in blobs: + blob.delete() diff --git a/gdeep/data/datasets/tests/test_data_cloud.py b/gdeep/data/datasets/tests/test_data_cloud.py new file mode 100644 index 00000000..8c39a8d9 --- /dev/null +++ b/gdeep/data/datasets/tests/test_data_cloud.py @@ -0,0 +1,148 @@ +# %% +from os import remove, makedirs, environ +from os.path import join, exists +import logging + +import google # type: ignore +from google.cloud import storage # type: ignore +from google.cloud.storage import Bucket # type: ignore +from google.auth.exceptions import DefaultCredentialsError # type: ignore +import hashlib +import pytest +import random +from shutil import rmtree + +from gdeep.data.datasets import _DataCloud +from gdeep.utility.utils import get_checksum +from gdeep.utility import DATASET_BUCKET_NAME + +LOGGER = logging.getLogger(__name__) + + +if "GOOGLE_APPLICATION_CREDENTIALS" in dict(environ): + # Check if the credentials are valid and if the bucket can be accessed + client = storage.Client() + if Bucket(client, DATASET_BUCKET_NAME).exists(): + + def test_download(): + """Test download of sample data from bucket""" + data_cloud = _DataCloud(use_public_access=False) + file_name = "giotto-deep-big.png" + data_cloud.download_file(file_name) + + # check if correct extension is raised when trying to download + # non-existing file + non_existing_file_name: str = "giotto-deep-bigs.png" + with pytest.raises(google.api_core.exceptions.NotFound): # type: ignore + data_cloud.download_file(non_existing_file_name) + + # check if downloaded file exists + file_path = join(data_cloud.download_directory, file_name) + assert exists(file_path) + + # check if downloaded file is correct + assert "d4b12b2dc2bc199831ba803431184fcb" == get_checksum(file_path) + + remove(join(data_cloud.download_directory, file_name)) + + def test_upload(): + """Test upload of sample file to bucket.""" + data_cloud = _DataCloud(use_public_access=False) + + # create temporary file to upload to bucket + sample_file_name = "tmp.txt" + sample_text = "Create a new tmp file!" + str(random.randint(0, 1_000)) + + if exists(sample_file_name): + remove(sample_file_name) + with open(sample_file_name, "w") as f: + f.write(sample_text) + + assert exists(sample_file_name) + + # upload sample file to bucket + data_cloud.upload_file(sample_file_name) + + # check if uploading to an already existing file raises exception + with pytest.raises(RuntimeError): + data_cloud.upload_file(sample_file_name) + + # delete local temporary file + remove(sample_file_name) + + data_cloud.download_file(sample_file_name) + + data_cloud.delete_blob(sample_file_name) + + # check if downloaded file exists + file_path = join(data_cloud.download_directory, sample_file_name) + assert exists(file_path) + + with open(file_path, "r") as f: + assert f.read() == sample_text + + remove(file_path) + + def test_upload_folder(): + """Test the upload of a folder to bucket and download the + folder.""" + data_cloud = _DataCloud(use_public_access=False) + + # create temporary folder structure and temporary file to upload + # to bucket + # tmp: tmp.txt + # |- sub_tmp_1: tmp1.txt + # |- sub_tmp_2: tmp2.txt + # |- sub_tmp_2_2: tmp2_2.txt + if exists("tmp"): + rmtree("tmp") + + tmp_files = [] + + sample_dir = "tmp" + makedirs(sample_dir) + tmp_files.append(join(sample_dir, "tmp.txt")) + + sub_1_sample_dir = join(sample_dir, "sub_tmp_1") + makedirs(sub_1_sample_dir) + tmp_files.append(join(sub_1_sample_dir, "tmp1.txt")) + + sub_2_sample_dir = join(sample_dir, "sub_tmp_2") + makedirs(sub_2_sample_dir) + tmp_files.append(join(sub_2_sample_dir, "tmp2.txt")) + + sub_2_2_sample_dir = join(sub_2_sample_dir, "sub_tmp_2_2") + makedirs(sub_2_2_sample_dir) + tmp_files.append(join(sub_2_2_sample_dir, "tmp2_2.txt")) + + sample_texts = {} + + for file in tmp_files: + if exists(file): + remove(file) + sample_text = "Create a new tmp file! " + str(random.randint(0, 1_000)) + sample_texts[file] = sample_text + with open(file, "w") as f: + f.write(sample_text) + assert exists(file) + + # upload sample file to bucket + data_cloud.upload_folder(sample_dir, "tmp") + + # delete local tmp folder + rmtree(sample_dir) + + # download folder to local + data_cloud.download_folder(sample_dir) + + # delete folder in bucket + data_cloud.delete_blobs("tmp") + + # check if downloaded folder is correct + + for file in sample_texts.keys(): + with open(join(data_cloud.download_directory, file), "r") as f: + assert f.read() == sample_texts[file] + + # delete local tmp folder + rmtree(join(data_cloud.download_directory, sample_dir)) diff --git a/gdeep/data/preprocessors/tests/__init__.py b/gdeep/data/preprocessors/tests/__init__.py index e69de29b..3d57db5a 100644 --- a/gdeep/data/preprocessors/tests/__init__.py +++ b/gdeep/data/preprocessors/tests/__init__.py @@ -0,0 +1 @@ +# from ..data_cloud import DataCloud From 3b6d774f17e5ad6aaf2ceb12440d863a4ba6373f Mon Sep 17 00:00:00 2001 From: raphaelreinauer Date: Sun, 14 Apr 2024 21:51:56 +0200 Subject: [PATCH 4/7] Add DatasetUploader and refactored DatasetCloud --- gdeep/data/datasets/_data_cloud.py | 399 --------------- .../datasets/{ => cloud}/dataloader_cloud.py | 8 +- gdeep/data/datasets/cloud/dataset_cloud.py | 44 ++ gdeep/data/datasets/cloud/utils.py | 8 + gdeep/data/datasets/dataset_cloud.py | 470 ------------------ gdeep/data/datasets/dataset_uploader.py | 48 ++ .../data/datasets/tests/test_dataset_cloud.py | 47 +- gdeep/utility/constants.py | 5 +- requirements.txt | 2 + 9 files changed, 127 insertions(+), 904 deletions(-) delete mode 100644 gdeep/data/datasets/_data_cloud.py rename gdeep/data/datasets/{ => cloud}/dataloader_cloud.py (96%) create mode 100644 gdeep/data/datasets/cloud/dataset_cloud.py create mode 100644 gdeep/data/datasets/cloud/utils.py delete mode 100644 gdeep/data/datasets/dataset_cloud.py create mode 100644 gdeep/data/datasets/dataset_uploader.py diff --git a/gdeep/data/datasets/_data_cloud.py b/gdeep/data/datasets/_data_cloud.py deleted file mode 100644 index 67792750..00000000 --- a/gdeep/data/datasets/_data_cloud.py +++ /dev/null @@ -1,399 +0,0 @@ -import logging -import os -from os import listdir, makedirs -from os.path import isfile, join, isdir, exists, getsize -import requests # type: ignore -import sys -from typing import Optional, Union, List -import time - -import google -from google.cloud import storage # type: ignore -from google.oauth2 import service_account # type: ignore -import wget # type: ignore - -from gdeep.utility.constants import DEFAULT_DOWNLOAD_DIR, DATASET_BUCKET_NAME -from gdeep.utility.utils import get_checksum - -LOGGER = logging.getLogger(__name__) -LOGGER.setLevel(logging.WARNING) - - -def _check_public_access(use_public_access: bool): - """Check if the public access is enabled.""" - - def wrap(function): - def wrapper_function(*args, **kwargs): - if use_public_access: - raise ValueError("DataCloud object has public access only!") - return function(*args, **kwargs) - - return wrapper_function - - return wrap - - -class _DataCloud: - """Download handle for Google Cloud Storage buckets. - - Args: - bucket_name (str, optional): - Name of the Google Cloud Storage bucket. - Defaults to "adversarial_attack". - download_directory (str, optional): - Directory of the downloaded files. - Defaults to join('examples', 'data', 'DataCloud'). - use_public_access: (bool, optional): - Whether or not to use public api access. - Defaults to True. - path_credentials (str, optional): - Path to the credentials file. - Only used if public_access is False and credentials are not - provided. Defaults to None. - - Raises: - ValueError: If the bucket does not exist. - - Returns: - None - """ - - def __init__( - self, - bucket_name: str = DATASET_BUCKET_NAME, - download_directory: str = DEFAULT_DOWNLOAD_DIR, - use_public_access: bool = True, - path_to_credentials: Union[str, None] = None, - ) -> None: - self.bucket_name = bucket_name - self.use_public_access = use_public_access - if not self.use_public_access: - # Get storage client - if path_to_credentials is None: - self.storage_client = storage.Client() - else: - credentials = service_account.Credentials.from_service_account_file( - path_to_credentials - ) - self.storage_client = storage.Client(credentials=credentials) - self.bucket = self.storage_client.bucket(self.bucket_name) - else: - self.public_url = "https://storage.googleapis.com/" + bucket_name + "/" - - # Set up download path - self.download_directory = download_directory - - # Create a new directory because it does not exist - if not exists(self.download_directory) and self.download_directory != "": - makedirs(self.download_directory) - - def list_blobs(self) -> List[str]: - """List all blobs in the bucket. - - Returns: - List[str]: - List of blobs in the bucket. - """ - # Assert that the bucket does not use public access - if self.use_public_access: - raise ValueError( - "DataCloud object can only list blobs" "when using private access!" - ) - blobs = self.bucket.list_blobs() - return [blob.name for blob in blobs] - - def blob_exists(self, blob_name: str) -> bool: - """Check if a Blob exists in the bucket. - - Args: - blob_name (str): - Name of the Blob to check. - - - Returns: - bool: - True if the Blob exists, False otherwise. - """ - if self.use_public_access: - url = self.public_url + blob_name - response = requests.head(url) - return response.status_code == 200 - else: - blob = self.bucket.blob(blob_name) - return blob.exists() - - def download_file( - self, blob_name: str, download_directory: Union[str, None] = None - ) -> None: - """Download a blob from Google Cloud Storage bucket. - - Args: - source_blob_name (str): - Name of the blob to download. The name is relative to the - root of the bucket. - download_directory (str, optional): - Directory to download the blob to. - - Raises: - ValueError: - If the blob does not exist. - - Returns: - None - """ - url = "" - if download_directory is None: - download_directory = self.download_directory - if self.use_public_access: - url = self.public_url + blob_name - # Check if blob exists - if not self.blob_exists(blob_name): - raise google.api_core.exceptions.NotFound( # type: ignore - "Blob {} does not exist!".format(blob_name) - ) - - # If the file exists, compare checksums - if isfile(join(download_directory, blob_name)): - # Get remote md5 checksum from url in base64 format. - if self.use_public_access: - response = requests.get(url, stream=True) - response.raw.decode_content = True - # Get remote md5 checksum from url in base64 format if it exists. - checksum_remote = response.headers.get("Content-MD5") - else: - blob = self.bucket.blob(blob_name) - checksum_remote = blob.md5_hash - - checksum_local = get_checksum( - join(download_directory, blob_name), encoding="base64" - ) - if checksum_remote is not None: - if checksum_remote != checksum_local: - # Ask user if they want to download the file - answer = input( - f"File {join(download_directory, blob_name)} already" - + "exists and checksums don't match! " - + "Do you want to overwrite it? [y/N]" - ) - if answer.lower() not in ["y", "yes"]: - return - else: - print( - f"File {join(download_directory, blob_name)} " - + "already exists and checksums match! " - + "Skipping download." - ) - return - else: - print( - f"File {join(download_directory, blob_name)} already" - + "exists and remote checksum is " - + "None! Downloading anyway." - ) - # Download file - print("Downloading file {} to {}".format(blob_name, download_directory)) - if self.use_public_access: - wget.download(url, join(download_directory, blob_name)) - else: - self.bucket.blob(blob_name).download_to_filename( - join(download_directory, blob_name), checksum="md5" - ) - - def download_folder(self, blob_name: str) -> None: - """Download a folder from Google Cloud Storage bucket. - - Warning: This function does not download empty subdirectories. - - Args: - blob_name (str): - Name of the blob folder to download. The name is relative to - the root of the bucket. - - Raises: - RuntimeError: - If the folder does not exist. - - Returns: - None - """ - assert not self.use_public_access, ( - "Downloading folders is not" "is not supported with public" "access!" - ) - # Get list of files in the blob - blobs = self.bucket.list_blobs(prefix=blob_name) - for blob in blobs: - # Do not download subdirectories - if blob.name.endswith("/"): - continue - file_split = blob.name.split("/") - directory = "/".join(file_split[0:-1]) - if not exists(directory): - makedirs(join(self.download_directory, directory), exist_ok=True) - logging.getLogger().info("Downloading blob %s", blob.name) - - local_path = ( - blob.name.replace("/", "\\") if sys.platform == "win32" else blob.name - ) - - blob.download_to_filename( - join(self.download_directory, local_path), checksum="md5" - ) - - def upload_file( - self, - source_file_name: str, - target_blob_name: Union[str, None] = None, - make_public: bool = False, - overwrite: bool = False, - ) -> None: - """Upload a local file to Google Cloud Storage bucket. - - Args: - source_file_name (str): - Filename of the local file to upload. - target_blob_name (Union[str, None], optional): - Name of the target blob relative to the root of the bucket. - If None, the filename will be used. - Defaults to None. - make_public (bool, optional): - Whether or not to make the uploaded - file public. Defaults to False. - overwrite (bool, optional): - Whether or not to overwrite the target - Blob. Defaults to False. - - Raises: - RuntimeError: If the target Blob already exists. - - Returns: - None - """ - if target_blob_name is None: - target_blob_name = os.path.basename(source_file_name) - blob = self.bucket.blob(target_blob_name) - if blob.exists() and not overwrite: - raise RuntimeError(f"Blob {target_blob_name} already exists.") - logging.getLogger().info("upload file %s", source_file_name) - # Check if source_file_name is bigger than 5GB - if isfile(source_file_name) and getsize(source_file_name) > 5000000000: - raise ValueError("File is bigger than 5GB") - - # Compute MD5 checksum of the file and add it to the metadata of - # the blob - blob.md5_hash = get_checksum(source_file_name, encoding="base64") - - blob.upload_from_filename(source_file_name, checksum="md5") - if make_public: - blob.make_public() - - def upload_folder( - self, - source_folder: str, - target_folder: Optional[str] = None, - make_public: bool = False, - ) -> None: - """Upload a local folder with all it's subolders to Google - Cloud Storage bucket. - - Args: - source_folder (str): - Folder to upload. - target_folder (Union[str, None], optional): - Name of the target folder relative to the root of the bucket. - If None, the root of the bucket will be used. - Defaults to None. - make_public (bool, optional): - Whether or not to make the uploaded - file public. Defaults to False. - - Raises: - ValueError: - If the source folder is not a directory. - - Returns: - None - """ - if not isdir(source_folder): - raise ValueError("Source folder is not a directory.") - - if target_folder is None: - target_folder = "" - - # List of all files in the source folder - files = [ - join(source_folder, f) - for f in listdir(source_folder) - if isfile(join(source_folder, f)) - ] - print(files) - # Upload all files in the source folder - for file in files: - file_name = os.path.basename(file) - if target_folder == "": - self.upload_file( - join(source_folder, file_name), - target_blob_name=file_name, - make_public=make_public, - ) - else: - self.upload_file( - join(source_folder, file_name), - target_blob_name=target_folder + "/" + file_name, - make_public=make_public, - ) - - # List of all subfolders in the source folder - subfolders = [ - join(source_folder, f) - for f in listdir(source_folder) - if isdir(join(source_folder, f)) - ] - # Upload all subfolders in the source folder recursively - for subfolder in subfolders: - relative_subfolder = os.path.relpath(subfolder, source_folder) - if target_folder == "": - self.upload_folder( - join(source_folder, relative_subfolder), - target_folder=relative_subfolder, - make_public=make_public, - ) - else: - self.upload_folder( - join(source_folder, relative_subfolder), - target_folder=target_folder + "/" + relative_subfolder, - make_public=make_public, - ) - - def delete_blob(self, blob_name: str) -> None: - """Deletes a single Blob from Google Cloud Storage - - Args: - blob_name (str): - The name of the Blob to delete - - Raises: - RuntimeError: If the Blob does not exist. - - Returns: - None - """ - blob = self.bucket.blob(blob_name) - blob.delete() - - def delete_blobs(self, blobs_name: str) -> None: - """Deletes a Blob and all its children from Google Cloud Storage. - - Args: - blobs_name (str): - Name of the parent Blob to delete. - - Raises: - ValueError: - If the Blob does not exist. - - Returns: - None - """ - blobs = self.bucket.list_blobs(prefix=blobs_name) - for blob in blobs: - blob.delete() diff --git a/gdeep/data/datasets/dataloader_cloud.py b/gdeep/data/datasets/cloud/dataloader_cloud.py similarity index 96% rename from gdeep/data/datasets/dataloader_cloud.py rename to gdeep/data/datasets/cloud/dataloader_cloud.py index 36d9191a..9292886a 100644 --- a/gdeep/data/datasets/dataloader_cloud.py +++ b/gdeep/data/datasets/cloud/dataloader_cloud.py @@ -8,10 +8,10 @@ import torch from torch.utils.data import DataLoader -from .dataset_form_array import FromArray +from ..dataset_form_array import FromArray from .dataset_cloud import DatasetCloud -from .base_dataloaders import DataLoaderBuilder -from .base_dataloaders import AbstractDataLoaderBuilder +from ..base_dataloaders import DataLoaderBuilder +from ..base_dataloaders import AbstractDataLoaderBuilder from gdeep.utility.custom_types import Tensor @@ -161,7 +161,7 @@ def _download_dataset( ) dataset_cloud = DatasetCloud( self.dataset_name, - download_directory=self.download_directory, + root_download_directory=self.download_directory, path_to_credentials=path_to_credentials, use_public_access=use_public_access, ) diff --git a/gdeep/data/datasets/cloud/dataset_cloud.py b/gdeep/data/datasets/cloud/dataset_cloud.py new file mode 100644 index 00000000..aceb3f66 --- /dev/null +++ b/gdeep/data/datasets/cloud/dataset_cloud.py @@ -0,0 +1,44 @@ +import os +from os.path import exists +from pathlib import Path +from zenodo_client.api import Zenodo + +import yaml + +from gdeep.data.datasets.cloud.utils import get_config_path, get_download_directory +from gdeep.utility.constants import DEFAULT_DOWNLOAD_DIR + + +class DatasetCloud: + def __init__( + self, + dataset_name: str, + root_download_directory: str=DEFAULT_DOWNLOAD_DIR + ): + self.name = dataset_name + self.zenodo = Zenodo() + self.download_directory = get_download_directory(self.name, root_download_directory) + self.config = self._load_config() + + def _load_config(self) -> dict: + config_path = get_config_path(self.name) + if not exists(config_path): + raise ValueError(f"Configuration file {config_path} does not exist.") + with open(config_path, "r") as file: + config = yaml.safe_load(file) + return config + + + def does_dataset_exist_locally(self) -> bool: + return exists(self.download_directory / self.name) + + def download(self) -> None: + os.makedirs(self.download_directory, exist_ok=True) + + for file in self.config["files"]: + self.zenodo.download(self.config["deposition_id"], file, parts=[str(self.download_directory)]) + + + + + diff --git a/gdeep/data/datasets/cloud/utils.py b/gdeep/data/datasets/cloud/utils.py new file mode 100644 index 00000000..c26ee481 --- /dev/null +++ b/gdeep/data/datasets/cloud/utils.py @@ -0,0 +1,8 @@ +from pathlib import Path +from gdeep.utility.constants import DATASET_CLOUD_CONFIGS_DIR + +def get_config_path(dataset_name: str) -> Path: + return Path(DATASET_CLOUD_CONFIGS_DIR) / f"{dataset_name}.yaml" + +def get_download_directory(dataset_name: str, root_download_directory: str) -> Path: + return Path(root_download_directory) / dataset_name \ No newline at end of file diff --git a/gdeep/data/datasets/dataset_cloud.py b/gdeep/data/datasets/dataset_cloud.py deleted file mode 100644 index 25fd7642..00000000 --- a/gdeep/data/datasets/dataset_cloud.py +++ /dev/null @@ -1,470 +0,0 @@ -import os -from os import remove -from os.path import join, exists -from typing import List, Tuple, Union, Set - -import json -import wget # type: ignore - -from ._data_cloud import _DataCloud # type: ignore -from gdeep.utility.constants import DEFAULT_DOWNLOAD_DIR, DATASET_BUCKET_NAME - - -class DatasetCloud: - """DatasetCloud class to handle the download and upload - of datasets to the DataCloud. - If the download_directory does not exist, it will be created and - if a folder with the same name as the dataset exists in the - download directory, it will not be downloaded again. - If a folder with the same name as the dataset does not exists - locally, it will be created when downloading the dataset. - - Args: - dataset_name (str): - Name of the dataset to be downloaded or uploaded. - bucket_name (str, optional): - Name of the bucket in the DataCloud. - Defaults to DATASET_BUCKET_NAME. - download_directory (Union[None, str], optional): - Directory where the - dataset will be downloaded to. Defaults to DEFAULT_DOWNLOAD_DIR. - use_public_access (bool, optional): - If True, the dataset will be - downloaded via public url. Defaults to False. - path_credentials (Union[None, str], optional): - Path to the credentials file. - Only used if public_access is False and credentials are not - provided. Defaults to None. - make_public (bool, optional): - If True, the dataset will be made public - - Raises: - ValueError: - Dataset does not exits in cloud. - - Returns: - None - """ - - def __init__( - self, - dataset_name: str, - bucket_name: str = DATASET_BUCKET_NAME, - download_directory: Union[None, str] = None, - use_public_access: bool = True, - path_to_credentials: Union[None, str] = None, - make_public: bool = True, - ) -> None: - # Non-public datasets start with "private_" - if make_public or use_public_access or dataset_name.startswith("private_"): - self.name = dataset_name - else: - self.name = "private_" + dataset_name - self.path_metadata = None - self.use_public_access = use_public_access - if download_directory is None: - # If download_directory is None, the dataset will be downloaded - # to the default directory. - self.download_directory = DEFAULT_DOWNLOAD_DIR - else: - self.download_directory = download_directory - - self._data_cloud = _DataCloud( - bucket_name=bucket_name, - download_directory=self.download_directory, - use_public_access=use_public_access, - path_to_credentials=path_to_credentials, - ) - if use_public_access: - self.public_url = "https://storage.googleapis.com/" + bucket_name + "/" - self.make_public = make_public - - def __del__(self) -> None: - """This function deletes the metadata file if it exists. - - Returns: - None - """ - if self.path_metadata != None: - remove(self.path_metadata) # type: ignore - return None - - def download(self) -> None: - """Download a dataset from the DataCloud. If the dataset does not - exist in the cloud, an exception will be raised. If the dataset - exists locally in the download directory, the dataset will not be - downloaded again. - - Raises: - ValueError: - Dataset does not exits in cloud. - ValueError: - Dataset exists locally but checksums do not match. - """ - if self.use_public_access: - self._download_using_url() - else: - self._download_using_api() - - def _download_using_api(self) -> None: - """Downloads the dataset using the DataCloud API. - If the dataset does not exist in the bucket, an exception will - be raised. If the dataset exists locally in the download directory, - the dataset will not be downloaded again. - - Raises: - ValueError: - Dataset does not exits in cloud. - - Returns: - None - """ - self._check_public_access() - # List of existing datasets in the cloud. - existing_datasets: Set[str] = set( - [ - blob.name.split("/")[0] # type: ignore - for blob in self._data_cloud.bucket.list_blobs() # type: ignore # type: ignore - if blob.name != "giotto-deep-big.png" - ] - ) # type: ignore - if self.name not in existing_datasets: - raise ValueError( - "Dataset {} does not exist in the cloud.".format(self.name) - + "Available datasets are: {existing_datasets}." - ) - if not self._does_dataset_exist_locally(): - self._create_dataset_folder() - self._data_cloud.download_folder(self.name + "/") - - def _does_dataset_exist_locally(self) -> bool: - """Check if the dataset exists locally. - - Returns: - bool: True if the dataset exists locally, False otherwise. - """ - return exists(join(self.download_directory, self.name)) - - def _create_dataset_folder(self) -> None: - """Creates a folder with the dataset name in the download directory. - - Returns: - None - """ - if not exists(join(self.download_directory, self.name)): - os.makedirs(join(self.download_directory, self.name)) - - def _download_using_url(self) -> None: - """Download the dataset using the public url. - If the dataset does not exist in the bucket, an exception will - be raised. If the dataset exists locally in the download directory, - the dataset will not be downloaded again. - - Raises: - ValueError: - Dataset does not exits in cloud. - - Returns: - None - """ - # List of existing datasets in the cloud. - existing_datasets = self.get_existing_datasets() - - # Check if requested dataset exists in the cloud. - assert ( - self.name in existing_datasets - ), "Dataset {} does not exist in the cloud.".format( - self.name - ) + "Available datasets are: {}.".format( - existing_datasets - ) - - # If the dataset does not exist locally, create the dataset folder. - if not self._does_dataset_exist_locally(): - self._create_dataset_folder() - - # Download the dataset (metadata.json, data.pt, labels.pt) - # by using the public URL. - self._data_cloud.download_file(self.name + "/metadata.json") - # load the metadata.json file to get the filetype - with open( - join(self.download_directory, self.name, "metadata.json") # type: ignore - ) as f: - metadata = json.load(f) - # filetype: Literal['pt', 'npy'] - if metadata["data_format"] == "pytorch_tensor": - filetype = "pt" - elif metadata["data_format"] == "numpy_array": - filetype = "npy" - else: - raise ValueError(f"Unknown data format: {metadata['data_format']}") - self._data_cloud.download_file(self.name + "/data." + filetype) - self._data_cloud.download_file(self.name + "/labels." + filetype) - - def get_existing_datasets(self) -> List[str]: - """Returns a list of datasets in the cloud. - - Returns: - List[str]: - List of datasets in the cloud. - """ - if self.use_public_access: - datasets_local = "tmp_datasets.json" - # Download the dataset list json file using the public URL. - wget.download(self.public_url + "datasets.json", datasets_local) # type: ignore - datasets = json.load(open(datasets_local)) - - # Remove duplicates. This has to be fixed in the future. - datasets = list(set(datasets)) - - # Remove the temporary file. - remove(datasets_local) - - return datasets - else: - existing_datasets = [ - blob_name.split("/")[0] - for blob_name in self._data_cloud.list_blobs() - if blob_name != "giotto-deep-big.png" and blob_name != "datasets.json" - ] - # Remove duplicates. - existing_datasets = list(set(existing_datasets)) - - # Remove dataset that are not public, i.e. start with "private_". - existing_datasets = [ - dataset - for dataset in existing_datasets - if not dataset.startswith("private_") - ] - - return existing_datasets - - def _update_dataset_list(self) -> None: - """Updates the dataset list in the datasets.json file. - - Returns: - None - """ - self._check_public_access() - - # List of existing datasets in the cloud. - existing_datasets = self.get_existing_datasets() - - # Save existing datasets to a json file. - json_file = "tmp_datasets.json" - json.dump(existing_datasets, open(json_file, "w")) - - # Upload the json file to the cloud. - self._data_cloud.upload_file( - json_file, - "datasets.json", - make_public=True, - overwrite=True, - ) - - # Remove the temporary file. - remove(json_file) - - @staticmethod - def _get_filetype(path: str) -> str: - """Returns the file extension from a given path. - - Args: - path: - A string path. - - Returns: - str: - The file extension. - - Raises: - None. - """ - return path.split(".")[-1] - - def _check_public_access(self) -> None: - """Check if use_public_access is set to False.""" - assert ( - self.use_public_access is False - ), "Only download functionality is supported for public access." - - def _upload_data( - self, - path: str, - ) -> None: - """Uploads the data file to a Cloud Storage bucket. - - Args: - path: - The path to the data file. - - Raises: - ValueError: - If the file type is not supported. - - Returns: - None - """ - self._check_public_access() - - filetype = DatasetCloud._get_filetype(path) - - # Check if the file type is supported - if filetype in ["pt", "npy"]: - self._data_cloud.upload_file( - path, - (self.metadata["name"] + "/data." + filetype), # type: ignore - make_public=self.make_public, - overwrite=False, - ) - else: - raise ValueError("File type {} is not supported.".format(filetype)) - - def _upload_label( - self, - path: str, - ) -> None: - """Uploads a set of labels to a remote dataset. - - Args: - path: - the path to the labels file. - - Raises: - ValueError: - If the file type is not supported. - - Returns: - None - - """ - self._check_public_access() - - filetype = DatasetCloud._get_filetype(path) - - # Check if the file type is supported - if filetype in ["pt", "npy"]: - self._data_cloud.upload_file( - path, - (self.metadata["name"] + "/labels." + filetype), # type: ignore - make_public=self.make_public, - overwrite=False, - ) - else: - raise ValueError("File type {} is not supported.".format(filetype)) - - def _upload_metadata(self, path: Union[str, None] = None) -> None: - """Uploads the metadata dictionary to the location specified in the - metadata. The metadata dictionary is generated using create_metadata. - - Args: - path (str): - The path to the data cloud folder. If none, path will - be set to the default path. - - Raises: - Exception: - If no metadata exists, an exception will be raised. - - Returns: - None - """ - self._check_public_access() - self._data_cloud.upload_file( - path, # type: ignore - str(self.metadata["name"]) + "/" + "metadata.json", # type: ignore - make_public=self.make_public, # type: ignore - ) - - def _add_metadata( - self, - size_dataset: int, - input_size: Tuple[int, ...], - num_labels: Union[None, int] = None, - data_type: str = "tabular", - task_type: str = "classification", - name: Union[None, str] = None, - data_format: Union[None, str] = None, - comment: Union[None, str] = None, - ) -> None: - """This function accepts various metadata for the dataset and stores it - in a temporary JSON file. - - Args: - size_dataset (int): - The size of the dataset (in terms of the number - of samples). - input_size (Tuple[int, ...]): - The size of each sample in the - dataset. - num_labels (Union[None, int]): - The number of classes in the dataset. - data_type (str): - The type of data in the dataset. - task_type (str): - The task type of the dataset. - name (Union[None, str]): - The name of the dataset. - data_format (Union[None, str]): - The format of the data in the dataset. - comment (Union[None, str]): - A comment describing the dataset. - - Returns: - None - """ - self._check_public_access() - if name is None: - name = self.name - if data_format is None: - data_format = "pytorch_tensor" - self.path_metadata = "tmp_metadata.json" # type: ignore - self.metadata = { - "name": name, - "size": size_dataset, - "input_size": input_size, - "num_labels": num_labels, - "task_type": task_type, - "data_type": data_type, - "data_format": data_format, - "comment": comment, - } - with open(self.path_metadata, "w") as f: # type: ignore - json.dump(self.metadata, f, sort_keys=True, indent=4) - - def _upload( - self, - path_data: str, - path_label: str, - path_metadata: Union[str, None] = None, - ) -> None: - """Uploads a dataset to the cloud. - - Args: - path_data (str): Path to the data files. - path_label (str): Path to the label file. - path_metadata (Optional[str]): Path to the metadata file. - - Raises: - ValueError: If the dataset already exists in the cloud. - - Returns: - None - """ - self._check_public_access() - - # List of existing datasets in the cloud. - existing_datasets = self.get_existing_datasets() - if self.name in existing_datasets: - raise ValueError( - "Dataset {} already exists in the cloud.".format(self.name) - + "Available datasets are: {}.".format(existing_datasets) - ) - if path_metadata is None: - path_metadata = self.path_metadata - self._upload_metadata(path_metadata) - self._upload_data(path_data) - self._upload_label(path_label) - - # Update dataset list. - self._update_dataset_list() diff --git a/gdeep/data/datasets/dataset_uploader.py b/gdeep/data/datasets/dataset_uploader.py new file mode 100644 index 00000000..1e4b4942 --- /dev/null +++ b/gdeep/data/datasets/dataset_uploader.py @@ -0,0 +1,48 @@ +from typing import Optional +import os +import yaml +from pathlib import Path +from zenodo_client.api import Zenodo +from zenodo_client.struct import Metadata +from gdeep.data.datasets.cloud.utils import get_config_path + +class DatasetUploader: + def __init__(self, access_token: Optional[str] = None, sandbox: bool = False): + if access_token is None: + access_token = os.getenv("ZENODO_API_TOKEN") + self.zenodo_client = Zenodo(access_token=access_token, sandbox=sandbox) + + def upload(self, dataset_name: str, metadata: Metadata, file_paths: list[str]) -> None: + config_path = get_config_path(dataset_name) + deposition = self.zenodo_client.create(data=metadata, paths=file_paths) + deposition_id = deposition["id"] + file_names = [file["filename"] for file in deposition["files"]] + config = { + "deposition_id": deposition_id, + "files": file_names + } + with open(config_path, "w") as f: + yaml.dump(config, f) + + def update(self, dataset_name: str, file_paths: list[str]) -> None: + config_path = get_config_path(dataset_name) + if not os.path.exists(config_path): + raise FileNotFoundError(f"Configuration file not found for dataset: {dataset_name}") + with open(get_config_path(dataset_name), "r") as f: + config = yaml.safe_load(f) + deposition_id = config["deposition_id"] + deposition = self.zenodo_client.update(deposition_id, file_paths) + config["files"] = [file["filename"] for file in deposition["files"]] + with open(config_path, "w") as f: + yaml.dump(config, f) + + def remove(self, dataset_name: str) -> None: + config_path = get_config_path(dataset_name) + if not os.path.exists(config_path): + raise FileNotFoundError(f"Configuration file not found for dataset: {dataset_name}") + with open(config_path, "r") as f: + config = yaml.safe_load(f) + deposition_id = config["deposition_id"] + self.zenodo_client.discard(deposition_id) + os.remove(config_path) + diff --git a/gdeep/data/datasets/tests/test_dataset_cloud.py b/gdeep/data/datasets/tests/test_dataset_cloud.py index 7cf27f29..549303f3 100644 --- a/gdeep/data/datasets/tests/test_dataset_cloud.py +++ b/gdeep/data/datasets/tests/test_dataset_cloud.py @@ -1,5 +1,6 @@ -from gdeep.data.datasets import DatasetCloud, dataset_cloud +from gdeep.data.datasets import DatasetCloud +import tempfile import hashlib import logging import os @@ -10,6 +11,7 @@ import numpy as np # type: ignore import torch +from gdeep.data.datasets.cloud import dataset_cloud from gdeep.utility.utils import get_checksum LOGGER = logging.getLogger(__name__) @@ -19,33 +21,24 @@ def test_public_access(): # Download a small dataset from Google Cloud Storage dataset = "SmallDataset" - download_directory = join("examples", "data", "DatasetCloud", "Tmp") - # Remove download directory recursively if it exists - if exists(download_directory): - rmtree(download_directory) + with tempfile.TemporaryDirectory() as download_directory: + dataset_cloud = DatasetCloud( + dataset, root_download_directory=download_directory, use_public_access=True + ) + dataset_cloud.download() - # Create download directory - os.makedirs(download_directory, exist_ok=False) + # Check if the downloaded files (metadata.json, data.json, labels.json) + # are correct + checksums = { + "data.pt": "2ef68a718e29134cbcbf46c9592f6168", + "labels.pt": "d71992425033c6bf449d175db146a423", + } - dataset_cloud = DatasetCloud( - dataset, download_directory=download_directory, use_public_access=True - ) - dataset_cloud.download() - - # Check if the downloaded files (metadata.json, data.json, labels.json) - # are correct - checksums = { - "data.pt": "2ef68a718e29134cbcbf46c9592f6168", - "labels.pt": "d71992425033c6bf449d175db146a423", - } - - for file in checksums.keys(): - assert ( - get_checksum(join(download_directory, dataset, file)) == checksums[file] - ), "File {} is corrupted.".format(file) - # Recursively remove download directory - rmtree(download_directory) + for file in checksums.keys(): + assert ( + get_checksum(join(download_directory, file)) == checksums[file] + ), "File {} is corrupted.".format(file) def test_get_dataset_list(): @@ -53,7 +46,7 @@ def test_get_dataset_list(): # It's only used for initialization of the DatasetCloud object download_directory = "" dataset_cloud = DatasetCloud( - "SmallDataset", download_directory=download_directory, use_public_access=True + "SmallDataset", root_download_directory=download_directory, use_public_access=True ) dataset_list = dataset_cloud.get_existing_datasets() assert len(dataset_list) > 0, "Dataset list is empty." @@ -109,7 +102,7 @@ def test_upload_and_download(): dataset_name = "TmpSmallDataset" dataset_cloud = DatasetCloud( dataset_name, - download_directory=download_directory, + root_download_directory=download_directory, use_public_access=False, ) diff --git a/gdeep/utility/constants.py b/gdeep/utility/constants.py index 7d1755b2..e8b7584d 100644 --- a/gdeep/utility/constants.py +++ b/gdeep/utility/constants.py @@ -13,10 +13,7 @@ # Define the default dataset download directory DEFAULT_DOWNLOAD_DIR = os.path.join(ROOT_DIR, "examples", "data", "DatasetCloud") -# Define the default dataset bucket on Google Cloud Storage where the datasets -# are stored -DATASET_BUCKET_NAME = "adversarial_attack" - +DATASET_CLOUD_CONFIGS_DIR = os.path.join(ROOT_DIR, "configs", "dataset_cloud") # Define the default dataset download directory where the graph # datasets from the PyG (PyTorch Geometric) library are stored diff --git a/requirements.txt b/requirements.txt index 16937e01..99adef39 100644 --- a/requirements.txt +++ b/requirements.txt @@ -41,3 +41,5 @@ jsonpickle typing_extensions; python_version == '3.7' gudhi pre-commit +zenodo-client==0.3.3 +python-dotenv \ No newline at end of file From cf39ac3c435ce8101f463d79497039b1e1303390 Mon Sep 17 00:00:00 2001 From: raphaelreinauer Date: Sun, 14 Apr 2024 22:19:34 +0200 Subject: [PATCH 5/7] Refactor tests --- .../data/datasets/tests/test_dataset_cloud.py | 322 ++++++++++-------- 1 file changed, 172 insertions(+), 150 deletions(-) diff --git a/gdeep/data/datasets/tests/test_dataset_cloud.py b/gdeep/data/datasets/tests/test_dataset_cloud.py index 549303f3..e1a30bbd 100644 --- a/gdeep/data/datasets/tests/test_dataset_cloud.py +++ b/gdeep/data/datasets/tests/test_dataset_cloud.py @@ -1,152 +1,174 @@ -from gdeep.data.datasets import DatasetCloud - -import tempfile -import hashlib -import logging import os -from os import remove, environ -from os.path import join, exists -from shutil import rmtree - -import numpy as np # type: ignore +import tempfile +import shutil +from pathlib import Path +import pytest import torch - -from gdeep.data.datasets.cloud import dataset_cloud -from gdeep.utility.utils import get_checksum - -LOGGER = logging.getLogger(__name__) - - -# Test public access for downloading datasets -def test_public_access(): - # Download a small dataset from Google Cloud Storage - dataset = "SmallDataset" - - with tempfile.TemporaryDirectory() as download_directory: - dataset_cloud = DatasetCloud( - dataset, root_download_directory=download_directory, use_public_access=True - ) - dataset_cloud.download() - - # Check if the downloaded files (metadata.json, data.json, labels.json) - # are correct - checksums = { - "data.pt": "2ef68a718e29134cbcbf46c9592f6168", - "labels.pt": "d71992425033c6bf449d175db146a423", - } - - for file in checksums.keys(): - assert ( - get_checksum(join(download_directory, file)) == checksums[file] - ), "File {} is corrupted.".format(file) - - -def test_get_dataset_list(): - # Download directory will not be used as well ass the dataset - # It's only used for initialization of the DatasetCloud object - download_directory = "" - dataset_cloud = DatasetCloud( - "SmallDataset", root_download_directory=download_directory, use_public_access=True - ) - dataset_list = dataset_cloud.get_existing_datasets() - assert len(dataset_list) > 0, "Dataset list is empty." - assert "SmallDataset" in dataset_list, "Dataset list does not contain the dataset." - - # Test if the list does not contain duplicates - assert len(dataset_list) == len( - set(dataset_list) - ), "Dataset list contains duplicates." - - -if "GOOGLE_APPLICATION_CREDENTIALS" in dict(environ): - - def test_update_dataset_list(): - # Create DatasetCloud object - dataset_cloud = DatasetCloud("", use_public_access=False) - # Update the dataset list - dataset_cloud._update_dataset_list() - - def test_upload_and_download(): - for data_format in ["pytorch_tensor", "numpy_array"]: - download_directory = join("examples", "data", "DatasetCloud") - # Generate a dataset - # You don't have to do that if you already have a pickled dataset - size_dataset = 100 - input_dim = 5 - num_labels = 2 - - if data_format == "pytorch_tensor": - data = torch.rand(size_dataset, input_dim) - labels = torch.randint(0, num_labels, (size_dataset,)).long() - - # pickle data and labels - data_filename = "tmp_data.pt" - labels_filename = "tmp_labels.pt" - torch.save(data, data_filename) - torch.save(labels, labels_filename) - elif data_format == "numpy_array": - data = np.random.rand(size_dataset, input_dim) # type: ignore - labels = np.random.randint( - 0, num_labels, (size_dataset,), dtype=np.long - ) - - # pickle data and labels - data_filename = "tmp_data.npy" - labels_filename = "tmp_labels.npy" - np.save(data_filename, data) - np.save(labels_filename, labels) - else: - raise ValueError(f"Unknown data format: {data_format}") - - ## Upload dataset to Cloud - dataset_name = "TmpSmallDataset" - dataset_cloud = DatasetCloud( - dataset_name, - root_download_directory=download_directory, - use_public_access=False, - ) - - # Specify the metadata of the dataset - dataset_cloud._add_metadata( - name=dataset_name, - size_dataset=size_dataset, - input_size=(input_dim,), - num_labels=num_labels, - data_type="tabular", - data_format=data_format, - ) - - # upload dataset to Cloud - dataset_cloud._upload(data_filename, labels_filename) - - # download dataset from Cloud to ´example/data/DataCloud/SmallDataset/´ - dataset_cloud.download() - - # remove created blob - dataset_cloud._data_cloud.delete_blobs(dataset_name) - - # check whether downloaded dataset is the same as the original dataset - if data_format == "pytorch_tensor": - downloaded_files = ["data.pt", "labels.pt", "metadata.json"] - elif data_format == "numpy_array": - downloaded_files = ["data.npy", "labels.npy", "metadata.json"] - else: - raise ValueError(f"Unknown data format: {data_format}") - for file in downloaded_files: - hash_original = get_checksum("tmp_" + file) - path_downloaded_file = join(download_directory, dataset_name, file) - hash_downloaded = get_checksum(path_downloaded_file) - assert ( - hash_original == hash_downloaded - ), "Original and downloaded files do not match." - - # remove the labels and data files - remove(data_filename) - remove(labels_filename) - - # remove the downloaded dataset - rmtree(join(download_directory, dataset_name)) - - # remove the metadata file - # will get deleted automatically when dataset_cloud is out of scope. - del dataset_cloud +import yaml +from zenodo_client.struct import Metadata +from gdeep.data.datasets.cloud.dataset_cloud import DatasetCloud +from gdeep.data.datasets.cloud.dataset_uploader import DatasetUploader +from gdeep.data.datasets.cloud.utils import get_config_path, get_dataset_list + +def test_dataset_upload_and_download(): + with tempfile.TemporaryDirectory() as temp_dir: + # Create test data and metadata + data_tensor = torch.randn(10, 5) + labels_tensor = torch.randint(0, 2, (10,)) + metadata = Metadata(title="Test Dataset", description="A test dataset for unit testing") + + # Save data and metadata to temporary files + data_path = Path(temp_dir) / "data.pt" + labels_path = Path(temp_dir) / "labels.pt" + metadata_path = Path(temp_dir) / "metadata.yaml" + torch.save(data_tensor, data_path) + torch.save(labels_tensor, labels_path) + with open(metadata_path, "w") as f: + yaml.dump(dict(metadata), f) + + # Upload the dataset + dataset_name = "test_dataset" + file_paths = [str(data_path), str(labels_path), str(metadata_path)] + uploader = DatasetUploader(sandbox=True) + uploader.upload(dataset_name, metadata, file_paths) + + # Check if the dataset is added to the list of datasets + assert dataset_name in get_dataset_list() + + # Check if the dataset config file is created + config_path = get_config_path(dataset_name) + assert config_path.exists() + + # Download the dataset + download_dir = Path(temp_dir) / "download" + dataset = DatasetCloud(dataset_name, root_download_directory=str(download_dir)) + dataset.download() + + # Check if the downloaded files exist + downloaded_data_path = download_dir / dataset_name / "data.pt" + downloaded_labels_path = download_dir / dataset_name / "labels.pt" + assert downloaded_data_path.exists() + assert downloaded_labels_path.exists() + + # Check if the downloaded data matches the original data + downloaded_data = torch.load(downloaded_data_path) + downloaded_labels = torch.load(downloaded_labels_path) + assert torch.allclose(downloaded_data, data_tensor) + assert torch.allclose(downloaded_labels, labels_tensor) + + # Clean up the uploaded dataset + uploader.remove(dataset_name) + assert not config_path.exists() + assert dataset_name not in get_dataset_list() + +def test_dataset_update(): + with tempfile.TemporaryDirectory() as temp_dir: + # Create test data and metadata + data_tensor = torch.randn(10, 5) + labels_tensor = torch.randint(0, 2, (10,)) + metadata = Metadata(title="Test Dataset", description="A test dataset for unit testing") + + # Save data and metadata to temporary files + data_path = Path(temp_dir) / "data.pt" + labels_path = Path(temp_dir) / "labels.pt" + metadata_path = Path(temp_dir) / "metadata.yaml" + torch.save(data_tensor, data_path) + torch.save(labels_tensor, labels_path) + with open(metadata_path, "w") as f: + yaml.dump(dict(metadata), f) + + # Upload the dataset + dataset_name = "test_dataset" + file_paths = [str(data_path), str(labels_path), str(metadata_path)] + uploader = DatasetUploader(sandbox=True) + uploader.upload(dataset_name, metadata, file_paths) + + # Update the dataset with new files + updated_data_tensor = torch.randn(20, 5) + updated_labels_tensor = torch.randint(0, 2, (20,)) + updated_data_path = Path(temp_dir) / "updated_data.pt" + updated_labels_path = Path(temp_dir) / "updated_labels.pt" + torch.save(updated_data_tensor, updated_data_path) + torch.save(updated_labels_tensor, updated_labels_path) + updated_file_paths = [str(updated_data_path), str(updated_labels_path)] + uploader.update(dataset_name, updated_file_paths) + + # Download the updated dataset + download_dir = Path(temp_dir) / "download" + dataset = DatasetCloud(dataset_name, root_download_directory=str(download_dir)) + dataset.download() + + # Check if the downloaded files exist + downloaded_data_path = download_dir / dataset_name / "updated_data.pt" + downloaded_labels_path = download_dir / dataset_name / "updated_labels.pt" + assert downloaded_data_path.exists() + assert downloaded_labels_path.exists() + + # Check if the downloaded data matches the updated data + downloaded_data = torch.load(downloaded_data_path) + downloaded_labels = torch.load(downloaded_labels_path) + assert torch.allclose(downloaded_data, updated_data_tensor) + assert torch.allclose(downloaded_labels, updated_labels_tensor) + + # Clean up the uploaded dataset + uploader.remove(dataset_name) + +def test_dataset_remove(): + with tempfile.TemporaryDirectory() as temp_dir: + # Create test data and metadata + data_tensor = torch.randn(10, 5) + labels_tensor = torch.randint(0, 2, (10,)) + metadata = Metadata(title="Test Dataset", description="A test dataset for unit testing") + + # Save data and metadata to temporary files + data_path = Path(temp_dir) / "data.pt" + labels_path = Path(temp_dir) / "labels.pt" + metadata_path = Path(temp_dir) / "metadata.yaml" + torch.save(data_tensor, data_path) + torch.save(labels_tensor, labels_path) + with open(metadata_path, "w") as f: + yaml.dump(dict(metadata), f) + + # Upload the dataset + dataset_name = "test_dataset" + file_paths = [str(data_path), str(labels_path), str(metadata_path)] + uploader = DatasetUploader(sandbox=True) + uploader.upload(dataset_name, metadata, file_paths) + + # Remove the dataset + uploader.remove(dataset_name) + + # Check if the dataset is removed from the list of datasets + assert dataset_name not in get_dataset_list() + + # Check if the dataset config file is removed + config_path = get_config_path(dataset_name) + assert not config_path.exists() + + +def test_all_datasets_valid(): + with tempfile.TemporaryDirectory() as temp_dir: + download_dir = Path(temp_dir) / "download" + + # Get the list of all datasets in DATASET_CLOUD_CONFIGS_DIR + dataset_list = get_dataset_list() + + # Download each dataset and check for errors + for dataset_name in dataset_list: + dataset = DatasetCloud(dataset_name, root_download_directory=str(download_dir)) + + try: + dataset.download() + + # Check if the downloaded files exist + downloaded_files = list(Path(download_dir / dataset_name).glob("*")) + dataset_files = dataset.config["files"] + assert len(downloaded_files) == len(dataset_files), \ + f"Number of files do not match for dataset: {dataset_name}" + + except Exception as e: + pytest.fail(f"Error occurred while downloading dataset: {dataset_name}\n{str(e)}") + + finally: + # Clean up the downloaded dataset + shutil.rmtree(download_dir / dataset_name) \ No newline at end of file From 79ed94ce0c112597551bdc2003eb19d532effd93 Mon Sep 17 00:00:00 2001 From: raphaelreinauer Date: Sun, 14 Apr 2024 22:20:07 +0200 Subject: [PATCH 6/7] Refactor dataset imports and fix cloud utils in gdeep.data.datasets --- gdeep/data/datasets/__init__.py | 5 ++--- gdeep/data/datasets/base_dataloaders.py | 2 +- gdeep/data/datasets/{ => cloud}/dataset_uploader.py | 0 gdeep/data/datasets/cloud/utils.py | 5 ++++- 4 files changed, 7 insertions(+), 5 deletions(-) rename gdeep/data/datasets/{ => cloud}/dataset_uploader.py (100%) diff --git a/gdeep/data/datasets/__init__.py b/gdeep/data/datasets/__init__.py index 8d5a859e..2c3ccea1 100644 --- a/gdeep/data/datasets/__init__.py +++ b/gdeep/data/datasets/__init__.py @@ -1,12 +1,11 @@ from .categorical_data import CategoricalDataCloud from .tori import Rotation, ToriDataset -from .dataset_cloud import DatasetCloud -from ._data_cloud import _DataCloud +from .cloud.dataset_cloud import DatasetCloud from .build_datasets import DatasetBuilder, get_dataset from .base_dataloaders import DataLoaderBuilder, AbstractDataLoaderBuilder from .dataset_for_image import ImageClassificationFromFiles from .dataset_form_array import FromArray -from .dataloader_cloud import DlBuilderFromDataCloud +from .cloud.dataloader_cloud import DlBuilderFromDataCloud from .parallel_orbit import ( generate_orbit_parallel, create_pd_orbits, diff --git a/gdeep/data/datasets/base_dataloaders.py b/gdeep/data/datasets/base_dataloaders.py index 1b0554c2..3e28e247 100644 --- a/gdeep/data/datasets/base_dataloaders.py +++ b/gdeep/data/datasets/base_dataloaders.py @@ -26,7 +26,7 @@ from torch.utils.data import Sampler from .build_datasets import get_dataset -from .dataset_cloud import DatasetCloud +from .cloud.dataset_cloud import DatasetCloud from ..transforming_dataset import TransformingDataset diff --git a/gdeep/data/datasets/dataset_uploader.py b/gdeep/data/datasets/cloud/dataset_uploader.py similarity index 100% rename from gdeep/data/datasets/dataset_uploader.py rename to gdeep/data/datasets/cloud/dataset_uploader.py diff --git a/gdeep/data/datasets/cloud/utils.py b/gdeep/data/datasets/cloud/utils.py index c26ee481..42ef284e 100644 --- a/gdeep/data/datasets/cloud/utils.py +++ b/gdeep/data/datasets/cloud/utils.py @@ -5,4 +5,7 @@ def get_config_path(dataset_name: str) -> Path: return Path(DATASET_CLOUD_CONFIGS_DIR) / f"{dataset_name}.yaml" def get_download_directory(dataset_name: str, root_download_directory: str) -> Path: - return Path(root_download_directory) / dataset_name \ No newline at end of file + return Path(root_download_directory) / dataset_name + +def get_dataset_list() -> list[str]: + return [path.stem for path in Path(DATASET_CLOUD_CONFIGS_DIR).iterdir() if path.suffix == ".yaml"] \ No newline at end of file From f355184c5c96f4f890a8e68cdac10da81e2e8293 Mon Sep 17 00:00:00 2001 From: raphaelreinauer Date: Sun, 14 Apr 2024 22:36:02 +0200 Subject: [PATCH 7/7] Make tests more robust --- gdeep/data/datasets/cloud/dataset_uploader.py | 2 + .../data/datasets/tests/test_dataset_cloud.py | 114 +++++++++--------- 2 files changed, 61 insertions(+), 55 deletions(-) diff --git a/gdeep/data/datasets/cloud/dataset_uploader.py b/gdeep/data/datasets/cloud/dataset_uploader.py index 1e4b4942..1db332ff 100644 --- a/gdeep/data/datasets/cloud/dataset_uploader.py +++ b/gdeep/data/datasets/cloud/dataset_uploader.py @@ -14,6 +14,8 @@ def __init__(self, access_token: Optional[str] = None, sandbox: bool = False): def upload(self, dataset_name: str, metadata: Metadata, file_paths: list[str]) -> None: config_path = get_config_path(dataset_name) + if os.path.exists(config_path): + raise FileExistsError(f"Configuration file already exists for dataset: {dataset_name}, use update method instead") deposition = self.zenodo_client.create(data=metadata, paths=file_paths) deposition_id = deposition["id"] file_names = [file["filename"] for file in deposition["files"]] diff --git a/gdeep/data/datasets/tests/test_dataset_cloud.py b/gdeep/data/datasets/tests/test_dataset_cloud.py index e1a30bbd..11555efc 100644 --- a/gdeep/data/datasets/tests/test_dataset_cloud.py +++ b/gdeep/data/datasets/tests/test_dataset_cloud.py @@ -32,34 +32,36 @@ def test_dataset_upload_and_download(): uploader = DatasetUploader(sandbox=True) uploader.upload(dataset_name, metadata, file_paths) - # Check if the dataset is added to the list of datasets - assert dataset_name in get_dataset_list() + try: + # Check if the dataset is added to the list of datasets + assert dataset_name in get_dataset_list() - # Check if the dataset config file is created - config_path = get_config_path(dataset_name) - assert config_path.exists() + # Check if the dataset config file is created + config_path = get_config_path(dataset_name) + assert config_path.exists() - # Download the dataset - download_dir = Path(temp_dir) / "download" - dataset = DatasetCloud(dataset_name, root_download_directory=str(download_dir)) - dataset.download() - - # Check if the downloaded files exist - downloaded_data_path = download_dir / dataset_name / "data.pt" - downloaded_labels_path = download_dir / dataset_name / "labels.pt" - assert downloaded_data_path.exists() - assert downloaded_labels_path.exists() - - # Check if the downloaded data matches the original data - downloaded_data = torch.load(downloaded_data_path) - downloaded_labels = torch.load(downloaded_labels_path) - assert torch.allclose(downloaded_data, data_tensor) - assert torch.allclose(downloaded_labels, labels_tensor) - - # Clean up the uploaded dataset - uploader.remove(dataset_name) - assert not config_path.exists() - assert dataset_name not in get_dataset_list() + # Download the dataset + download_dir = Path(temp_dir) / "download" + dataset = DatasetCloud(dataset_name, root_download_directory=str(download_dir)) + dataset.download() + + # Check if the downloaded files exist + downloaded_data_path = download_dir / dataset_name / "data.pt" + downloaded_labels_path = download_dir / dataset_name / "labels.pt" + assert downloaded_data_path.exists() + assert downloaded_labels_path.exists() + + # Check if the downloaded data matches the original data + downloaded_data = torch.load(downloaded_data_path) + downloaded_labels = torch.load(downloaded_labels_path) + assert torch.allclose(downloaded_data, data_tensor) + assert torch.allclose(downloaded_labels, labels_tensor) + + finally: + # Clean up the uploaded dataset + uploader.remove(dataset_name) + assert not config_path.exists() + assert dataset_name not in get_dataset_list() def test_dataset_update(): with tempfile.TemporaryDirectory() as temp_dir: @@ -83,35 +85,37 @@ def test_dataset_update(): uploader = DatasetUploader(sandbox=True) uploader.upload(dataset_name, metadata, file_paths) - # Update the dataset with new files - updated_data_tensor = torch.randn(20, 5) - updated_labels_tensor = torch.randint(0, 2, (20,)) - updated_data_path = Path(temp_dir) / "updated_data.pt" - updated_labels_path = Path(temp_dir) / "updated_labels.pt" - torch.save(updated_data_tensor, updated_data_path) - torch.save(updated_labels_tensor, updated_labels_path) - updated_file_paths = [str(updated_data_path), str(updated_labels_path)] - uploader.update(dataset_name, updated_file_paths) - - # Download the updated dataset - download_dir = Path(temp_dir) / "download" - dataset = DatasetCloud(dataset_name, root_download_directory=str(download_dir)) - dataset.download() - - # Check if the downloaded files exist - downloaded_data_path = download_dir / dataset_name / "updated_data.pt" - downloaded_labels_path = download_dir / dataset_name / "updated_labels.pt" - assert downloaded_data_path.exists() - assert downloaded_labels_path.exists() - - # Check if the downloaded data matches the updated data - downloaded_data = torch.load(downloaded_data_path) - downloaded_labels = torch.load(downloaded_labels_path) - assert torch.allclose(downloaded_data, updated_data_tensor) - assert torch.allclose(downloaded_labels, updated_labels_tensor) - - # Clean up the uploaded dataset - uploader.remove(dataset_name) + try: + # Update the dataset with new files + updated_data_tensor = torch.randn(20, 5) + updated_labels_tensor = torch.randint(0, 2, (20,)) + updated_data_path = Path(temp_dir) / "updated_data.pt" + updated_labels_path = Path(temp_dir) / "updated_labels.pt" + torch.save(updated_data_tensor, updated_data_path) + torch.save(updated_labels_tensor, updated_labels_path) + updated_file_paths = [str(updated_data_path), str(updated_labels_path)] + uploader.update(dataset_name, updated_file_paths) + + # Download the updated dataset + download_dir = Path(temp_dir) / "download" + dataset = DatasetCloud(dataset_name, root_download_directory=str(download_dir)) + dataset.download() + + # Check if the downloaded files exist + downloaded_data_path = download_dir / dataset_name / "updated_data.pt" + downloaded_labels_path = download_dir / dataset_name / "updated_labels.pt" + assert downloaded_data_path.exists() + assert downloaded_labels_path.exists() + + # Check if the downloaded data matches the updated data + downloaded_data = torch.load(downloaded_data_path) + downloaded_labels = torch.load(downloaded_labels_path) + assert torch.allclose(downloaded_data, updated_data_tensor) + assert torch.allclose(downloaded_labels, updated_labels_tensor) + + finally: + # Clean up the uploaded dataset + uploader.remove(dataset_name) def test_dataset_remove(): with tempfile.TemporaryDirectory() as temp_dir: