Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[builder] port to use enums in schema #896

Merged
merged 19 commits into from
Dec 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pyarrow as pa
import tiledbsoma as soma

from .globals import CENSUS_DATASETS_COLUMNS, CENSUS_DATASETS_NAME
from .globals import CENSUS_DATASETS_NAME, CENSUS_DATASETS_TABLE_SPEC

T = TypeVar("T", bound="Dataset")

Expand Down Expand Up @@ -70,14 +70,14 @@ def create_dataset_manifest(info_collection: soma.Collection, datasets: List[Dat
"""
logging.info("Creating dataset_manifest")
manifest_df = Dataset.to_dataframe(datasets)
manifest_df = manifest_df[CENSUS_DATASETS_COLUMNS + ["soma_joinid"]]
manifest_df = manifest_df[list(CENSUS_DATASETS_TABLE_SPEC.field_names())]
if len(manifest_df) == 0:
return

schema = CENSUS_DATASETS_TABLE_SPEC.to_arrow_schema(manifest_df)

# write to a SOMA dataframe
with info_collection.add_new_dataframe(
CENSUS_DATASETS_NAME,
schema=pa.Schema.from_pandas(manifest_df, preserve_index=False),
index_column_names=["soma_joinid"],
CENSUS_DATASETS_NAME, schema=schema, index_column_names=["soma_joinid"]
) as manifest:
manifest.write(pa.Table.from_pandas(manifest_df, preserve_index=False))
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@
from .datasets import Dataset
from .globals import (
CENSUS_OBS_PLATFORM_CONFIG,
CENSUS_OBS_TERM_COLUMNS,
CENSUS_OBS_TABLE_SPEC,
CENSUS_VAR_PLATFORM_CONFIG,
CENSUS_VAR_TERM_COLUMNS,
CENSUS_VAR_TABLE_SPEC,
CENSUS_X_LAYERS,
CENSUS_X_LAYERS_PLATFORM_CONFIG,
CXG_OBS_TERM_COLUMNS,
Expand Down Expand Up @@ -161,26 +161,8 @@ def create(self, census_data: soma.Collection) -> None:
# create `ms`
ms = self.experiment.add_new_collection("ms")

# create `obs`
obs_schema = pa.schema(list(CENSUS_OBS_TERM_COLUMNS.items()))
self.experiment.add_new_dataframe(
"obs",
schema=obs_schema,
index_column_names=["soma_joinid"],
platform_config=CENSUS_OBS_PLATFORM_CONFIG,
)

# make measurement and add to ms collection
rna_measurement = ms.add_new_collection(MEASUREMENT_RNA_NAME, soma.Measurement)

# create `var` in the measurement
var_schema = pa.schema(list(CENSUS_VAR_TERM_COLUMNS.items()))
rna_measurement.add_new_dataframe(
"var",
schema=var_schema,
index_column_names=["soma_joinid"],
platform_config=CENSUS_VAR_PLATFORM_CONFIG,
)
ms.add_new_collection(MEASUREMENT_RNA_NAME, soma.Measurement)

def filter_anndata_cells(self, ad: anndata.AnnData) -> Union[None, anndata.AnnData]:
anndata_cell_filter = make_anndata_cell_filter(self.anndata_cell_filter_spec)
Expand All @@ -206,19 +188,19 @@ def accumulate_axes(self, dataset: Dataset, ad: anndata.AnnData) -> int:
add_tissue_mapping(obs_df, dataset.dataset_id)

# add any other computed columns
for key in CENSUS_OBS_TERM_COLUMNS:
for key in CENSUS_OBS_TABLE_SPEC.field_names():
if key not in obs_df:
obs_df[key] = np.full(
(len(obs_df),),
np.nan,
dtype=CENSUS_OBS_TERM_COLUMNS[key].to_pandas_dtype(),
dtype=CENSUS_OBS_TABLE_SPEC.field(key).to_pandas_dtype(ignore_dict_type=True),
)

# Accumulate aggregation counts
self.census_summary_cell_counts = accumulate_summary_counts(self.census_summary_cell_counts, obs_df)

# drop columns we don't want to write (e.g., organism)
obs_df = obs_df[list(CENSUS_OBS_TERM_COLUMNS)]
obs_df = obs_df[list(CENSUS_OBS_TABLE_SPEC.field_names())]

# accumulate obs
self.obs_df_accumulation.append(obs_df)
Expand All @@ -229,9 +211,11 @@ def accumulate_axes(self, dataset: Dataset, ad: anndata.AnnData) -> int:
# NOTE: assumes raw.var is None, OR has same index as var. Currently enforced in open_anndata(),
# but may need to evolve this logic if that assumption is not scalable.
tv = ad.var.rename_axis("feature_id").reset_index()[["feature_id", "feature_name", "feature_length"]]
for key in CENSUS_VAR_TERM_COLUMNS:
for key in CENSUS_VAR_TABLE_SPEC.field_names():
if key not in tv:
tv[key] = np.full((len(tv),), 0, dtype=CENSUS_VAR_TERM_COLUMNS[key].to_pandas_dtype())
tv[key] = np.full(
(len(tv),), 0, dtype=CENSUS_VAR_TABLE_SPEC.field(key).to_pandas_dtype(ignore_dict_type=True)
)
self.var_df = (
pd.concat([self.var_df, tv], ignore_index=True).drop_duplicates() if self.var_df is not None else tv
)
Expand All @@ -255,34 +239,56 @@ def finalize_obs_axes(self) -> None:

def write_obs_dataframe(self) -> None:
logging.info(f"{self.name}: writing obs dataframe")
assert self.experiment is not None
_assert_open_for_write(self.experiment)

if self.obs_df is None or len(self.obs_df) == 0:
obs_df = CENSUS_OBS_TABLE_SPEC.recategoricalize(self.obs_df)
obs_schema = CENSUS_OBS_TABLE_SPEC.to_arrow_schema(obs_df)

# create `obs`
self.experiment.add_new_dataframe(
"obs",
schema=obs_schema,
index_column_names=["soma_joinid"],
platform_config=CENSUS_OBS_PLATFORM_CONFIG,
)

if obs_df is None or obs_df.empty:
logging.info(f"{self.name}: empty obs dataframe")
else:
logging.debug(f"experiment {self.name} obs = {self.obs_df.shape}")
assert not np.isnan(self.obs_df.nnz.to_numpy()).any() # sanity check
logging.debug(f"experiment {self.name} obs = {obs_df.shape}")
assert not np.isnan(obs_df.nnz.to_numpy()).any() # sanity check
pa_table = pa.Table.from_pandas(
self.obs_df,
preserve_index=False,
columns=list(CENSUS_OBS_TERM_COLUMNS),
obs_df, preserve_index=False, columns=list(CENSUS_OBS_TABLE_SPEC.field_names())
)
self.experiment.obs.write(pa_table) # type:ignore
self.experiment.obs.write(pa_table)

def write_var_dataframe(self) -> None:
logging.info(f"{self.name}: writing var dataframe")
assert self.experiment is not None
_assert_open_for_write(self.experiment)

if self.var_df is None or len(self.var_df) == 0:
rna_measurement = self.experiment.ms[MEASUREMENT_RNA_NAME]

var_df = CENSUS_VAR_TABLE_SPEC.recategoricalize(self.var_df)
var_schema = CENSUS_VAR_TABLE_SPEC.to_arrow_schema(var_df)

# create `var` in the measurement
rna_measurement.add_new_dataframe(
"var",
schema=var_schema,
index_column_names=["soma_joinid"],
platform_config=CENSUS_VAR_PLATFORM_CONFIG,
)

if var_df is None or var_df.empty:
logging.info(f"{self.name}: empty var dataframe")
else:
logging.debug(f"experiment {self.name} var = {self.var_df.shape}")
logging.debug(f"experiment {self.name} var = {var_df.shape}")
pa_table = pa.Table.from_pandas(
self.var_df,
preserve_index=False,
columns=list(CENSUS_VAR_TERM_COLUMNS),
var_df, preserve_index=False, columns=list(CENSUS_VAR_TABLE_SPEC.field_names())
)
self.experiment.ms["RNA"].var.write(pa_table) # type:ignore
rna_measurement.var.write(pa_table)

def populate_var_axis(self) -> None:
logging.info(f"{self.name}: populate var axis")
Expand Down Expand Up @@ -346,7 +352,7 @@ def populate_presence_matrix(self, datasets: List[Dataset]) -> None:
assert pm.count_nonzero() == pm.nnz
assert pm.dtype == bool

fdpm = self.experiment.ms["RNA"].add_new_sparse_ndarray( # type:ignore
fdpm = self.experiment.ms[MEASUREMENT_RNA_NAME].add_new_sparse_ndarray( # type:ignore
FEATURE_DATASET_PRESENCE_MATRIX_NAME,
type=pa.bool_(),
shape=(max_dataset_joinid + 1, self.n_var),
Expand Down
Loading
Loading