diff --git a/utils/download-car-model.py b/utils/download-car-model.py index e6f4ba26..2aeb3ba4 100755 --- a/utils/download-car-model.py +++ b/utils/download-car-model.py @@ -55,11 +55,12 @@ def check_model_file(prefix): print(f"Error checking {prefix}: {e}") return None -def get_matching_prefixes(prefix_pattern): +def list_matching_prefixes(bucket_name, prefix_pattern): """ - Get a list of prefixes in the S3 bucket that match the given pattern. + List all prefixes in the specified S3 bucket that match the given pattern. Args: + bucket_name (str): The name of the S3 bucket. prefix_pattern (str): The pattern to match prefixes against. Returns: @@ -74,19 +75,22 @@ def get_matching_prefixes(prefix_pattern): print(f"Error listing prefixes: {e}") return [] -def download_and_rename_model_file(prefix, file_key): +def download_and_rename_model_file(prefix, file_key, output_folder="."): """ Download and rename the model.tar.gz file from the specified file key. Args: prefix (str): The prefix of the model file. file_key (str): The S3 key of the model file to download. + output_folder (str): The folder where the downloaded file should be placed. Defaults to the current directory. Returns: bool: True if the model file is downloaded and renamed, False otherwise. """ try: - local_filename = os.path.join("tmp", f"{prefix.rstrip('/')}.tar.gz") + if not os.path.exists(output_folder): + os.makedirs(output_folder) + local_filename = os.path.join(output_folder, f"{prefix.rstrip('/')}.tar.gz") s3.download_file(bucket_name, file_key, local_filename) print(f"Downloaded and renamed {file_key} to {local_filename}") return True @@ -114,12 +118,13 @@ def validate_s3_connection(): if __name__ == "__main__": parser = argparse.ArgumentParser(description='Check and download model files from S3.') parser.add_argument('--pattern', type=str, required=True, help='Pattern for prefixes to check') + parser.add_argument('--output_folder', type=str, default='.', help='Folder to store downloaded files') args = parser.parse_args() validate_s3_connection() - matching_prefixes = get_matching_prefixes(args.pattern) + matching_prefixes = list_matching_prefixes(bucket_name, args.pattern) for prefix in matching_prefixes: model_file_path = check_model_file(prefix) if model_file_path: - download_and_rename_model_file(prefix, model_file_path) \ No newline at end of file + download_and_rename_model_file(prefix, model_file_path, args.output_folder) \ No newline at end of file