Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallel writing and other fixes #23

Merged
merged 17 commits into from
Aug 20, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 134 additions & 13 deletions src/ome2024_ngff_challenge/resave.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from __future__ import annotations

import argparse
import itertools
import json
import logging
import math
Expand Down Expand Up @@ -29,17 +30,70 @@
#


class Batched:
"""
implementation of itertools.batched for pre-3.12 Python versions
from https://mathspp.com/blog/itertools-batched
"""

def __init__(self, iterable, n: int):
if n < 1:
msg = f"n must be at least one ({n})"
raise ValueError(msg)
self.iter = iter(iterable)
self.n = n

def __iter__(self):
return self

def __next__(self):
batch = tuple(itertools.islice(self.iter, self.n))
if not batch:
raise StopIteration()
return batch


class SafeEncoder(json.JSONEncoder):
# Handle any TypeErrors so we are safe to use this for logging
# E.g. dtype obj is not JSON serializable
def default(self, o):
try:
return super().default(o)
except TypeError:
return str(o)


def guess_shards(shape: list, chunks: list):
"""
Method to calculate best shard sizes. These values can be written to
a file for the current dataset by using:

./resave.py input.zarr output.json --output-write-details
"""
# TODO: hard-coded to return the full size unless too large
if math.prod(shape) < 100_000_000:
return shape
raise ValueError(f"no shard guess: shape={shape}, chunks={chunks}")
# TODO: hard-coded to return the full size
assert chunks is not None # fixes unused parameter
return shape


