Skip to content

Commit

Permalink
fully change to use integer indexing
Browse files Browse the repository at this point in the history
  • Loading branch information
melonora committed Jan 9, 2025
1 parent c705649 commit 64c5ffc
Showing 1 changed file with 6 additions and 16 deletions.
22 changes: 6 additions & 16 deletions src/spatialdata/_core/query/relational_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,9 +380,9 @@ def _inner_join_spatialelement_table(
element_dict: dict[str, dict[str, Any]], table: AnnData, match_rows: Literal["left", "no", "right"]
) -> tuple[dict[str, Any], AnnData]:
regions, region_column_name, instance_key = get_table_keys(table)
groups_df = table.obs.groupby(by=region_column_name, observed=False)
obs = table.obs.reset_index()
groups_df = obs.groupby(by=region_column_name, observed=False)
joined_indices = None
element_indices_mapping = {}
for element_type, name_element in element_dict.items():
for name, element in name_element.items():
if name in regions:
Expand All @@ -400,10 +400,9 @@ def _inner_join_spatialelement_table(

masked_element = _get_masked_element(element_indices, element, table_instance_key_column, match_rows)
element_dict[element_type][name] = masked_element
element_indices_mapping[name] = masked_element.index

joined_indices = _get_joined_table_indices(
joined_indices, element_indices, table_instance_key_column, match_rows
joined_indices, masked_element.index, table_instance_key_column, match_rows
)
else:
warnings.warn(
Expand All @@ -415,17 +414,7 @@ def _inner_join_spatialelement_table(
if joined_indices is not None:
joined_indices = joined_indices.dropna() if any(joined_indices.isna()) else joined_indices

try:
joined_table = table[joined_indices, :].copy() if joined_indices is not None else None
# happens when having duplicate indices in obs. Need to revert to integer indexing.
# TODO: benchmark to check whether this by default is just as quick as obtaining joined_indices.
except pd.errors.InvalidIndexError:
indices = []
obs = table.obs.reset_index()
_, region_col, index_col = get_table_keys(table)
for name_key, index_values in element_indices_mapping.items():
indices.extend(obs[(obs[region_col] == name_key) & (obs[index_col].isin(index_values))].index)
joined_table = table[indices, :].copy()
joined_table = table[joined_indices, :].copy() if joined_indices is not None else None

_inplace_fix_subset_categorical_obs(subset_adata=joined_table, original_adata=table)
return element_dict, joined_table
Expand Down Expand Up @@ -468,7 +457,8 @@ def _left_join_spatialelement_table(
if match_rows == "right":
warnings.warn("Matching rows 'right' is not supported for 'left' join.", UserWarning, stacklevel=2)
regions, region_column_name, instance_key = get_table_keys(table)
groups_df = table.obs.groupby(by=region_column_name, observed=False)
obs = table.obs.reset_index()
groups_df = obs.groupby(by=region_column_name, observed=False)
joined_indices = None
for element_type, name_element in element_dict.items():
for name, element in name_element.items():
Expand Down

0 comments on commit 64c5ffc

Please sign in to comment.