Skip to content

Commit

Permalink
Provide output_folder argument
Browse files Browse the repository at this point in the history
  • Loading branch information
larsll committed Nov 23, 2024
1 parent f72f085 commit 95ad16d
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions utils/download-car-model.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,12 @@ def check_model_file(prefix):
print(f"Error checking {prefix}: {e}")
return None

def get_matching_prefixes(prefix_pattern):
def list_matching_prefixes(bucket_name, prefix_pattern):
"""
Get a list of prefixes in the S3 bucket that match the given pattern.
List all prefixes in the specified S3 bucket that match the given pattern.
Args:
bucket_name (str): The name of the S3 bucket.
prefix_pattern (str): The pattern to match prefixes against.
Returns:
Expand All @@ -74,19 +75,22 @@ def get_matching_prefixes(prefix_pattern):
print(f"Error listing prefixes: {e}")
return []

def download_and_rename_model_file(prefix, file_key):
def download_and_rename_model_file(prefix, file_key, output_folder="."):
"""
Download and rename the model.tar.gz file from the specified file key.
Args:
prefix (str): The prefix of the model file.
file_key (str): The S3 key of the model file to download.
output_folder (str): The folder where the downloaded file should be placed. Defaults to the current directory.
Returns:
bool: True if the model file is downloaded and renamed, False otherwise.
"""
try:
local_filename = os.path.join("tmp", f"{prefix.rstrip('/')}.tar.gz")
if not os.path.exists(output_folder):
os.makedirs(output_folder)
local_filename = os.path.join(output_folder, f"{prefix.rstrip('/')}.tar.gz")
s3.download_file(bucket_name, file_key, local_filename)
print(f"Downloaded and renamed {file_key} to {local_filename}")
return True
Expand Down Expand Up @@ -114,12 +118,13 @@ def validate_s3_connection():
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Check and download model files from S3.')
parser.add_argument('--pattern', type=str, required=True, help='Pattern for prefixes to check')
parser.add_argument('--output_folder', type=str, default='.', help='Folder to store downloaded files')
args = parser.parse_args()

validate_s3_connection()

matching_prefixes = get_matching_prefixes(args.pattern)
matching_prefixes = list_matching_prefixes(bucket_name, args.pattern)
for prefix in matching_prefixes:
model_file_path = check_model_file(prefix)
if model_file_path:
download_and_rename_model_file(prefix, model_file_path)
download_and_rename_model_file(prefix, model_file_path, args.output_folder)

0 comments on commit 95ad16d

Please sign in to comment.