Skip to content

Commit

Permalink
async write to cache
Browse files Browse the repository at this point in the history
  • Loading branch information
KeplerC committed Aug 31, 2024
1 parent a0e813a commit 143a4fe
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 170 deletions.
26 changes: 17 additions & 9 deletions benchmarks/openx.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,14 +127,18 @@ def __init__(self, exp_dir, dataset_name, num_trajectories, log_frequency=DEFAUL
super().__init__(exp_dir, dataset_name, num_trajectories, dataset_type="vla", log_frequency=log_frequency)
self.file_extension = ".vla"

def measure_loading_time(self, mode="no_cache"):
def measure_loading_time(self, save_to_cache=True):
start_time = time.time()
loader = VLALoader(self.dataset_dir, cache_dir=CACHE_DIR)
if save_to_cache:
mode = "cache"
else:
mode = "no_cache"
for i, data in enumerate(loader, 1):
if self.num_trajectories != -1 and i > self.num_trajectories:
break
try:
self._recursively_load_data(data.load(mode=mode))
self._recursively_load_data(data.load(save_to_cache=save_to_cache))
elapsed_time = time.time() - start_time
self.write_result(f"VLA-{mode.capitalize()}", elapsed_time, i)
if i % self.log_frequency == 0:
Expand All @@ -143,7 +147,7 @@ def measure_loading_time(self, mode="no_cache"):
print(f"Failed to load data: {e}")
return time.time() - start_time

def measure_random_loading_time(self, num_loads):
def measure_random_loading_time(self, num_loads, save_to_cache=True):
start_time = time.time()
loader = VLALoader(self.dataset_dir, cache_dir=CACHE_DIR)
dataset_size = len(loader)
Expand All @@ -153,7 +157,7 @@ def measure_random_loading_time(self, num_loads):
random_index = np.random.randint(0, dataset_size)
data = loader[random_index]
try:
self._recursively_load_data(data.load(mode="cache"))
self._recursively_load_data(data.load(save_to_cache=save_to_cache))
elapsed_time = time.time() - start_time
self.write_result(f"VLA-RandomLoad", elapsed_time, i + 1)
if (i + 1) % self.log_frequency == 0:
Expand All @@ -168,14 +172,18 @@ def __init__(self, exp_dir, dataset_name, num_trajectories, log_frequency=DEFAUL
super().__init__(exp_dir, dataset_name, num_trajectories, dataset_type="ffv1", log_frequency=log_frequency)
self.file_extension = ".vla"

def measure_loading_time(self, mode="no_cache"):
def measure_loading_time(self, save_to_cache=True):
start_time = time.time()
loader = VLALoader(self.dataset_dir, cache_dir=CACHE_DIR)
if save_to_cache:
mode = "cache"
else:
mode = "no_cache"
for i, data in enumerate(loader, 1):
if self.num_trajectories != -1 and i > self.num_trajectories:
break
try:
self._recursively_load_data(data.load(mode=mode))
self._recursively_load_data(data.load(save_to_cache=save_to_cache))
elapsed_time = time.time() - start_time
self.write_result(f"FFV1-{mode.capitalize()}", elapsed_time, i)
if i % self.log_frequency == 0:
Expand All @@ -184,7 +192,7 @@ def measure_loading_time(self, mode="no_cache"):
print(f"Failed to load data: {e}")
return time.time() - start_time

def measure_random_loading_time(self, num_loads):
def measure_random_loading_time(self, num_loads, save_to_cache=True):
start_time = time.time()
loader = VLALoader(self.dataset_dir, cache_dir=CACHE_DIR)
dataset_size = len(loader)
Expand All @@ -194,7 +202,7 @@ def measure_random_loading_time(self, num_loads):
random_index = np.random.randint(0, dataset_size)
data = loader[random_index]
try:
self._recursively_load_data(data.load(mode="cache"))
self._recursively_load_data(data.load(save_to_cache=save_to_cache))
elapsed_time = time.time() - start_time
self.write_result(f"FFV1-RandomLoad", elapsed_time, i + 1)
if (i + 1) % self.log_frequency == 0:
Expand Down Expand Up @@ -260,7 +268,7 @@ def evaluation(args):
print(f"Evaluating dataset: {dataset_name}")

handlers = [
RLDSHandler(args.exp_dir, dataset_name, args.num_trajectories, args.log_frequency),
# RLDSHandler(args.exp_dir, dataset_name, args.num_trajectories, args.log_frequency),
VLAHandler(args.exp_dir, dataset_name, args.num_trajectories, args.log_frequency),
HDF5Handler(args.exp_dir, dataset_name, args.num_trajectories, args.log_frequency),
FFV1Handler(args.exp_dir, dataset_name, args.num_trajectories, args.log_frequency)
Expand Down
34 changes: 15 additions & 19 deletions fog_x/loader/vla.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,41 +2,37 @@
import fog_x
import glob
import logging

logger = logging.getLogger(__name__)
import asyncio
import os
from typing import Text

logger = logging.getLogger(__name__)

class VLALoader(BaseLoader):
def __init__(self, path: Text, cache_dir=None):
"""initialize VLALoader from paths
Args:
path (_type_): path to the vla files
can be a directory, or a glob pattern
split (_type_, optional): split of training and testing. Defaults to None.
"""
super(VLALoader, self).__init__(path)
self.index = 0
self.files = self._get_files(path)
self.cache_dir = cache_dir
self.loop = asyncio.get_event_loop()

def _get_files(self, path):
if "*" in path:
self.files = glob.glob(path)
return glob.glob(path)
elif os.path.isdir(path):
self.files = glob.glob(os.path.join(path, "*.vla"))
return glob.glob(os.path.join(path, "*.vla"))
else:
self.files = [path]

self.cache_dir = cache_dir
return [path]

def _read_vla(self, data_path):
async def _read_vla_async(self, data_path):
logger.debug(f"Reading {data_path}")
if self.cache_dir:
traj = fog_x.Trajectory(data_path, cache_dir=self.cache_dir)
else:
traj = fog_x.Trajectory(data_path)
traj = fog_x.Trajectory(data_path, cache_dir=self.cache_dir)
await traj.load_async()
return traj

def _read_vla(self, data_path):
return self.loop.run_until_complete(self._read_vla_async(data_path))

def __iter__(self):
return self

Expand Down
193 changes: 51 additions & 142 deletions fog_x/trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from fog_x import FeatureType
import pickle
import h5py
import asyncio
from concurrent.futures import ThreadPoolExecutor

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -77,6 +79,10 @@ def __init__(
self.stream_id_to_info = {} # stream_id: StreamInfo
self.is_closed = False
self.lossy_compression = lossy_compression
self.pending_write_tasks = [] # List to keep track of pending write tasks
self.cache_write_lock = asyncio.Lock()
self.cache_write_task = None
self.executor = ThreadPoolExecutor(max_workers=1)

# check if the path exists
# if not, create a new file and start data collection
Expand Down Expand Up @@ -145,33 +151,40 @@ def close(self, compact=True):
self.container_file = None
self.is_closed = True

def load(self, mode = "cache"):
def load(self, save_to_cache=True, return_h5=False):
"""
load the container file
Load the trajectory data.
returns the container file
Args:
mode (str): "cache" to use cached data if available, "no_cache" to always load from container.
return_h5 (bool): If True, return h5py.File object instead of numpy arrays.
workflow:
- check if a cached mmap/hdf5 file exists
- if exists, load the file
- otherwise: load the container file with entire vla trajctory
Returns:
dict: A dictionary of numpy arrays if return_h5 is False, otherwise an h5py.File object.
"""
if mode == "cache":
if os.path.exists(self.cache_file_name):
logger.debug(f"Loading the cached file {self.cache_file_name}")
self.trajectory_data = self._load_from_cache()

return asyncio.get_event_loop().run_until_complete(
self.load_async(save_to_cache=save_to_cache, return_h5=return_h5)
)

async def load_async(self, save_to_cache=True, return_h5=False):
if os.path.exists(self.cache_file_name):
logger.debug(f"Loading the cached file {self.cache_file_name}")
if return_h5:
return h5py.File(self.cache_file_name, "r")
else:
logger.debug(f"Loading the container file {self.path}, saving to cache {self.cache_file_name}")
self.trajectory_data = self._load_from_container(save_to_cache=True)
elif mode == "no_cache":
logger.debug(f"Loading the container file {self.path} without cache")
# self.trajectory_data = self._load_from_container_to_h5()
self.trajectory_data = self._load_from_container(save_to_cache=False)
with h5py.File(self.cache_file_name, "r") as h5_cache:
return {k: np.array(v) for k, v in h5_cache.items()}
else:
logger.debug(f"No option provided. Force loading from container file {self.path}")
self.trajectory_data = self._load_from_container(save_to_cache=False)

return self.trajectory_data
logger.debug(f"Loading the container file {self.path}, saving to cache {self.cache_file_name}")
np_cache = self._load_from_container()
if save_to_cache:
await self._async_write_to_cache(np_cache)

if return_h5:
return h5py.File(self.cache_file_name, "r")
else:
return np_cache

def init_feature_streams(self, feature_spec: Dict):
"""
Expand Down Expand Up @@ -346,115 +359,15 @@ def _load_from_cache(self):
h5_cache = h5py.File(self.cache_file_name, "r")
return h5_cache

def _load_from_container_to_h5(self):
"""
load the container file with entire vla trajctory
workflow:
- get schema of the container file
- preallocate decoded streams
- decode frame by frame and store in the preallocated memory
"""

container = av.open(self.path, mode="r", format="matroska")
h5_cache = h5py.File(self.cache_file_name, "w")
streams = container.streams

# preallocate memory for the streams in h5
for stream in streams:
feature_name = stream.metadata.get("FEATURE_NAME")
if feature_name is None:
logger.warn(f"Skipping stream without FEATURE_NAME: {stream}")
continue
feature_type = FeatureType.from_str(stream.metadata.get("FEATURE_TYPE"))
self.feature_name_to_stream[feature_name] = stream
self.feature_name_to_feature_type[feature_name] = feature_type
# Preallocate arrays with the shape [None, X, Y, Z]
# where X, Y, Z are the dimensions of the feature

logger.debug(
f"creating a cache for {feature_name} with shape {feature_type.shape}"
)

if feature_type.dtype == "string":
# strings are not supported in h5py, so we store them as objects
h5_cache.create_dataset(
feature_name,
(0,) + feature_type.shape,
maxshape=(None,) + feature_type.shape,
dtype=h5py.special_dtype(vlen=str),
chunks=(100,) + feature_type.shape,
)
else:
h5_cache.create_dataset(
feature_name,
(0,) + feature_type.shape,
maxshape=(None,) + feature_type.shape,
dtype=feature_type.dtype,
chunks=(100,) + feature_type.shape,
)

# decode the frames and store in the preallocated memory
d_feature_length = {feature: 0 for feature in self.feature_name_to_stream}
for packet in container.demux(list(streams)):
feature_name = packet.stream.metadata.get("FEATURE_NAME")
if feature_name is None:
logger.debug(f"Skipping stream without FEATURE_NAME: {stream}")
continue
feature_type = FeatureType.from_str(
packet.stream.metadata.get("FEATURE_TYPE")
)
logger.debug(
f"Decoding {feature_name} with shape {feature_type.shape} and dtype {feature_type.dtype} with time {packet.dts}"
)
feature_codec = packet.stream.codec_context.codec.name
if feature_codec == "h264":
frames = packet.decode()

for frame in frames:
if feature_type.dtype == "float32":
data = frame.to_ndarray(format="gray").reshape(
feature_type.shape
)
else:
data = frame.to_ndarray(format="rgb24").reshape(
feature_type.shape
)
h5_cache[feature_name].resize(
h5_cache[feature_name].shape[0] + 1, axis=0
)
h5_cache[feature_name][-1] = data
d_feature_length[feature_name] += 1
else:
packet_in_bytes = bytes(packet)
if packet_in_bytes:
# decode the packet
data = pickle.loads(packet_in_bytes)
h5_cache[feature_name].resize(
h5_cache[feature_name].shape[0] + 1, axis=0
)
h5_cache[feature_name][-1] = data
d_feature_length[feature_name] += 1
else:
logger.debug(f"Skipping empty packet: {packet} for {feature_name}")
container.close()
h5_cache.close()
h5_cache = h5py.File(self.cache_file_name, "r")
return h5_cache

def _load_from_container(self, save_to_cache: bool = True):
def _load_from_container(self):
"""
Load the container file with the entire VLA trajectory.
args:
save_to_cache: save the decoded data to the cache file
returns:
h5_cache: h5py file with the decoded data
or
dict: dictionary with the decoded data
np_cache: dictionary with the decoded data
Workflow:
- Get schema of the container file.
Expand Down Expand Up @@ -544,37 +457,33 @@ def _get_length_of_stream(container, stream):
logger.debug(f"Length of the stream {feature_name} is {d_feature_length[feature_name]}")
container.close()

if save_to_cache:
# create and save it to be hdf5 file
h5_cache = h5py.File(self.cache_file_name, "w")
return np_cache

async def _async_write_to_cache(self, np_cache):
async with self.cache_write_lock:
await asyncio.get_event_loop().run_in_executor(
self.executor,
self._write_to_cache,
np_cache
)

def _write_to_cache(self, np_cache):
with h5py.File(self.cache_file_name, "w") as h5_cache:
for feature_name, data in np_cache.items():
if data.dtype == object:
for i in range(len(data)):
data_type = type(data[i])
if data_type == str:
data[i] = str(data[i])
elif data_type == bytes:
data[i] = str(data[i])
elif data_type == np.ndarray:
if data_type in (str, bytes, np.ndarray):
data[i] = str(data[i])
else:
data[i] = str(data[i])
try:
h5_cache.create_dataset(
feature_name,
data=data
)
h5_cache.create_dataset(feature_name, data=data)
except Exception as e:
logger.error(f"Error saving {feature_name} to cache: {e} with data {data}")
else:
h5_cache.create_dataset(feature_name, data=data)
h5_cache.close()
h5_cache = h5py.File(self.cache_file_name, "r")
return h5_cache
else:
return np_cache



def _transcode_pickled_images(self, ending_timestamp: Optional[int] = None):
"""
Transcode pickled images into the desired format (e.g., raw or encoded images).
Expand Down

0 comments on commit 143a4fe

Please sign in to comment.