Skip to content

Commit

Permalink
feat: support min and max pooling for downsampling (#176)
Browse files Browse the repository at this point in the history
* feat: support min and max pooling for downsampling

* refactor: use the downsample method function exclusively

* fix: wrong dependency for tinybrain

* fixtest: use updated invocation of num_mips_from_memory_target
  • Loading branch information
william-silversmith authored May 21, 2024
1 parent 4736d98 commit 478d1f3
Show file tree
Hide file tree
Showing 6 changed files with 93 additions and 38 deletions.
11 changes: 10 additions & 1 deletion igneous/task_creation/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
)

from igneous.shards import image_shard_shape_from_spec
from igneous.types import ShapeType
from igneous.types import ShapeType, DownsampleMethods

from .common import (
operator_contact, FinelyDividedTaskIterator,
Expand Down Expand Up @@ -210,6 +210,7 @@ def create_downsampling_tasks(
bounds_mip:int = 0,
memory_target:int = MEMORY_TARGET,
encoding_level:Optional[int] = None,
method:int = DownsampleMethods.AUTO,
):
"""
Creates a set of unsharded downsampling tasks and inserts them into the queue.
Expand Down Expand Up @@ -307,6 +308,7 @@ def task(self, shape, offset):
compress=compress,
factor=factor,
max_mips=num_mips,
method=method,
)

def on_finish(self):
Expand All @@ -329,6 +331,7 @@ def on_finish(self):
'dest_path': dest_path,
'compress': compress,
'factor': (tuple(factor) if factor else None),
'method': method,
},
'by': operator_contact(),
'date': strftime('%Y-%m-%d %H:%M %Z'),
Expand Down Expand Up @@ -605,6 +608,7 @@ def create_image_shard_downsample_tasks(
agglomerate=False, timestamp=None,
factor=(2,2,1), bounds=None, bounds_mip=0,
encoding_level:Optional[int] = None,
method=DownsampleMethods.AUTO,
):
"""
Downsamples an existing image layer that may be
Expand Down Expand Up @@ -652,6 +656,7 @@ def task(self, shape, offset):
agglomerate=bool(agglomerate),
timestamp=timestamp,
factor=tuple(factor),
method=method,
)

def on_finish(self):
Expand All @@ -666,6 +671,7 @@ def on_finish(self):
"mip": mip,
"agglomerate": agglomerate,
"timestamp": timestamp,
"method": method,
},
"by": operator_contact(),
"date": strftime("%Y-%m-%d %H:%M %Z"),
Expand Down Expand Up @@ -806,6 +812,7 @@ def create_transfer_tasks(
truncate_scales:bool = True,
cutout:bool = False,
stop_layer:Optional[int] = None,
downsample_method:int = DownsampleMethods.AUTO,
) -> Iterator:
"""
Transfer data to a new data layer. You can use this operation
Expand Down Expand Up @@ -970,6 +977,7 @@ def task(self, shape, offset):
factor=factor,
sparse=sparse,
stop_layer=stop_layer,
downsample_method=int(downsample_method),
)

def on_finish(self):
Expand Down Expand Up @@ -998,6 +1006,7 @@ def on_finish(self):
'sparse': bool(sparse),
'encoding_level': encoding_level,
'stop_layer': stop_layer,
'downsample_method': int(downsample_method),
},
'by': operator_contact(),
'date': strftime('%Y-%m-%d %H:%M %Z'),
Expand Down
59 changes: 37 additions & 22 deletions igneous/tasks/image/image.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections.abc import Sequence

from functools import partial
import json
import math
import os
Expand All @@ -25,18 +26,39 @@

import igneous.shards
from igneous import downsample_scales
from igneous.types import ShapeType
from igneous.types import ShapeType, DownsampleMethods

from .obsolete import (
HyperSquareConsensusTask, WatershedRemapTask,
MaskAffinitymapTask, InferenceTask
)

def downsample_method_to_fn(method, sparse, vol):
if method == DownsampleMethods.AUTO:
if vol.layer_type == 'image':
method = DownsampleMethods.AVERAGE_POOLING
elif vol.layer_type == 'segmentation':
method = DownsampleMethods.MODE_POOLING
else:
method = DownsampleMethods.STRIDING

if method == DownsampleMethods.MIN_POOLING:
return tinybrain.downsample_with_min_pooling
elif method == DownsampleMethods.MAX_POOLING:
return tinybrain.downsample_with_max_pooling
elif method == DownsampleMethods.AVERAGE_POOLING:
return partial(tinybrain.downsample_with_averaging, sparse=sparse)
elif method == DownsampleMethods.MODE_POOLING:
return partial(tinybrain.downsample_segmentation, sparse=sparse)
else:
return tinybrain.downsample_with_striding

def downsample_and_upload(
image, bounds, vol, ds_shape,
mip=0, axis='z', skip_first=False,
sparse=False, factor=None, max_mips=None
):
image, bounds, vol, ds_shape,
mip=0, axis='z', skip_first=False,
sparse=False, factor=None, max_mips=None,
method=DownsampleMethods.AUTO,
):
ds_shape = min2(vol.volume_size, ds_shape[:3])
underlying_mip = (mip + 1) if (mip + 1) in vol.available_mips else mip
chunk_size = vol.meta.chunk_size(underlying_mip).astype(np.float32)
Expand All @@ -63,19 +85,10 @@ def downsample_and_upload(
num_mips = len(factors)

mips = []
if vol.layer_type == 'image':
mips = tinybrain.downsample_with_averaging(
image, factors[0],
num_mips=num_mips, sparse=sparse
)
elif vol.layer_type == 'segmentation':
mips = tinybrain.downsample_segmentation(
image, factors[0],
num_mips=num_mips, sparse=sparse
)
else:
mips = tinybrain.downsample_with_striding(image, factors[0], num_mips=num_mips)

fn = downsample_method_to_fn(method, sparse, vol)
mips = fn(image, factors[0], num_mips=num_mips)

new_bounds = bounds.clone()

for factor3 in factors:
Expand Down Expand Up @@ -384,6 +397,7 @@ def TransferTask(
factor=None,
max_mips:Optional[int] = None,
stop_layer:Optional[int] = None,
downsample_method:str = DownsampleMethods.AUTO,
):
"""
Transfer an image to a new location while enabling
Expand Down Expand Up @@ -445,6 +459,7 @@ def TransferTask(
skip_first=skip_first,
sparse=sparse, axis=axis,
factor=factor, max_mips=max_mips,
method=downsample_method,
)

