Skip to content

Commit

Permalink
Make create_anonymized_columns work with multi columns transformer (#872
Browse files Browse the repository at this point in the history
)
  • Loading branch information
R-Palazzo authored Aug 23, 2024
1 parent 8cccc81 commit 2fab091
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 1 deletion.
17 changes: 16 additions & 1 deletion rdt/hyper_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,8 +871,23 @@ def create_anonymized_columns(self, num_rows, column_names):
'list of valid column names.'
)

columns_to_generate = set()
for column in column_names:
if column not in self._multi_column_fields:
columns_to_generate.add(column)
continue

multi_columns = self._multi_column_fields[column]
if any(col not in column_names for col in multi_columns):
raise InvalidConfigError(
f"Column '{column}' is part of a multi-column field. You must include all "
'columns inside the multi-column field to generate the anonymized columns.'
)

columns_to_generate.add(multi_columns)

transformers = []
for column_name in column_names:
for column_name in sorted(columns_to_generate):
transformer = self.field_transformers.get(column_name)
if not transformer.is_generator():
raise TransformerProcessingError(
Expand Down
112 changes: 112 additions & 0 deletions tests/unit/test_hyper_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1478,6 +1478,7 @@ def test_create_anonymized_columns(self):
instance._modified_config = False
instance._subset.return_value = False
instance.random_state = {}
instance._multi_column_fields = {}

random_element = AnonymizedFaker(
function_name='random_element', function_kwargs={'elements': ['a']}
Expand Down Expand Up @@ -1622,6 +1623,7 @@ def test_create_anonymized_columns_invalid_transformers(self):
instance._fitted = True
instance._modified_config = False
instance._subset.return_value = False
instance._multi_column_fields = {}

instance.field_transformers = {
'datetime': FloatFormatter(),
Expand All @@ -1641,6 +1643,116 @@ def test_create_anonymized_columns_invalid_transformers(self):
column_names=['datetime', 'random_element'],
)

def test_create_anonymized_columns_multi_column_transformer(self):
"""Test ``create_anonymized_columns`` with a multi-column transformer."""

class GeneratorTransformer(BaseMultiColumnTransformer):
IS_GENERATOR = True

def __init__(self):
super().__init__()
self.output_properties = {}

def _fit(self, data):
self.columns = list(data.columns)

def _transform(self, data):
return pd.DataFrame()

def _get_prefix(self):
return

def _reverse_transform(self, data):
num_rows = data.shape[0]
for column in self.columns:
data[column] = np.arange(num_rows)

return data

# Setup
instance = HyperTransformer()
instance._multi_column_fields = {
'col1': ('col1', 'col2'),
'col2': ('col1', 'col2'),
}
generator = GeneratorTransformer()
instance.field_transformers = {
('col1', 'col2'): generator,
}
instance.field_sdtypes = {
'col1': 'numerical',
'col2': 'numerical',
}
instance.fit(pd.DataFrame({'col1': [1, 2, 3], 'col2': [1, 2, 3]}))

# Run
output = instance.create_anonymized_columns(num_rows=5, column_names=['col1', 'col2'])

# Assert
expected_output = pd.DataFrame({
'col1': [0, 1, 2, 3, 4],
'col2': [0, 1, 2, 3, 4],
})
pd.testing.assert_frame_equal(output, expected_output, check_dtype=False)

def test_create_anonymized_columns_multi_column_transformer_error(self):
"""Test ``create_anonymized_columns`` raises error with multi-column transformer.
Test that:
- An error occurs when some columns in the column_name list are part of a multi-column
transformer, but not all the required columns of the multi-column
transformer are present.
- An error is raised when a multi-column transformer is not a generator.
"""

class MultiColumnTransformer(BaseMultiColumnTransformer):
IS_GENERATOR = False

def __init__(self):
super().__init__()
self.output_properties = {}

def _fit(self, data):
self.columns = list(data.columns)

def _transform(self, data):
return pd.DataFrame()

def _get_prefix(self):
return

# Setup
instance = HyperTransformer()
instance._multi_column_fields = {
'col1': ('col1', 'col2'),
'col2': ('col1', 'col2'),
}
not_generator = MultiColumnTransformer()
instance.field_transformers = {
('col1', 'col2'): not_generator,
}
instance.field_sdtypes = {
'col1': 'numerical',
'col2': 'numerical',
}
instance.fit(pd.DataFrame({'col1': [1, 2, 3], 'col2': [1, 2, 3]}))

# Run and Assert
error_msg_not_all_multi_column = re.escape(
"Column 'col1' is part of a multi-column field. You must include all "
'columns inside the multi-column field to generate the anonymized columns.'
)
with pytest.raises(InvalidConfigError, match=error_msg_not_all_multi_column):
instance.create_anonymized_columns(num_rows=5, column_names=['col1'])

error_msg_not_generator = re.escape(
"Column '('col1', 'col2')' cannot be anonymized. All columns must be assigned to "
"'AnonymizedFaker', 'RegexGenerator' or other ``generator``. Use "
"'get_config()' to see the current transformer assignments."
)
with pytest.raises(TransformerProcessingError, match=error_msg_not_generator):
instance.create_anonymized_columns(num_rows=5, column_names=['col1', 'col2'])

def test_reverse_transform(self):
"""Test the ``reverse_transform`` method.
Expand Down

0 comments on commit 2fab091

Please sign in to comment.