def chunk_iter(shape: list, chunks: list):
"""
Returns a series of tuples, each containing chunk slice
E.g. for 2D shape/chunks: ((slice(0, 512, 1), slice(0, 512, 1)), (slice(0, 512, 1), slice(512, 1024, 1))...)
Thanks to Davis Bennett.
"""
assert len(shape) == len(chunks)
chunk_iters = []
for chunk_size, dim_size in zip(chunks, shape):
chunk_tuple = tuple(
slice(
c_index * chunk_size,
min(dim_size, c_index * chunk_size + chunk_size),
1,
)
for c_index in range(-(-dim_size // chunk_size))
)
chunk_iters.append(chunk_tuple)
return tuple(itertools.product(*chunk_iters))


def csv_int(vstr, sep=",") -> list:
Expand All @@ -53,7 +107,7 @@ def csv_int(vstr, sep=",") -> list:
values.append(v)
except ValueError as ve:
raise argparse.ArgumentError(
message="Invalid value %s, values must be a number" % v0
message=f"Invalid value {v0}, values must be a number"
) from ve
return values

Expand Down Expand Up @@ -237,7 +291,9 @@ def check_or_delete_path(self):
else:
shutil.rmtree(self.path)
else:
raise Exception(f"{self.path} exists. Exiting")
raise Exception(
f"{self.path} exists. Use --output-overwrite to overwrite"
)

def open_group(self):
# Needs zarr_format=2 or we get ValueError("store mode does not support writing")
Expand Down Expand Up @@ -291,6 +347,7 @@ def convert_array(
dimension_names: list,
chunks: list,
shards: list,
threads: int,
):
read = input_config.ts_read()

Expand Down Expand Up @@ -340,13 +397,44 @@ def convert_array(
write_config["create"] = True
write_config["delete_existing"] = output_config.overwrite

LOGGER.log(
5,
f"""input_config:
{json.dumps(input_config.ts_config, indent=4)}
""",
)
LOGGER.log(
5,
f"""write_config:
{json.dumps(write_config, indent=4, cls=SafeEncoder)}
""",
)

verify_config = base_config.copy()

write = ts.open(write_config).result()

before = TSMetrics(input_config.ts_config, write_config)
future = write.write(read)
future.result()

# read & write a chunk (or shard) at a time:
blocks = shards if shards is not None else chunks
for idx, batch in enumerate(Batched(chunk_iter(read.shape, blocks), threads)):
start = time.time()
with ts.Transaction() as txn:
LOGGER.log(5, f"batch {idx:03d}: scheduling transaction size={len(batch)}")
for slice_tuple in batch:
write.with_transaction(txn)[slice_tuple] = read[slice_tuple]
LOGGER.log(
5, f"batch {idx:03d}: {slice_tuple} scheduled in transaction"
)
LOGGER.log(5, f"batch {idx:03d}: waiting on transaction size={len(batch)}")
stop = time.time()
elapsed = stop - start
avg = float(elapsed) / len(batch)
LOGGER.debug(
f"batch {idx:03d}: completed transaction size={len(batch)} in {stop-start:0.2f}s (avg={avg:0.2f})"
)

after = TSMetrics(input_config.ts_config, write_config, before)

LOGGER.info(f"""Re-encode (tensorstore) {input_config} to {output_config}
Expand Down Expand Up @@ -374,6 +462,7 @@ def convert_image(
output_read_details: str | None,
output_write_details: bool,
output_script: bool,
threads: int,
):
dimension_names = None
# top-level version...
Expand Down Expand Up @@ -417,13 +506,21 @@ def convert_image(
with output_config.path.open(mode="w") as o:
json.dump(details, o)
else:
if output_chunks:
ds_chunks = output_chunks
ds_shards = output_shards
elif output_read_details:
if output_read_details:
# read row by row and overwrite
ds_chunks = details[idx]["chunks"]
ds_shards = details[idx]["shards"]
else:
if output_chunks:
ds_chunks = output_chunks
if output_shards:
ds_shards = output_shards
elif not output_script and math.prod(ds_shards) > 100_000_000:
# if we're going to convert, and we guessed the shards,
# let's validate the guess...
raise ValueError(
f"no shard guess: shape={ds_shape}, chunks={ds_chunks}"
)

if output_script:
chunk_txt = ",".join(map(str, ds_chunks))
Expand All @@ -440,6 +537,7 @@ def convert_image(
dimension_names,
ds_chunks,
ds_shards,
threads,
)


Expand Down Expand Up @@ -549,6 +647,7 @@ def main(ns: argparse.Namespace, rocrate: ROCrateWriter | None = None) -> int:
ns.output_read_details,
ns.output_write_details,
ns.output_script,
ns.output_threads,
)
converted += 1

Expand Down Expand Up @@ -602,6 +701,7 @@ def main(ns: argparse.Namespace, rocrate: ROCrateWriter | None = None) -> int:
ns.output_read_details,
ns.output_write_details,
ns.output_script,
ns.output_threads,
)
converted += 1
# Note: plates can *also* contain this metadata
Expand Down Expand Up @@ -644,6 +744,7 @@ def main(ns: argparse.Namespace, rocrate: ROCrateWriter | None = None) -> int:
ns.output_read_details,
ns.output_write_details,
ns.output_script,
ns.output_threads,
)
converted += 1
else:
Expand All @@ -669,12 +770,21 @@ def cli(args=sys.argv[1:]):
parser.add_argument("--output-region", default="us-east-1")
parser.add_argument("--output-overwrite", action="store_true")
parser.add_argument("--output-script", action="store_true")
parser.add_argument(
"--output-threads",
type=int,
default=16,
help="number of simultaneous write threads",
)
parser.add_argument("--rocrate-name", type=str)
parser.add_argument("--rocrate-description", type=str)
parser.add_argument("--rocrate-license", type=str)
parser.add_argument("--rocrate-organism", type=str)
parser.add_argument("--rocrate-modality", type=str)
parser.add_argument("--rocrate-skip", action="store_true")
parser.add_argument(
"--log", default="warn", help="'error', 'warn', 'info', 'debug' or 'trace'"
)
group_ex = parser.add_mutually_exclusive_group()
group_ex.add_argument(
"--output-write-details",
Expand All @@ -698,7 +808,18 @@ def cli(args=sys.argv[1:]):
parser.add_argument("output_path", type=Path)
ns = parser.parse_args(args)

logging.basicConfig()
# configure logging
if ns.log.upper() == "TRACE":
numeric_level = 5
else:
numeric_level = getattr(logging, ns.log.upper(), None)
if not isinstance(numeric_level, int):
raise ValueError(f"Invalid log level: {ns.log}. Use 'info' or 'debug'")
logging.basicConfig(
level=numeric_level,
format="%(asctime)s.%(msecs)03d %(levelname)-8s %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)

rocrate = None
if not ns.rocrate_skip:
Expand Down