@queueable
Expand All @@ -453,7 +468,7 @@ def DownsampleTask(
fill_missing=False, axis='z', sparse=False,
delete_black_uploads=False, background_color=0,
dest_path=None, compress="gzip", factor=None,
max_mips=None,
max_mips=None, method=DownsampleMethods.AUTO,
):
"""
Downsamples a cutout of the volume. By default it performs
Expand All @@ -477,6 +492,7 @@ def DownsampleTask(
compress=compress,
factor=factor,
max_mips=max_mips,
downsample_method=DownsampleMethods.AUTO,
)

@queueable
Expand Down Expand Up @@ -608,7 +624,8 @@ def ImageShardDownsampleTask(
sparse: bool = False,
agglomerate: bool = False,
timestamp: Optional[int] = None,
factor: ShapeType = (2,2,1)
factor: ShapeType = (2,2,1),
method: int = DownsampleMethods.AUTO,
):
"""
Generate a single downsample level for a shard.
Expand Down Expand Up @@ -656,9 +673,7 @@ def ImageShardDownsampleTask(
output_img = np.zeros(shard_shape, dtype=src_vol.dtype, order="F")
nz = int(math.ceil(bbox.dz / (chunk_size.z * factor[2])))

dsfn = tinybrain.downsample_with_averaging
if src_vol.layer_type == "segmentation":
dsfn = tinybrain.downsample_segmentation
dsfn = downsample_method_to_fn(method, sparse, vol)

zbox = bbox.clone()
zbox.maxpt.z = zbox.minpt.z + (chunk_size.z * factor[2])
Expand Down
11 changes: 10 additions & 1 deletion igneous/types.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
import enum
from typing import Any, Dict, Tuple, Union, Optional

ShapeType = Tuple[int, int, int]
ShapeType = Tuple[int, int, int]

class DownsampleMethods(enum.IntEnum):
AVERAGE_POOLING = 1
MODE_POOLING = 2
MIN_POOLING = 3
MAX_POOLING = 4
STRIDING = 5
AUTO = 6
25 changes: 23 additions & 2 deletions igneous_cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from igneous import task_creation as tc
from igneous import downsample_scales
from igneous.secrets import LEASE_SECONDS, SQS_REGION_NAME
from igneous.types import DownsampleMethods

from igneous_cli.humanbytes import format_bytes

Expand Down Expand Up @@ -123,6 +124,24 @@ class CloudPath(click.ParamType):
def convert(self, value, param, ctx):
return cloudfiles.paths.normalize(value)

class DownsampleMethodType(click.ParamType):
name = "DownsampleMethod"
def convert(self, value, param, ctx):
if value == "auto":
return DownsampleMethods.AUTO
elif value == "avg":
return DownsampleMethods.AVERAGE_POOLING
elif value == "mode":
return DownsampleMethods.MODE_POOLING
elif value == "min":
return DownsampleMethods.MIN_POOLING
elif value == "max":
return DownsampleMethods.MAX_POOLING
elif value == "striding":
return DownsampleMethods.STRIDING
else:
raise ValueError(f"Downsample method {value} not supported.")

def compute_bounds(path, mip, xrange, yrange, zrange):
bounds = None
if xrange or yrange or zrange:
Expand Down Expand Up @@ -208,6 +227,7 @@ def imagegroup():
@click.option('--bg-color', default=0, help="Determines which color is regarded as background. Default: 0")
@click.option('--sharded', is_flag=True, default=False, help="Generate sharded downsamples which reduces the number of files.")
@click.option('--memory', default=3.5e9, type=int, help="(sharded only) Task memory limit in bytes. Task shape will be chosen to fit and maximize downsamples.", show_default=True)
@click.option('--method', default="auto", type=DownsampleMethodType(), help="Select the downsample method type. Options: auto, avg, mode, min, max, striding", show_default=True)
@click.option('--xrange', type=Tuple2(), default=None, help="If specified, set x-bounds for downsampling in terms of selected mip. By default the whole dataset is selected. The bounds must be chunk aligned to the task size (maybe mysterious... use igneous design to investigate). e.g. 0,1024.", show_default=True)
@click.option('--yrange', type=Tuple2(), default=None, help="If specified, set y-bounds for downsampling in terms of selected mip. By default the whole dataset is selected. The bounds must be chunk aligned to the task size (maybe mysterious... use igneous design to investigate). e.g. 0,1024", show_default=True)
@click.option('--zrange', type=Tuple2(), default=None, help="If specified, set z-bounds for downsampling in terms of selected mip. By default the whole dataset is selected. The bounds must be chunk aligned to the task size (maybe mysterious... use igneous design to investigate). e.g. 0,1", show_default=True)
Expand All @@ -217,7 +237,7 @@ def downsample(
num_mips, encoding, encoding_level, sparse,
chunk_size, compress, volumetric,
delete_bg, bg_color, sharded, memory,
xrange, yrange, zrange,
xrange, yrange, zrange, method,
):
"""
Create an image pyramid for grayscale or labeled images.
Expand Down Expand Up @@ -250,7 +270,7 @@ def downsample(
sparse=sparse, chunk_size=chunk_size,
encoding=encoding, memory_target=memory,
factor=factor, bounds=bounds, bounds_mip=mip,
encoding_level=encoding_level,
encoding_level=encoding_level, method=method,
)
else:
tasks = tc.create_downsampling_tasks(
Expand All @@ -264,6 +284,7 @@ def downsample(
bounds_mip=mip,
memory_target=memory,
encoding_level=encoding_level,
downsample_method=method,
)

enqueue_tasks(ctx, queue, tasks)
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ pytest>=3.3.1
pytz
scipy
shard-computer
tinybrain
tinybrain>=1.5.0
task-queue>=2.4.0
tqdm
trimesh[easy]
Expand Down
23 changes: 12 additions & 11 deletions test/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,48 +574,49 @@ def test_num_mips_from_memory_target():
memory = 0
chunk_size = (128,128,64)
factor = (2,2,1)
num_channels = 1

num_mips = num_mips_from_memory_target(memory, 'uint8', chunk_size, factor)
num_mips = num_mips_from_memory_target(memory, 'uint8', chunk_size, num_channels, factor)
assert num_mips == 1

memory = 100e6
num_mips = num_mips_from_memory_target(memory, 'uint8', chunk_size, factor)
num_mips = num_mips_from_memory_target(memory, 'uint8', chunk_size, num_channels, factor)
assert num_mips == 3

memory = 100e6
num_mips = num_mips_from_memory_target(memory, 'uint16', chunk_size, factor)
num_mips = num_mips_from_memory_target(memory, 'uint16', chunk_size, num_channels, factor)
assert num_mips == 2

memory = 100e6
num_mips = num_mips_from_memory_target(memory, 'uint32', chunk_size, factor)
num_mips = num_mips_from_memory_target(memory, 'uint32', chunk_size, num_channels, factor)
assert num_mips == 2

memory = 100e6
num_mips = num_mips_from_memory_target(memory, 'uint64', chunk_size, factor)
num_mips = num_mips_from_memory_target(memory, 'uint64', chunk_size, num_channels, factor)
assert num_mips == 1

memory = 3.5e9
num_mips = num_mips_from_memory_target(memory, 'uint64', chunk_size, factor)
num_mips = num_mips_from_memory_target(memory, 'uint64', chunk_size, num_channels, factor)
assert num_mips == 4

memory = 12e9
num_mips = num_mips_from_memory_target(memory, 'uint64', chunk_size, factor)
num_mips = num_mips_from_memory_target(memory, 'uint64', chunk_size, num_channels, factor)
assert num_mips == 5

factor = (2,2,2)

memory = 800e6
num_mips = num_mips_from_memory_target(memory, 'uint8', chunk_size, factor)
num_mips = num_mips_from_memory_target(memory, 'uint8', chunk_size, num_channels, factor)
assert num_mips == 3

memory = 500e6
num_mips = num_mips_from_memory_target(memory, 'uint8', chunk_size, factor)
num_mips = num_mips_from_memory_target(memory, 'uint8', chunk_size, num_channels, factor)
assert num_mips == 2

memory = 100e6
num_mips = num_mips_from_memory_target(memory, 'uint8', chunk_size, factor)
num_mips = num_mips_from_memory_target(memory, 'uint8', chunk_size, num_channels, factor)
assert num_mips == 2

memory = 50e6
num_mips = num_mips_from_memory_target(memory, 'uint8', chunk_size, factor)
num_mips = num_mips_from_memory_target(memory, 'uint8', chunk_size, num_channels, factor)
assert num_mips == 1

0 comments on commit 478d1f3

Please sign in to comment.