Skip to content

Commit

Permalink
sdk/python: add support for ext mapping (multiobj transform)
Browse files Browse the repository at this point in the history
Signed-off-by: Ryan Koo <[email protected]>
  • Loading branch information
rkoo19 committed Dec 20, 2024
1 parent 3fb75ad commit 8c60006
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 10 deletions.
9 changes: 6 additions & 3 deletions python/aistore/sdk/multiobj/object_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
import logging
from typing import List, Iterable
from typing import Dict, List, Iterable

from aistore.sdk.ais_source import AISSource
from aistore.sdk.const import (
Expand Down Expand Up @@ -282,13 +282,14 @@ def copy(
value=value,
).text

# pylint: disable=too-many-arguments
# pylint: disable=too-many-arguments, too-many-locals
def transform(
self,
to_bck: "Bucket",
etl_name: str,
timeout: str = DEFAULT_ETL_TIMEOUT,
prepend: str = "",
ext: Dict[str, str] = None,
continue_on_error: bool = False,
dry_run: bool = False,
force: bool = False,
Expand All @@ -304,6 +305,8 @@ def transform(
etl_name (str): Name of existing ETL to apply
timeout (str): Timeout of the ETL job (e.g. 5m for 5 minutes)
prepend (str, optional): Value to prepend to the name of resulting transformed objects
ext (Dict[str, str], optional): dict of new extension followed by extension to be replaced
(i.e. {"jpg": "txt"})
continue_on_error (bool, optional): Whether to continue if there is an error transforming a single object
dry_run (bool, optional): Skip performing the transform and just log the intended actions
force (bool, optional): Force this job to run over others in case it conflicts
Expand Down Expand Up @@ -340,7 +343,7 @@ def transform(
transform_msg = TransformBckMsg(etl_name=etl_name, timeout=timeout)
value = TCMultiObj(
to_bck=to_bck.as_model(),
tc_msg=TCBckMsg(transform_msg=transform_msg, copy_msg=copy_msg),
tc_msg=TCBckMsg(ext=ext, transform_msg=transform_msg, copy_msg=copy_msg),
object_selection=self._obj_collection.get_value(),
continue_on_err=continue_on_error,
num_workers=num_workers,
Expand Down
2 changes: 1 addition & 1 deletion python/tests/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,4 @@

# Names
PREFIX_NAME = "prefix-"
SUFFIX_NAME = "-suffix"
SUFFIX_NAME = "-suffix.ext"
4 changes: 2 additions & 2 deletions python/tests/integration/sdk/test_bucket_ops_stress.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from tests.integration import AWS_BUCKET
from tests.integration.sdk.remote_enabled_test import RemoteEnabledTest
from tests.const import STRESS_TEST_OBJECT_COUNT, TEST_TIMEOUT
from tests.const import STRESS_TEST_OBJECT_COUNT, TEST_TIMEOUT, SUFFIX_NAME


class TestBucketOpsStress(RemoteEnabledTest):
Expand All @@ -21,7 +21,7 @@ def setUp(self) -> None:
)
def test_stress_copy_objects_sync_flag(self):
num_obj = STRESS_TEST_OBJECT_COUNT
obj_names = self._create_objects(num_obj=num_obj, suffix="-suffix")
obj_names = self._create_objects(num_obj=num_obj, suffix=SUFFIX_NAME)
to_bck = self._create_bucket()

obj_group = self.bucket.objects(obj_names=obj_names)
Expand Down
10 changes: 9 additions & 1 deletion python/tests/integration/sdk/test_object_group_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class TestObjectGroupOps(RemoteEnabledTest):
def setUp(self) -> None:
super().setUp()
self.suffix = SUFFIX_NAME
self.obj_extension = self.suffix[-3]
# Use a slightly larger file size to allow for blob threshold (must be > 1MiB)
self.file_size = MEDIUM_FILE_SIZE
self.obj_names = self._create_objects(
Expand Down Expand Up @@ -347,6 +348,7 @@ def transform(input_bytes):

to_bck = self._create_bucket()
new_prefix = PREFIX_NAME
new_ext = "new-ext"
self.assertEqual(0, len(to_bck.list_all_objects(prefix=self.obj_prefix)))
self.assertEqual(
OBJECT_COUNT, len(self.bucket.list_all_objects(prefix=self.obj_prefix))
Expand All @@ -356,6 +358,7 @@ def transform(input_bytes):
"to_bck": to_bck,
"etl_name": md5_etl.name,
"prepend": new_prefix,
"ext": {self.obj_extension: new_ext},
}
if num_workers is not None:
transform_kwargs["num_workers"] = num_workers
Expand All @@ -371,7 +374,9 @@ def transform(input_bytes):
for name in self.obj_names
]
to_obj_values = [
to_bck.object(new_prefix + name).get_reader().read_all()
to_bck.object(self._new_name(name, new_prefix, new_ext))
.get_reader()
.read_all()
for name in self.obj_names
]
self.assertEqual(to_obj_values, from_obj_hashes)
Expand All @@ -390,3 +395,6 @@ def _evict_objects(self, obj_group):
self._check_all_objects_cached(
len(obj_group.list_names()), expected_cached=False
)

def _new_name(self, s, prefix, ext):
return prefix + s.rstrip(self.obj_extension) + ext
2 changes: 2 additions & 0 deletions python/tests/unit/sdk/multiobj/test_object_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ def test_transform(self):
prepend_val = "new_prefix-"
self.expected_value["coer"] = True
self.expected_value["prepend"] = prepend_val
self.expected_value["ext"] = {"wav": "flac"}
self.expected_value["request_timeout"] = timeout
self.expected_value["dry_run"] = True
self.expected_value["force"] = True
Expand All @@ -208,6 +209,7 @@ def test_transform(self):
self.expected_value,
to_bck=self.dest_bucket,
prepend=prepend_val,
ext={"wav": "flac"},
etl_name=ETL_NAME,
timeout=timeout,
dry_run=True,
Expand Down
12 changes: 9 additions & 3 deletions python/tests/unit/sdk/multiobj/test_object_range.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def test_object_range_defaults(self):
object_range = ObjectRange(
prefix=self.prefix, min_index=self.min_index, max_index=self.max_index
)
self.assertEqual("prefix-{4..9..1}", str(object_range))
self.assertEqual(f"{self.prefix}{{4..9..1}}", str(object_range))

def test_object_range(self):
object_range = ObjectRange(
Expand All @@ -31,7 +31,9 @@ def test_object_range(self):
step=self.step,
suffix=self.suffix,
)
self.assertEqual("prefix-{004..009..2}-suffix", str(object_range))
self.assertEqual(
f"{self.prefix}{{004..009..2}}{self.suffix}", str(object_range)
)

def test_object_range_prefix_only(self):
object_range = ObjectRange(prefix=self.prefix)
Expand Down Expand Up @@ -77,5 +79,9 @@ def test_iter(self):
step=self.step,
suffix=self.suffix,
)
expected_range = ["prefix-004-suffix", "prefix-006-suffix", "prefix-008-suffix"]
expected_range = [
f"{self.prefix}004{self.suffix}",
f"{self.prefix}006{self.suffix}",
f"{self.prefix}008{self.suffix}",
]
self.assertEqual(expected_range, list(object_range))

0 comments on commit 8c60006

Please sign in to comment.