From ef8d1ce2fac0c881ba384e1c1278d27d1d80c1e9 Mon Sep 17 00:00:00 2001 From: Vivian Nguyen Date: Fri, 5 Apr 2024 13:15:58 -0500 Subject: [PATCH] Correct Boolean value writes for enum values --- apis/python/devtools/ingestor | 2 +- apis/python/src/tiledbsoma/_dataframe.py | 14 +++++++++++--- apis/python/tests/test_collection.py | 2 +- 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/apis/python/devtools/ingestor b/apis/python/devtools/ingestor index 8da4b63560..26ab40b8ee 100755 --- a/apis/python/devtools/ingestor +++ b/apis/python/devtools/ingestor @@ -16,6 +16,7 @@ import os import sys from typing import Optional +import tiledb from somacore import options import tiledbsoma @@ -25,7 +26,6 @@ import tiledbsoma._util import tiledbsoma.io import tiledbsoma.logging from tiledbsoma.options import SOMATileDBContext -import tiledb # ================================================================ diff --git a/apis/python/src/tiledbsoma/_dataframe.py b/apis/python/src/tiledbsoma/_dataframe.py index 9d6b784c77..509c84d4c9 100644 --- a/apis/python/src/tiledbsoma/_dataframe.py +++ b/apis/python/src/tiledbsoma/_dataframe.py @@ -276,7 +276,7 @@ def create( extents, context.native_context, plt_cfg, - (0, timestamp_ms) + (0, timestamp_ms), ) handle = cls._wrapper_type.open(uri, "w", context, tiledb_timestamp) @@ -469,15 +469,23 @@ def write( if not pa.types.is_dictionary(input_field.type): raise ValueError(f"{name} requires dictionary entry") col = values.column(name).combine_chunks() + if pa.types.is_boolean(target_field.type.value_type): + col = col.cast( + pa.dictionary( + target_field.type.index_type, + pa.uint8(), + target_field.type.ordered, + ) + ) new_enmr = self._handle._handle.extend_enumeration(name, col) - + if pa.types.is_binary( target_field.type.value_type ) or pa.types.is_large_binary(target_field.type.value_type): new_enmr = np.array(new_enmr, "S") elif pa.types.is_boolean(target_field.type.value_type): new_enmr = np.array(new_enmr, bool) - + df = pd.Categorical( col.to_pandas(), ordered=target_field.type.ordered, diff --git a/apis/python/tests/test_collection.py b/apis/python/tests/test_collection.py index 75c792aa9e..47f5f5a030 100644 --- a/apis/python/tests/test_collection.py +++ b/apis/python/tests/test_collection.py @@ -519,4 +519,4 @@ def test_context_timestamp(tmp_path: pathlib.Path): assert coll.tiledb_timestamp_ms == 234 sub_1 = coll["sub_1"] assert sub_1.tiledb_timestamp_ms == 234 - assert sub_1["sub_sub"].tiledb_timestamp_ms == 234 \ No newline at end of file + assert sub_1["sub_sub"].tiledb_timestamp_ms == 234