Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce readers' memory consumption #229

Open
marcovarrone opened this issue Oct 28, 2024 · 1 comment
Open

Reduce readers' memory consumption #229

marcovarrone opened this issue Oct 28, 2024 · 1 comment

Comments

@marcovarrone
Copy link

marcovarrone commented Oct 28, 2024

In the current implementation, a reader loads all the object from the raw files in memory and creates a spatialdata object. Is then the user's responsibility to save the object to disk.
For big samples, this leads to large memory requirements as it needs to load the whole object in memory.

I would propose to add a new output_path parameter in a reader's function to allow saving every element of the spatialdata object as soon as it's created. This allows to free up part of the memory during the function execution.

I created a draft pull request (#228) only for Xenium but if I receive the ok I can implement it for the rest of the readers.

Here is a comparison using memory_profiler

Current version

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
   55    366.7 MiB    366.7 MiB           1   @deprecation_alias(cells_as_shapes="cells_as_circles", cell_boundaries="cells_boundaries", cell_labels="cells_labels")
    56                                         @inject_docs(xx=XeniumKeys)
    57                                         @profile(stream=fp)
    58                                         def xenium(
    59                                             path: str | Path,
    60                                             *,
    61                                             cells_boundaries: bool = True,
    62                                             nucleus_boundaries: bool = True,
    63                                             cells_as_circles: bool | None = None,
    64                                             cells_labels: bool = True,
    65                                             nucleus_labels: bool = True,
    66                                             transcripts: bool = True,
    67                                             morphology_mip: bool = True,
    68                                             morphology_focus: bool = True,
    69                                             aligned_images: bool = True,
    70                                             cells_table: bool = True,
    71                                             n_jobs: int = 1,
    72                                             imread_kwargs: Mapping[str, Any] = MappingProxyType({}),
    73                                             image_models_kwargs: Mapping[str, Any] = MappingProxyType({}),
    74                                             labels_models_kwargs: Mapping[str, Any] = MappingProxyType({}),
    75                                         ) -> SpatialData:

   155    366.7 MiB      0.0 MiB           1       if cells_as_circles is None:
   156                                                 cells_as_circles = True
   157                                                 warnings.warn(
   158                                                     "The default value of `cells_as_circles` will change to `False` in the next release. "
   159                                                     "Please pass `True` explicitly to maintain the current behavior.",
   160                                                     DeprecationWarning,
   161                                                     stacklevel=3,
   162                                                 )
   163    366.7 MiB      0.0 MiB           2       image_models_kwargs, labels_models_kwargs = _initialize_raster_models_kwargs(
   164    366.7 MiB      0.0 MiB           1           image_models_kwargs, labels_models_kwargs
   165                                             )
   166    366.7 MiB      0.0 MiB           1       path = Path(path)
   167    366.7 MiB      0.0 MiB           2       with open(path / XeniumKeys.XENIUM_SPECS) as f:
   168    366.7 MiB      0.0 MiB           1           specs = json.load(f)
   169                                             # to trigger the warning if the version cannot be parsed
   170    366.7 MiB      0.0 MiB           1       version = _parse_version_of_xenium_analyzer(specs, hide_warning=False)
   171                                         
   172    366.7 MiB      0.0 MiB           1       specs["region"] = "cell_circles" if cells_as_circles else "cell_labels"
   173                                         
   174                                             # the table is required in some cases
   175    366.7 MiB      0.0 MiB           1       if not cells_table:
   176                                                 if cells_as_circles:
   177                                                     logging.info(
   178                                                         'When "cells_as_circles" is set to `True` reading the table is required; setting `cell_annotations` to '
   179                                                         "`True`."
   180                                                     )
   181                                                     cells_table = True
   182                                                 if cells_boundaries or nucleus_boundaries:
   183                                                     logging.info(
   184                                                         'When "cell_boundaries" or "nucleus_boundaries" is set to `True` reading the table is required; '
   185                                                         "setting `cell_annotations` to `True`."
   186                                                     )
   187                                                     cells_table = True
   188                                         
   189    366.7 MiB      0.0 MiB           1       if cells_table:
   190   1206.8 MiB    840.1 MiB           1           return_values = _get_tables_and_circles(path, cells_as_circles, specs)
   191   1206.8 MiB      0.0 MiB           1           if cells_as_circles:
   192   1206.8 MiB      0.0 MiB           1               table, circles = return_values
   193                                                 else:
   194                                                     table = return_values
   195                                             else:
   196                                                 table = None
   197                                         
   198   1206.8 MiB      0.0 MiB           1       if version is not None and version >= packaging.version.parse("2.0.0") and table is not None:
   199                                                 cell_summary_table = _get_cells_metadata_table_from_zarr(path, XeniumKeys.CELLS_ZARR, specs)
   200                                                 if not cell_summary_table[XeniumKeys.CELL_ID].equals(table.obs[XeniumKeys.CELL_ID]):
   201                                                     warnings.warn(
   202                                                         'The "cell_id" column in the cells metadata table does not match the "cell_id" column in the annotation'
   203                                                         " table. This could be due to trying to read a new version that is not supported yet. Please "
   204                                                         "report this issue.",
   205                                                         UserWarning,
   206                                                         stacklevel=2,
   207                                                     )
   208                                                 table.obs[XeniumKeys.Z_LEVEL] = cell_summary_table[XeniumKeys.Z_LEVEL]
   209                                                 table.obs[XeniumKeys.NUCLEUS_COUNT] = cell_summary_table[XeniumKeys.NUCLEUS_COUNT]
   210                                         
   211   1206.8 MiB      0.0 MiB           1       polygons = {}
   212   1206.8 MiB      0.0 MiB           1       labels = {}
   213   1206.8 MiB      0.0 MiB           1       tables = {}
   214   1206.8 MiB      0.0 MiB           1       points = {}
   215   1206.8 MiB      0.0 MiB           1       images = {}
   216                                         
   217                                             # From the public release notes here:
   218                                             # https://www.10xgenomics.com/support/software/xenium-onboard-analysis/latest/release-notes/release-notes-for-xoa
   219                                             # we see that for distinguishing between the nuclei of polinucleated cells, the `label_id` column is used.
   220                                             # This column is currently not found in the preview data, while I think it is needed in order to unambiguously match
   221                                             # nuclei to cells. Therefore for the moment we only link the table to the cell labels, and not to the nucleus
   222                                             # labels.
   223   1206.8 MiB      0.0 MiB           1       if nucleus_labels:
   224  34193.5 MiB  32986.8 MiB           2           labels["nucleus_labels"], _ = _get_labels_and_indices_mapping(
   225   1206.8 MiB      0.0 MiB           1               path,
   226   1206.8 MiB      0.0 MiB           1               XeniumKeys.CELLS_ZARR,
   227   1206.8 MiB      0.0 MiB           1               specs,
   228   1206.8 MiB      0.0 MiB           1               mask_index=0,
   229   1206.8 MiB      0.0 MiB           1               labels_name="nucleus_labels",
   230   1206.8 MiB      0.0 MiB           1               labels_models_kwargs=labels_models_kwargs,
   231                                                 )
   232  34192.5 MiB     -1.0 MiB           1           gc.collect()
   233  34192.5 MiB      0.0 MiB           1       if cells_labels:
   234  63019.2 MiB  28826.6 MiB           2           labels["cell_labels"], cell_labels_indices_mapping = _get_labels_and_indices_mapping(
   235  34192.5 MiB      0.0 MiB           1               path,
   236  34192.5 MiB      0.0 MiB           1               XeniumKeys.CELLS_ZARR,
   237  34192.5 MiB      0.0 MiB           1               specs,
   238  34192.5 MiB      0.0 MiB           1               mask_index=1,
   239  34192.5 MiB      0.0 MiB           1               labels_name="cell_labels",
   240  34192.5 MiB      0.0 MiB           1               labels_models_kwargs=labels_models_kwargs,
   241                                                 )
   242  63019.2 MiB      0.0 MiB           1           gc.collect()
   243  63019.2 MiB      0.0 MiB           1           if cell_labels_indices_mapping is not None and table is not None:
   244                                                     if not pd.DataFrame.equals(cell_labels_indices_mapping["cell_id"], table.obs[str(XeniumKeys.CELL_ID)]):
   245                                                         warnings.warn(
   246                                                             "The cell_id column in the cell_labels_table does not match the cell_id column derived from the "
   247                                                             "cell labels data. This could be due to trying to read a new version that is not supported yet. "
   248                                                             "Please report this issue.",
   249                                                             UserWarning,
   250                                                             stacklevel=2,
   251                                                         )
   252                                                     else:
   253                                                         table.obs["cell_labels"] = cell_labels_indices_mapping["label_index"]
   254                                                         if not cells_as_circles:
   255                                                             table.uns[TableModel.ATTRS_KEY][TableModel.INSTANCE_KEY] = "cell_labels"
   256                                         
   257  63019.2 MiB      0.0 MiB           1       if nucleus_boundaries:
   258  63669.8 MiB    650.7 MiB           2           polygons["nucleus_boundaries"] = _get_polygons(
   259  63019.2 MiB      0.0 MiB           1               path,
   260  63019.2 MiB      0.0 MiB           1               XeniumKeys.NUCLEUS_BOUNDARIES_FILE,
   261  63019.2 MiB      0.0 MiB           1               specs,
   262  63019.2 MiB      0.0 MiB           1               n_jobs,
   263  63019.2 MiB      0.0 MiB           1               idx=table.obs[str(XeniumKeys.CELL_ID)].copy(),
   264                                                 )
   265  63669.8 MiB      0.0 MiB           1           gc.collect()
   266                                         
   267  63669.8 MiB      0.0 MiB           1       if cells_boundaries:
   268  64175.0 MiB    505.2 MiB           2           polygons["cell_boundaries"] = _get_polygons(
   269  63669.8 MiB      0.0 MiB           1               path,
   270  63669.8 MiB      0.0 MiB           1               XeniumKeys.CELL_BOUNDARIES_FILE,
   271  63669.8 MiB      0.0 MiB           1               specs,
   272  63669.8 MiB      0.0 MiB           1               n_jobs,
   273  63669.8 MiB      0.0 MiB           1               idx=table.obs[str(XeniumKeys.CELL_ID)].copy(),
   274                                                 )
   275  64175.0 MiB      0.0 MiB           1           gc.collect()
   276                                         
   277  64175.0 MiB      0.0 MiB           1       if transcripts:
   278  64360.1 MiB    185.0 MiB           1           points["transcripts"] = _get_points(path, specs)
   279  64359.1 MiB     -1.0 MiB           1           gc.collect()
   280                                         
   281  64359.1 MiB      0.0 MiB           1       if version is None or version < packaging.version.parse("2.0.0"):
   282  64359.1 MiB      0.0 MiB           1           if morphology_mip:
   283  64359.1 MiB   -213.3 MiB           2               images["morphology_mip"] = _get_images(
   284  64359.1 MiB      0.0 MiB           1                   path,
   285  64359.1 MiB      0.0 MiB           1                   XeniumKeys.MORPHOLOGY_MIP_FILE,
   286  64359.1 MiB      0.0 MiB           1                   imread_kwargs,
   287  64359.1 MiB      0.0 MiB           1                   image_models_kwargs,
   288                                                     )
   289  64145.7 MiB   -213.3 MiB           1               gc.collect()
   290  64145.7 MiB      0.0 MiB           1           if morphology_focus:
   291  64145.7 MiB      0.0 MiB           2               images["morphology_focus"] = _get_images(
   292  64145.7 MiB      0.0 MiB           1                   path,
   293  64145.7 MiB      0.0 MiB           1                   XeniumKeys.MORPHOLOGY_FOCUS_FILE,
   294  64145.7 MiB      0.0 MiB           1                   imread_kwargs,
   295  64145.7 MiB      0.0 MiB           1                   image_models_kwargs,
   296                                                     )
   297  64145.7 MiB     -0.0 MiB           1               gc.collect()
   298                                             else:
   299                                                 if morphology_focus:
   300                                                     morphology_focus_dir = path / XeniumKeys.MORPHOLOGY_FOCUS_DIR
   301                                                     files = {f for f in os.listdir(morphology_focus_dir) if f.endswith(".ome.tif")}
   302                                                     if len(files) not in [1, 4]:
   303                                                         raise ValueError(
   304                                                             "Expected 1 (no segmentation kit) or 4 (segmentation kit) files in the morphology focus directory, "
   305                                                             f"found {len(files)}: {files}"
   306                                                         )
   307                                                     if files != {XeniumKeys.MORPHOLOGY_FOCUS_CHANNEL_IMAGE.value.format(i) for i in range(len(files))}:
   308                                                         raise ValueError(
   309                                                             "Expected files in the morphology focus directory to be named as "
   310                                                             f"{XeniumKeys.MORPHOLOGY_FOCUS_CHANNEL_IMAGE.value.format(0)} to "
   311                                                             f"{XeniumKeys.MORPHOLOGY_FOCUS_CHANNEL_IMAGE.value.format(len(files) - 1)}, found {files}"
   312                                                         )
   313                                                     # the 'dummy' channel is a temporary workaround, see _get_images() for more details
   314                                                     if len(files) == 1:
   315                                                         channel_names = {
   316                                                             0: XeniumKeys.MORPHOLOGY_FOCUS_CHANNEL_0.value,
   317                                                         }
   318                                                     else:
   319                                                         channel_names = {
   320                                                             0: XeniumKeys.MORPHOLOGY_FOCUS_CHANNEL_0.value,
   321                                                             1: XeniumKeys.MORPHOLOGY_FOCUS_CHANNEL_1.value,
   322                                                             2: XeniumKeys.MORPHOLOGY_FOCUS_CHANNEL_2.value,
   323                                                             3: XeniumKeys.MORPHOLOGY_FOCUS_CHANNEL_3.value,
   324                                                             4: "dummy",
   325                                                         }
   326                                                     # this reads the scale 0 for all the 1 or 4 channels (the other files are parsed automatically)
   327                                                     # dask.image.imread will call tifffile.imread which will give a warning saying that reading multi-file
   328                                                     # pyramids is not supported; since we are reading the full scale image and reconstructing the pyramid, we
   329                                                     # can ignore this
   330                                         
   331                                                     class IgnoreSpecificMessage(logging.Filter):
   332                                                         def filter(self, record: logging.LogRecord) -> bool:
   333                                                             # Ignore specific log message
   334                                                             if "OME series cannot read multi-file pyramids" in record.getMessage():
   335                                                                 return False
   336                                                             return True
   337                                         
   338                                                     logger = tifffile.logger()
   339                                                     logger.addFilter(IgnoreSpecificMessage())
   340                                                     image_models_kwargs = dict(image_models_kwargs)
   341                                                     assert (
   342                                                         "c_coords" not in image_models_kwargs
   343                                                     ), "The channel names for the morphology focus images are handled internally"
   344                                                     image_models_kwargs["c_coords"] = list(channel_names.values())
   345                                                     images["morphology_focus"] = _get_images(
   346                                                         morphology_focus_dir,
   347                                                         XeniumKeys.MORPHOLOGY_FOCUS_CHANNEL_IMAGE.format(0),
   348                                                         imread_kwargs,
   349                                                         image_models_kwargs,
   350                                                     )
   351                                                     del image_models_kwargs["c_coords"]
   352                                                     logger.removeFilter(IgnoreSpecificMessage())
   353                                         
   354  64145.7 MiB      0.0 MiB           1       if table is not None:
   355  64145.7 MiB      0.0 MiB           1           tables["table"] = table
   356  64145.7 MiB      0.0 MiB           1           gc.collect()
   357                                         
   358  64145.7 MiB      0.0 MiB           1       elements_dict = {"images": images, "labels": labels, "points": points, "tables": tables, "shapes": polygons}
   359  64145.7 MiB      0.0 MiB           1       if cells_as_circles:
   360  64145.7 MiB      0.0 MiB           1           elements_dict["shapes"][specs["region"]] = circles
   361  64145.7 MiB      0.0 MiB           1       sdata = SpatialData(**elements_dict)
   362                                         
   363                                             # find and add additional aligned images
   364  64145.7 MiB      0.0 MiB           1       if aligned_images:
   365  64145.7 MiB      0.0 MiB           1           extra_images = _add_aligned_images(path, imread_kwargs, image_models_kwargs)
   366  64145.7 MiB      0.0 MiB           2           for key, value in extra_images.items():
   367  64145.7 MiB      0.0 MiB           1               sdata.images[key] = value
   368  64145.7 MiB      0.0 MiB           1               gc.collect()
   369  64145.7 MiB      0.0 MiB           1       return sdata


New version

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    56    369.0 MiB    369.0 MiB           1   @deprecation_alias(cells_as_shapes="cells_as_circles", cell_boundaries="cells_boundaries", cell_labels="cells_labels")
    57                                         @inject_docs(xx=XeniumKeys)
    58                                         @profile(stream=fp)
    59                                         def xenium(
    60                                             path: str | Path,
    61                                             *,
    62                                             cells_boundaries: bool = True,
    63                                             nucleus_boundaries: bool = True,
    64                                             cells_as_circles: bool | None = None,
    65                                             cells_labels: bool = True,
    66                                             nucleus_labels: bool = True,
    67                                             transcripts: bool = True,
    68                                             morphology_mip: bool = True,
    69                                             morphology_focus: bool = True,
    70                                             aligned_images: bool = True,
    71                                             cells_table: bool = True,
    72                                             n_jobs: int = 1,
    73                                             imread_kwargs: Mapping[str, Any] = MappingProxyType({}),
    74                                             image_models_kwargs: Mapping[str, Any] = MappingProxyType({}),
    75                                             labels_models_kwargs: Mapping[str, Any] = MappingProxyType({}),
    76                                             output_path: Path | None = None,
    77                                         ) -> SpatialData:
    78                                             """

   133                                             output_path
   134                                                 Path to directly write the output to a zarr file. This can decrease the memory requirement. If not provided, the
   135                                                 function will return a :class:`spatialdata.SpatialData` object.

   159                                             """
   160    369.0 MiB      0.0 MiB           1       if cells_as_circles is None:
   161                                                 cells_as_circles = True
   162                                                 warnings.warn(
   163                                                     "The default value of `cells_as_circles` will change to `False` in the next release. "
   164                                                     "Please pass `True` explicitly to maintain the current behavior.",
   165                                                     DeprecationWarning,
   166                                                     stacklevel=3,
   167                                                 )
   168    369.0 MiB      0.0 MiB           2       image_models_kwargs, labels_models_kwargs = _initialize_raster_models_kwargs(
   169    369.0 MiB      0.0 MiB           1           image_models_kwargs, labels_models_kwargs
   170                                             )
   171    369.0 MiB      0.0 MiB           1       path = Path(path)
   172    369.0 MiB      0.0 MiB           1       output_path = Path(output_path) if output_path is not None else None
   173                                         
   174    369.1 MiB      0.0 MiB           2       with open(path / XeniumKeys.XENIUM_SPECS) as f:
   175    369.1 MiB      0.1 MiB           1           specs = json.load(f)
   176                                             # to trigger the warning if the version cannot be parsed
   177    369.1 MiB      0.0 MiB           1       version = _parse_version_of_xenium_analyzer(specs, hide_warning=False)
   178                                         
   179    369.1 MiB      0.0 MiB           1       specs["region"] = "cell_circles" if cells_as_circles else "cell_labels"
   180                                         
   181                                             # the table is required in some cases
   182    369.1 MiB      0.0 MiB           1       if not cells_table:
   183                                                 if cells_as_circles:
   184                                                     logging.info(
   185                                                         'When "cells_as_circles" is set to `True` reading the table is required; setting `cell_annotations` to '
   186                                                         "`True`."
   187                                                     )
   188                                                     cells_table = True
   189                                                 if cells_boundaries or nucleus_boundaries:
   190                                                     logging.info(
   191                                                         'When "cell_boundaries" or "nucleus_boundaries" is set to `True` reading the table is required; '
   192                                                         "setting `cell_annotations` to `True`."
   193                                                     )
   194                                                     cells_table = True
   195                                         
   196    369.1 MiB      0.0 MiB           1       if cells_table:
   197   1226.1 MiB    857.0 MiB           1           return_values = _get_tables_and_circles(path, cells_as_circles, specs)
   198   1226.1 MiB      0.0 MiB           1           if cells_as_circles:
   199   1226.1 MiB      0.0 MiB           1               table, circles = return_values
   200                                                 else:
   201                                                     table = return_values
   202                                             else:
   203                                                 table = None
   204                                         
   205   1226.1 MiB      0.0 MiB           1       if version is not None and version >= packaging.version.parse("2.0.0") and table is not None:
   206                                                 cell_summary_table = _get_cells_metadata_table_from_zarr(path, XeniumKeys.CELLS_ZARR, specs)
   207                                                 if not cell_summary_table[XeniumKeys.CELL_ID].equals(table.obs[XeniumKeys.CELL_ID]):
   208                                                     warnings.warn(
   209                                                         'The "cell_id" column in the cells metadata table does not match the "cell_id" column in the annotation'
   210                                                         " table. This could be due to trying to read a new version that is not supported yet. Please "
   211                                                         "report this issue.",
   212                                                         UserWarning,
   213                                                         stacklevel=2,
   214                                                     )
   215                                                 table.obs[XeniumKeys.Z_LEVEL] = cell_summary_table[XeniumKeys.Z_LEVEL]
   216                                                 table.obs[XeniumKeys.NUCLEUS_COUNT] = cell_summary_table[XeniumKeys.NUCLEUS_COUNT]
   217                                         
   218   1226.1 MiB      0.0 MiB           1       polygons = {}
   219   1226.1 MiB      0.0 MiB           1       labels = {}
   220   1226.1 MiB      0.0 MiB           1       tables = {}
   221   1226.1 MiB      0.0 MiB           1       points = {}
   222   1226.1 MiB      0.0 MiB           1       images = {}
   223                                         
   224   1226.1 MiB      0.0 MiB           1       sdata = SpatialData()
   225                                             
   226   1226.1 MiB      0.0 MiB           1       if output_path is not None:
   227   1226.1 MiB      0.0 MiB           1           sdata.path = output_path
   228   1226.1 MiB      0.0 MiB           1           sdata._validate_can_safely_write_to_path(output_path, overwrite=False)
   229   1226.1 MiB      0.0 MiB           1           store = parse_url(output_path, mode="w").store
   230   1226.1 MiB      0.0 MiB           1           _ = zarr.group(store=store, overwrite=False)
   231   1226.1 MiB      0.0 MiB           1           store.close()  
   232                                         
   233                                             # From the public release notes here:
   234                                             # https://www.10xgenomics.com/support/software/xenium-onboard-analysis/latest/release-notes/release-notes-for-xoa
   235                                             # we see that for distinguishing between the nuclei of polinucleated cells, the `label_id` column is used.
   236                                             # This column is currently not found in the preview data, while I think it is needed in order to unambiguously match
   237                                             # nuclei to cells. Therefore for the moment we only link the table to the cell labels, and not to the nucleus
   238                                             # labels.
   239   1226.1 MiB      0.0 MiB           1       if nucleus_labels:
   240  34151.9 MiB  32925.8 MiB           2           sdata.labels["nucleus_labels"], _ = _get_labels_and_indices_mapping(
   241   1226.1 MiB      0.0 MiB           1               path,
   242   1226.1 MiB      0.0 MiB           1               XeniumKeys.CELLS_ZARR,
   243   1226.1 MiB      0.0 MiB           1               specs,
   244   1226.1 MiB      0.0 MiB           1               mask_index=0,
   245   1226.1 MiB      0.0 MiB           1               labels_name="nucleus_labels",
   246   1226.1 MiB      0.0 MiB           1               labels_models_kwargs=labels_models_kwargs,
   247                                                 )
   248  34151.9 MiB      0.0 MiB           1           if output_path is not None:
   249  34152.0 MiB      0.0 MiB           1               print(sdata.labels["nucleus_labels"])
   250  35070.3 MiB    918.3 MiB           2               sdata._write_element(
   251  34152.0 MiB      0.0 MiB           1                   element=sdata.labels["nucleus_labels"],
   252  34152.0 MiB      0.0 MiB           1                   zarr_container_path=output_path,
   253  34152.0 MiB      0.0 MiB           1                   element_type="labels",
   254  34152.0 MiB      0.0 MiB           1                   element_name="nucleus_labels",
   255  34152.0 MiB      0.0 MiB           1                   overwrite=False,
   256                                                     )
   257  35070.3 MiB      0.0 MiB           1               del sdata.labels["nucleus_labels"]
   258   7119.5 MiB -27950.8 MiB           1               gc.collect()
   259   7119.5 MiB      0.0 MiB           1       if cells_labels:
   260  35562.1 MiB  28442.6 MiB           2           sdata.labels["cell_labels"], cell_labels_indices_mapping = _get_labels_and_indices_mapping(
   261   7119.5 MiB      0.0 MiB           1               path,
   262   7119.5 MiB      0.0 MiB           1               XeniumKeys.CELLS_ZARR,
   263   7119.5 MiB      0.0 MiB           1               specs,
   264   7119.5 MiB      0.0 MiB           1               mask_index=1,
   265   7119.5 MiB      0.0 MiB           1               labels_name="cell_labels",
   266   7119.5 MiB      0.0 MiB           1               labels_models_kwargs=labels_models_kwargs,
   267                                                 )
   268  35562.1 MiB      0.0 MiB           1           if output_path is not None:
   269  35583.6 MiB     21.4 MiB           2               sdata._write_element(
   270  35562.1 MiB      0.0 MiB           1                   element=sdata.labels["cell_labels"],
   271  35562.1 MiB      0.0 MiB           1                   zarr_container_path=output_path,
   272  35562.1 MiB      0.0 MiB           1                   element_type="labels",
   273  35562.1 MiB      0.0 MiB           1                   element_name="cell_labels",
   274  35562.1 MiB      0.0 MiB           1                   overwrite=False,
   275                                                     )
   276  35583.6 MiB      0.0 MiB           1               del sdata.labels["cell_labels"]
   277   7632.8 MiB -27950.8 MiB           1               gc.collect()
   278   7632.8 MiB      0.0 MiB           1           if cell_labels_indices_mapping is not None and table is not None:
   279                                                     if not pd.DataFrame.equals(cell_labels_indices_mapping["cell_id"], table.obs[str(XeniumKeys.CELL_ID)]):
   280                                                         warnings.warn(
   281                                                             "The cell_id column in the cell_labels_table does not match the cell_id column derived from the "
   282                                                             "cell labels data. This could be due to trying to read a new version that is not supported yet. "
   283                                                             "Please report this issue.",
   284                                                             UserWarning,
   285                                                             stacklevel=2,
   286                                                         )
   287                                                     else:
   288                                                         table.obs["cell_labels"] = cell_labels_indices_mapping["label_index"]
   289                                                         if not cells_as_circles:
   290                                                             table.uns[TableModel.ATTRS_KEY][TableModel.INSTANCE_KEY] = "cell_labels"
   291                                         
   292   7632.8 MiB      0.0 MiB           1       if nucleus_boundaries:
   293   8278.5 MiB    645.6 MiB           2           sdata.shapes["nucleus_boundaries"] = _get_polygons(
   294   7632.8 MiB      0.0 MiB           1               path,
   295   7632.8 MiB      0.0 MiB           1               XeniumKeys.NUCLEUS_BOUNDARIES_FILE,
   296   7632.8 MiB      0.0 MiB           1               specs,
   297   7632.8 MiB      0.0 MiB           1               n_jobs,
   298   7632.8 MiB      0.0 MiB           1               idx=table.obs[str(XeniumKeys.CELL_ID)].copy(),
   299                                                 )
   300   8278.5 MiB      0.0 MiB           1           if output_path is not None:
   301   8278.5 MiB   -106.4 MiB           2               sdata._write_element(
   302   8278.5 MiB      0.0 MiB           1                   element=sdata.shapes["nucleus_boundaries"],
   303   8278.5 MiB      0.0 MiB           1                   zarr_container_path=output_path,
   304   8278.5 MiB      0.0 MiB           1                   element_type="shapes",
   305   8278.5 MiB      0.0 MiB           1                   element_name="nucleus_boundaries",
   306   8278.5 MiB      0.0 MiB           1                   overwrite=False,
   307                                                     )
   308   8170.5 MiB   -108.0 MiB           1               del sdata.shapes["nucleus_boundaries"]
   309   8158.9 MiB    -11.6 MiB           1               gc.collect()
   310                                         
   311   8158.9 MiB      0.0 MiB           1       if cells_boundaries:
   312   8340.8 MiB    181.9 MiB           2           sdata.shapes["cell_boundaries"] = _get_polygons(
   313   8158.9 MiB      0.0 MiB           1               path,
   314   8158.9 MiB      0.0 MiB           1               XeniumKeys.CELL_BOUNDARIES_FILE,
   315   8158.9 MiB      0.0 MiB           1               specs,
   316   8158.9 MiB      0.0 MiB           1               n_jobs,
   317   8158.9 MiB      0.0 MiB           1               idx=table.obs[str(XeniumKeys.CELL_ID)].copy(),
   318                                                 )
   319   8340.8 MiB      0.0 MiB           1           if output_path is not None:
   320   8340.8 MiB   -119.9 MiB           2               sdata._write_element(
   321   8340.8 MiB      0.0 MiB           1                   element=sdata.shapes["cell_boundaries"],
   322   8340.8 MiB      0.0 MiB           1                   zarr_container_path=output_path,
   323   8340.8 MiB      0.0 MiB           1                   element_type="shapes",
   324   8340.8 MiB      0.0 MiB           1                   element_name="cell_boundaries",
   325   8340.8 MiB      0.0 MiB           1                   overwrite=False,
   326                                                     )
   327   8208.1 MiB   -132.7 MiB           1               del sdata.shapes["cell_boundaries"]
   328   8198.2 MiB     -9.8 MiB           1               gc.collect()
   329                                         
   330   8198.2 MiB      0.0 MiB           1       if transcripts:
   331   8482.3 MiB    284.0 MiB           1           sdata.points["transcripts"] = _get_points(path, specs)
   332   8482.3 MiB      0.0 MiB           1           if output_path is not None:
   333  11620.8 MiB   3138.5 MiB           2               sdata._write_element(
   334   8482.3 MiB      0.0 MiB           1                   element=sdata.points["transcripts"],
   335   8482.3 MiB      0.0 MiB           1                   zarr_container_path=output_path,
   336   8482.3 MiB      0.0 MiB           1                   element_type="points",
   337   8482.3 MiB      0.0 MiB           1                   element_name="transcripts",
   338   8482.3 MiB      0.0 MiB           1                   overwrite=False,
   339                                                     )
   340  11620.8 MiB      0.0 MiB           1               del sdata.points["transcripts"]
   341  10868.1 MiB   -752.7 MiB           1               gc.collect()
   342                                         
   343  10868.1 MiB      0.0 MiB           1       if version is None or version < packaging.version.parse("2.0.0"):
   344  10868.1 MiB      0.0 MiB           1           if morphology_mip:
   345  10868.1 MiB   -537.9 MiB           2               sdata.images["morphology_mip"] = _get_images(
   346  10868.1 MiB      0.0 MiB           1                   path,
   347  10868.1 MiB      0.0 MiB           1                   XeniumKeys.MORPHOLOGY_MIP_FILE,
   348  10868.1 MiB      0.0 MiB           1                   imread_kwargs,
   349  10868.1 MiB      0.0 MiB           1                   image_models_kwargs,
   350                                                     )
   351  10330.3 MiB   -537.9 MiB           1               if output_path is not None:
   352  13196.5 MiB   2866.2 MiB           2                   sdata._write_element(
   353  10330.3 MiB      0.0 MiB           1                       element=sdata.images["morphology_mip"],
   354  10330.3 MiB      0.0 MiB           1                       zarr_container_path=output_path,
   355  10330.3 MiB      0.0 MiB           1                       element_type="images",
   356  10330.3 MiB      0.0 MiB           1                       element_name="morphology_mip",
   357  10330.3 MiB      0.0 MiB           1                       overwrite=False,
   358                                                         )
   359  13196.5 MiB      0.0 MiB           1                   del sdata.images["morphology_mip"]
   360  13194.5 MiB     -2.0 MiB           1                   gc.collect()
   361  13194.5 MiB      0.0 MiB           1           if morphology_focus:
   362  13194.5 MiB     -2.0 MiB           2               sdata.images["morphology_focus"] = _get_images(
   363  13194.5 MiB      0.0 MiB           1                   path,
   364  13194.5 MiB      0.0 MiB           1                   XeniumKeys.MORPHOLOGY_FOCUS_FILE,
   365  13194.5 MiB      0.0 MiB           1                   imread_kwargs,
   366  13194.5 MiB      0.0 MiB           1                   image_models_kwargs,
   367                                                     )
   368  13192.6 MiB     -2.0 MiB           1               if output_path is not None:
   369  15717.2 MiB   2524.7 MiB           2                   sdata._write_element(
   370  13192.6 MiB      0.0 MiB           1                       element=sdata.images["morphology_focus"],
   371  13192.6 MiB      0.0 MiB           1                       zarr_container_path=output_path,
   372  13192.6 MiB      0.0 MiB           1                       element_type="images",
   373  13192.6 MiB      0.0 MiB           1                       element_name="morphology_focus",
   374  13192.6 MiB      0.0 MiB           1                       overwrite=False,
   375                                                         )
   376  15717.2 MiB      0.0 MiB           1                   del sdata.images["morphology_focus"]
   377  15716.3 MiB     -1.0 MiB           1                   gc.collect()
   378                                             else:
   379                                                 if morphology_focus:
   380                                                     morphology_focus_dir = path / XeniumKeys.MORPHOLOGY_FOCUS_DIR
   381                                                     files = {f for f in os.listdir(morphology_focus_dir) if f.endswith(".ome.tif")}
   382                                                     if len(files) not in [1, 4]:
   383                                                         raise ValueError(
   384                                                             "Expected 1 (no segmentation kit) or 4 (segmentation kit) files in the morphology focus directory, "
   385                                                             f"found {len(files)}: {files}"
   386                                                         )
   387                                                     if files != {XeniumKeys.MORPHOLOGY_FOCUS_CHANNEL_IMAGE.value.format(i) for i in range(len(files))}:
   388                                                         raise ValueError(
   389                                                             "Expected files in the morphology focus directory to be named as "
   390                                                             f"{XeniumKeys.MORPHOLOGY_FOCUS_CHANNEL_IMAGE.value.format(0)} to "
   391                                                             f"{XeniumKeys.MORPHOLOGY_FOCUS_CHANNEL_IMAGE.value.format(len(files) - 1)}, found {files}"
   392                                                         )
   393                                                     # the 'dummy' channel is a temporary workaround, see _get_images() for more details
   394                                                     if len(files) == 1:
   395                                                         channel_names = {
   396                                                             0: XeniumKeys.MORPHOLOGY_FOCUS_CHANNEL_0.value,
   397                                                         }
   398                                                     else:
   399                                                         channel_names = {
   400                                                             0: XeniumKeys.MORPHOLOGY_FOCUS_CHANNEL_0.value,
   401                                                             1: XeniumKeys.MORPHOLOGY_FOCUS_CHANNEL_1.value,
   402                                                             2: XeniumKeys.MORPHOLOGY_FOCUS_CHANNEL_2.value,
   403                                                             3: XeniumKeys.MORPHOLOGY_FOCUS_CHANNEL_3.value,
   404                                                             4: "dummy",
   405                                                         }
   406                                                     # this reads the scale 0 for all the 1 or 4 channels (the other files are parsed automatically)
   407                                                     # dask.image.imread will call tifffile.imread which will give a warning saying that reading multi-file
   408                                                     # pyramids is not supported; since we are reading the full scale image and reconstructing the pyramid, we
   409                                                     # can ignore this
   410                                         
   411                                                     class IgnoreSpecificMessage(logging.Filter):
   412                                                         def filter(self, record: logging.LogRecord) -> bool:
   413                                                             # Ignore specific log message
   414                                                             if "OME series cannot read multi-file pyramids" in record.getMessage():
   415                                                                 return False
   416                                                             return True
   417                                         
   418                                                     logger = tifffile.logger()
   419                                                     logger.addFilter(IgnoreSpecificMessage())
   420                                                     image_models_kwargs = dict(image_models_kwargs)
   421                                                     assert (
   422                                                         "c_coords" not in image_models_kwargs
   423                                                     ), "The channel names for the morphology focus images are handled internally"
   424                                                     image_models_kwargs["c_coords"] = list(channel_names.values())
   425                                                     sdata.images["morphology_focus"] = _get_images(
   426                                                         morphology_focus_dir,
   427                                                         XeniumKeys.MORPHOLOGY_FOCUS_CHANNEL_IMAGE.format(0),
   428                                                         imread_kwargs,
   429                                                         image_models_kwargs,
   430                                                     )
   431                                                     del image_models_kwargs["c_coords"]
   432                                                     if output_path is not None:
   433                                                         sdata._write_element(
   434                                                             element=sdata.images["morphology_focus"],
   435                                                             zarr_container_path=output_path,
   436                                                             element_type="images",
   437                                                             element_name="morphology_focus",
   438                                                             overwrite=False,
   439                                                         )
   440                                                         del sdata.images["morphology_focus"]
   441                                                         gc.collect()
   442                                                     logger.removeFilter(IgnoreSpecificMessage())
   443                                         
   444  15716.3 MiB      0.0 MiB           1       if table is not None:
   445  15716.3 MiB      0.0 MiB           1           sdata.tables["table"] = table
   446  15716.3 MiB      0.0 MiB           1           if output_path is not None:
   447  15716.3 MiB     -0.6 MiB           2               sdata._write_element(
   448  15716.3 MiB      0.0 MiB           1                   element=sdata.tables["table"],
   449  15716.3 MiB      0.0 MiB           1                   zarr_container_path=output_path,
   450  15716.3 MiB      0.0 MiB           1                   element_type="tables",
   451  15716.3 MiB      0.0 MiB           1                   element_name="table",
   452  15716.3 MiB      0.0 MiB           1                   overwrite=False,
   453                                                     )
   454  15715.7 MiB     -0.6 MiB           1               del sdata.tables["table"]
   455  15715.7 MiB      0.0 MiB           1               gc.collect()
   456                                         
   457  15715.7 MiB      0.0 MiB           1       if cells_as_circles:
   458  15715.7 MiB      0.0 MiB           1           sdata.shapes[specs["region"]] = circles
   459                                         
   460                                             # find and add additional aligned images
   461  15715.7 MiB      0.0 MiB           1       if aligned_images:
   462  15715.7 MiB      0.0 MiB           1           extra_images = _add_aligned_images(path, imread_kwargs, image_models_kwargs)
   463  20413.1 MiB      0.0 MiB           2           for key, value in extra_images.items():
   464  15715.7 MiB      0.0 MiB           1               sdata.images[key] = value
   465  15715.7 MiB      0.0 MiB           1               if output_path is not None:
   466  20413.1 MiB   4697.4 MiB           2                   sdata._write_element(
   467  15715.7 MiB      0.0 MiB           1                       element=sdata.images[key],
   468  15715.7 MiB      0.0 MiB           1                       zarr_container_path=output_path,
   469  15715.7 MiB      0.0 MiB           1                       element_type="images",
   470  15715.7 MiB      0.0 MiB           1                       element_name=key,
   471  15715.7 MiB      0.0 MiB           1                       overwrite=False,
   472                                                         )
   473  20413.1 MiB      0.0 MiB           1                   del sdata.images[key]
   474  20413.1 MiB      0.0 MiB           1                   gc.collect()
   475                                         
   476  20413.1 MiB      0.0 MiB           1       return sdata

As you can see, it reduces the maximum memory occupied by around 44% (from 64 GB to 35GB).

I need your opinion on a few things:

  • I followed what is done inside spatialdata.write(). Instead, could we use spatialdata.write() repeatedly every time a new element is generated? In theory, it should write only the new element, right? In this way, if something changes in the spatialdata code for writing, we will not have to update code of the readers.
  • I have set the default parameters of the write function from spatialdata to not add too many parameters to the reader function. However, should we add a parameter write_kwargs to allow the user to change the parameters of the writing?
  • For now, if the output_path is set, it returns None instead of the spatialdata object. Should we return an on-disk spatialdata object by rereading the file?

Looking forward to feedback!

@LucaMarconato
Copy link
Member

LucaMarconato commented Jan 5, 2025

Thanks @marcovarrone for reporting this and for implementing a solution, which from the profiling is giving excellent results! I am still doing some triaging some issues/feature requests so I haven't tried your code yet, but from a quick look I have a few comments. I'll also answer your questions above.

I followed what is done inside spatialdata.write(). Instead, could we use spatialdata.write() repeatedly every time a new element is generated? In theory, it should write only the new element, right? In this way, if something changes in the spatialdata code for writing, we will not have to update code of the readers.

I would change the implementation a bit by removing the calls to internal APIs and instead use just write() (once, to write the empty SpatialData object), and write_element() for each successive call. Using only public API would make the code future-proof.

I have set the default parameters of the write function from spatialdata to not add too many parameters to the reader function. However, should we add a parameter write_kwargs to allow the user to change the parameters of the writing?

I think it is a good approach. I would keep the implementation simple and not add additional parameter as the case in which the user wants to modify the writing option is very rare, and in such cases I think it would be acceptable for the user to use extra RAM and call the write function manually.

For now, if the output_path is set, it returns None instead of the spatialdata object. Should we return an on-disk spatialdata object by rereading the file?

Good point. Some users experienced performance problems because it was not clear to them that they had to write and re-read the SpatialData object after reading it. On one side I like the idea to return the re-read SpatialData object because it keeps the syntax simpler. On the other hand I like returning None because it makes users more familiar with having to re-read the objects manually. But I probably prefer returning the re-read SpatialData object. In particular, if the recommended behavior will be to pass output_path, then we address the problem of having to re-read the object.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants