diff --git a/CHANGELOG.md b/CHANGELOG.md index 3ad50ab..9163f9d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,9 @@ +# v0.4.2 + +- Refactored temporary table creation to use sqlalchemy constructs +- Ensured temporary tables created with primary key if source table had one +- Added ID compaction functionality + # v0.4.1 - Added logic to appropriately adjust auto-increment sequence state for postgresql diff --git a/subsetter.example.yaml b/subsetter.example.yaml index 3d71fa1..4d7fa24 100644 --- a/subsetter.example.yaml +++ b/subsetter.example.yaml @@ -300,3 +300,26 @@ sampler: # single-column primary key in another table. infer_foreign_keys: none # can be 'none', 'schema', or 'all' + # Compaction refers to removing gaps in the sampled ID space of a specific + # table in a column. In most cases this is unnecessary but sometimes it can + # be helpful to keep the IDs in the sampled dataset small. Enabling compaction + # can require more tables to be materialized on the source database and can + # have some mild performance impacts on sampling. + compact: + # If set to true any tables that have single-column, integral primary key + # will have their primary key marked for compaction. + primary_keys: false + + # If set to true any tables that have single-column, integral, + # auto-increment primary key will have their primary key marked for + # compaction. + auto_increment_keys: false + + # Mapping of additional columns that should be compacted if needed. Note if + # multiple columns in the same table are compacted they will end up having + # the same value. + columns: + db1.gizmo: [extra_id] + + # Minimum ID to set of the first sampled row for a table. + start_key: 1 diff --git a/subsetter/config_model.py b/subsetter/config_model.py index ee16c08..691649c 100644 --- a/subsetter/config_model.py +++ b/subsetter/config_model.py @@ -78,11 +78,17 @@ class MultiplicityConfig(ForbidBaseModel): extra_columns: Dict[str, List[str]] = {} ignore_primary_key_columns: Dict[str, List[str]] = {} + class CompactConfig(ForbidBaseModel): + primary_keys: bool = False + auto_increment_keys: bool = False + columns: Dict[str, List[str]] = {} + start_key: int = 1 + output: OutputType = DirectoryOutputConfig(mode="directory", directory="output") filters: Dict[str, List[FilterConfig]] = {} # type: ignore multiplicity: MultiplicityConfig = MultiplicityConfig() infer_foreign_keys: Literal["none", "schema", "all"] = "none" - compact_keys: bool = False + compact: CompactConfig = CompactConfig() class SubsetterConfig(ForbidBaseModel): diff --git a/subsetter/sampler.py b/subsetter/sampler.py index 8f7b685..9d90c48 100644 --- a/subsetter/sampler.py +++ b/subsetter/sampler.py @@ -4,7 +4,6 @@ import logging import os import re -import uuid from typing import Any, Dict, Iterable, List, Optional, Set, Tuple import sqlalchemy as sa @@ -37,72 +36,99 @@ def tqdm(x, **_): LOGGER = logging.getLogger(__name__) SOURCE_BUFFER_SIZE = 1024 DESTINATION_BUFFER_SIZE = 1024 +SUBSETTER_COMPACT_COLUMN = "_sbsttr_id" -def create_temporary_table( - conn, - schema: str, - select: sa.Select, - *, - primary_key: Tuple[str, ...] = (), -) -> Tuple[sa.Table, int]: +class TempTableCreator: """ - Create a temporary table on the passed connection generated by the passed - Select object. This method will return a - - Parameters - conn: The connection to create the temporary table within. Temporary tables - are private to the connection that created them and are cleaned up - after the connection is closed. - schema: The schema to create the temporary table within. For some dialects - temporary tables always exist in their own schema and this parameter - will be ignored. - primary_key: If set will mark the set of columns passed as primary keys in - the temporary table. This tuple should match a subset of the - column names in the select query. - - Returns a tuple containing the generated table object and the number of rows that - were inserted in the table. + Help generate temporary tables. Attempts to give the temporary tables + friendly names that will also not interfere with existing tables. Avoids + giving any two tables generated with an instance of this class the same + name. """ - dialect = conn.engine.dialect - - # Some dialects can only create temporary tables in an implicit schema - temp_schema: Optional[str] = schema - if dialect.name in ("postgresql", "sqlite"): - temp_schema = None - - temp_name = f"_tmp_subsetter_{str(uuid.uuid4()).replace('-', '_')}" - - # Create the temporary table from the select statement. Mark the requested - # columns as part of the primary key. - metadata = sa.MetaData() - table_obj = sa.Table( - temp_name, - metadata, - schema=temp_schema, - prefixes=["TEMPORARY"], - *( - sa.Column(col.name, col.type, primary_key=col.name in primary_key) - for col in select.selected_columns - ), - ) - try: - metadata.create_all(conn) - except Exception as exc: # pylint: disable=broad-exception-caught - # TODO: Is this still needed? - # - # Some client/server combinations report a read-only error even though the temporary - # table creation actually succeeded. We'll just swallow the error here and if there - # was a real issue it'll get flagged again when we query against it. - if "--read-only" not in str(exc): - raise - # Copy data into the temporary table - result = conn.execute( - table_obj.insert().from_select(list(table_obj.columns), select) - ) + NAME_PREFIX = "_sbsttr_" + + def __init__(self) -> None: + self.name_counts: Dict[str, int] = {} + + def _generate_name(self, name: str) -> str: + name = "".join(ch for ch in name if ch.isalnum())[:40] + if not name: + name = "anon" + cnt = self.name_counts.get(name, 0) + self.name_counts[name] = cnt + 1 + if cnt: + return f"{self.NAME_PREFIX}{name}_{cnt}" + return f"{self.NAME_PREFIX}{name}" + + def create( + self, + conn, + schema: str, + select: sa.Select, + *, + name: str = "", + primary_key: Tuple[str, ...] = (), + ) -> Tuple[sa.Table, int]: + """ + Create a temporary table on the passed connection generated by the passed + Select object. This method will return a + + Parameters + conn: The connection to create the temporary table within. Temporary tables + are private to the connection that created them and are cleaned up + after the connection is closed. + schema: The schema to create the temporary table within. For some dialects + temporary tables always exist in their own schema and this parameter + will be ignored. + primary_key: If set will mark the set of columns passed as primary keys in + the temporary table. This tuple should match a subset of the + column names in the select query. + + Returns a tuple containing the generated table object and the number of rows that + were inserted in the table. + """ + dialect = conn.engine.dialect + + # Some dialects can only create temporary tables in an implicit schema + temp_schema: Optional[str] = schema + if dialect.name in ("postgresql", "sqlite"): + temp_schema = None + + temp_name = self._generate_name(name) + + # Create the temporary table from the select statement. Mark the requested + # columns as part of the primary key. + metadata = sa.MetaData() + table_obj = sa.Table( + temp_name, + metadata, + schema=temp_schema, + prefixes=["TEMPORARY"], + *( + sa.Column(col.name, col.type, primary_key=col.name in primary_key) + for col in select.selected_columns + ), + ) + try: + metadata.create_all(conn) + except Exception as exc: # pylint: disable=broad-exception-caught + # TODO: Is this still needed? + # + # Some client/server combinations report a read-only error even though the temporary + # table creation actually succeeded. We'll just swallow the error here and if there + # was a real issue it'll get flagged again when we query against it. + if "--read-only" not in str(exc): + raise + + # Copy data into the temporary table + result = conn.execute( + table_obj.insert().from_select(list(table_obj.columns), select) + ) + result = conn.execute(table_obj.select()) - return table_obj, result.rowcount + return table_obj, result.rowcount # pylint: disable=too-many-ancestors,abstract-method @@ -663,6 +689,8 @@ class Sampler: def __init__(self, source: DatabaseConfig, config: SamplerConfig) -> None: self.config = config self.source_engine = source.database_engine(env_prefix="SUBSET_SOURCE_") + self.compact_columns: Dict[Tuple[str, str], Set[str]] = {} + self.temp_tables = TempTableCreator() def sample( self, @@ -688,12 +716,64 @@ def sample( output.truncate() output.prepare() + self.compact_columns = self._get_compact_columns(meta) with self.source_engine.execution_options().connect() as conn: self._materialize_tables(meta, conn, plan) self._copy_results( output, conn, meta, plan, insert_order, table_column_multipliers ) + def _get_compact_columns( + self, meta: DatabaseMetadata + ) -> Dict[Tuple[str, str], Set[str]]: + """ + Calculate the set of columns that need to be compacted by table. + """ + compact_columns = {} + + for table, cols in self.config.compact.columns.items(): + if not cols: + continue + table_key = parse_table_name(table) + if table_key not in meta.tables: + LOGGER.warning( + "Table %s has columns configured for compaction but is not found", + table, + ) + else: + compact_columns[table_key] = set(cols) + + if ( + not self.config.compact.primary_keys + and not self.config.compact.auto_increment_keys + ): + return compact_columns + + for table_key, table_meta in meta.tables.items(): + if len(table_meta.primary_key) != 1: + continue + + col = table_meta.table_obj.columns[table_meta.primary_key[0]] + if not issubclass(col.type.python_type, int): # type: ignore + continue + + if ( + self.config.compact.primary_keys + or table_meta.table_obj.autoincrement_column is not None + ): + compact_columns.setdefault(table_key, set()).add(col.name) + + for table_key, table_meta in meta.tables.items(): + table_compact_cols = compact_columns.get(table_key, set()) + for fk in table_meta.foreign_keys: + for col_name in fk.columns: + if col_name in table_compact_cols: + raise ValueError( + f"Cannot compact column {table_key[0]}.{table_key[1]}.{col_name} within foreign key" + ) + + return compact_columns + def _materialization_order( self, meta: DatabaseMetadata, plan: SubsetPlan ) -> List[Tuple[str, str, int]]: @@ -716,9 +796,18 @@ def _record_sampled_tables( for table, query in plan.queries.items(): counter: Dict[Tuple[str, str], int] = {} query.build(functools.partial(_record_sampled_tables, counter)) + dep_graph[parse_table_name(table)] = set(counter.keys()) + + # For calculating max ref count also need to count joins needed at sampling time + # to facilitate compaction. + for fk in meta.tables[parse_table_name(table)].foreign_keys: + key = (fk.dst_schema, fk.dst_table) + compact_cols = self.compact_columns.get(key, set()) + if any(col in compact_cols for col in fk.dst_columns): + counter[key] = counter.get(key, 0) + 1 + for key, count in counter.items(): max_ref_counts[key] = max(max_ref_counts.get(key, 0), count) - dep_graph[parse_table_name(table)] = set(counter.keys()) order: List[Tuple[str, str]] = toposort(dep_graph) return [ @@ -738,16 +827,33 @@ def _materialize_tables( table = meta.tables[(schema, table_name)] query = plan.queries[f"{schema}.{table_name}"] + table_q = query.build(meta.sql_build_context()) + if (schema, table_name) in self.compact_columns: + subq = table_q.subquery() + table_q = sa.select( + ( + sa.func.row_number(type_=sa.Integer).over( + order_by=[ + subq.c[col.name] for col in table.table_obj.primary_key + ], + ) + + self.config.compact.start_key + - 1 + ).label(SUBSETTER_COMPACT_COLUMN), + subq, + ) + LOGGER.info( "Materializing sample for %s.%s", schema, table_name, ) meta.temp_tables[(schema, table_name, 0)], rowcount = ( - create_temporary_table( + self.temp_tables.create( conn, schema, - query.build(meta.sql_build_context()), + table_q, + name=table_name, primary_key=table.primary_key, ) ) @@ -765,10 +871,11 @@ def _materialize_tables( # to work around an issue on mysql with reopening temporary tables. for index in range(1, ref_count): meta.temp_tables[(schema, table_name, index)], _ = ( - create_temporary_table( + self.temp_tables.create( conn, schema, meta.temp_tables[(schema, table_name, 0)].select(), + name=table_name, primary_key=table.primary_key, ) ) @@ -796,10 +903,81 @@ def _copy_results( LOGGER.info("Sampling %s.%s ...", schema, table_name) + build_ctx = meta.sql_build_context() if (schema, table_name, 0) in meta.temp_tables: - query_stmt = sa.select(meta.temp_tables[(schema, table_name, 0)]) + query_stmt = meta.temp_tables[(schema, table_name, 0)].select() else: - query_stmt = query.build(meta.sql_build_context()) + query_stmt = query.build(build_ctx) + + # Figure out what columns are foreign keys to compacted columns. + # We must update the value for that column in this table to match + # the new compacted column. + src_table = meta.tables[(schema, table_name)] + remote_compact_cols = {} + for fk in src_table.foreign_keys: + for src_col, dst_col in zip(fk.columns, fk.dst_columns): + if dst_col in self.compact_columns.get( + (fk.dst_schema, fk.dst_table), set() + ): + remote_compact_cols[src_col] = ( + fk.dst_schema, + fk.dst_table, + dst_col, + ) + + # Update the query to reflect compaction if needed. + compact_cols = self.compact_columns.get((schema, table_name), set()) + if remote_compact_cols or compact_cols: + subq = query_stmt.subquery() + id_col = subq.columns.get( + SUBSETTER_COMPACT_COLUMN, + sa.func.row_number(type_=sa.Integer).over( + order_by=[ + subq.c[col.name] for col in src_table.table_obj.primary_key + ], + ) + + self.config.compact.start_key + - 1, + ) + + # Determine the final values for each column in the sampled table. + # Each column either comes directly from the source table, is locally + # compacted, or comes from joining in the compacted ID from another + # pre-materialized table. + cols: List[sa.ColumnElement] = [] + select_from: sa.FromClause = subq + joined_tables = set() + for col in src_table.table_obj.columns: + if col.name in compact_cols: + cols.append(id_col.label(col.name)) + elif col.name in remote_compact_cols: + dst_schema, dst_table_name, dst_col = remote_compact_cols[ + col.name + ] + dst_table = build_ctx( + SQLTableIdentifier( + table_schema=dst_schema, + table_name=dst_table_name, + sampled=True, + ) + ) + if dst_table in joined_tables: + dst_table = dst_table.alias() + else: + joined_tables.add(dst_table) + cols.append( + dst_table.columns[SUBSETTER_COMPACT_COLUMN].label(col.name) + ) + select_from = sa.join( + select_from, + dst_table, + subq.columns[col.name] == dst_table.columns[dst_col], + isouter=True, + ) + else: + cols.append(subq.columns[col.name]) + + query_stmt = sa.select(*cols).select_from(select_from) LOGGER.debug( " Using statement %s", diff --git a/tests/data/datasets/fk_chain.yaml b/tests/data/datasets/fk_chain.yaml index 5d8cd67..4b1def6 100644 --- a/tests/data/datasets/fk_chain.yaml +++ b/tests/data/datasets/fk_chain.yaml @@ -97,6 +97,7 @@ data: - [4, 4, 6, 1] - [5, 6, 4, 0] - [6, 7, 1, 0] + - [7, 4, null, 1] test.bookmark: - [1, 1, 1] - [2, 2, 2] diff --git a/tests/data/fk_chain.yaml b/tests/data/fk_chain.yaml index 1ddc3b9..021a6c1 100644 --- a/tests/data/fk_chain.yaml +++ b/tests/data/fk_chain.yaml @@ -243,6 +243,10 @@ expected_sample: friend_b: 6 id: 4 sample: 1 + - friend_a: 4 + friend_b: null + id: 7 + sample: 1 test_out.bookmark: - id: 1 user_id: 1 diff --git a/tests/data/fk_chain_compact.yaml b/tests/data/fk_chain_compact.yaml new file mode 100644 index 0000000..6dde8a1 --- /dev/null +++ b/tests/data/fk_chain_compact.yaml @@ -0,0 +1,275 @@ +dataset: fk_chain + +plan_config: + targets: + test.friends: + in: + sample: [1] + select: + - test.* + +sample_config: + compact: + auto_increment_keys: true + start_key: 101 + +expected_plan: + queries: + test.bookmark: + statement: + from: + schema: test + table: bookmark + type: select + where: + columns: + - user_id + negated: false + type: in + values: + columns: + - id + from: + sampled: true + schema: test + table: users + type: select + test.friends: + statement: + from: + schema: test + table: friends + type: select + where: + columns: + - sample + type: in + values: + - - 1 + test.referal_owners: + statement: + from: + schema: test + table: referal_owners + type: select + where: + conditions: + - columns: + - referal_id + negated: false + type: in + values: + columns: + - id + from: + sampled: true + schema: test + table: referals + type: select + - columns: + - source_website_id + negated: false + type: in + values: + columns: + - id + from: + sampled: true + schema: test + table: websites + type: select + type: and + test.referals: + statement: + from: + schema: test + table: referals + type: select + where: + columns: + - website_id + negated: false + type: in + values: + columns: + - id + from: + sampled: true + schema: test + table: websites + type: select + test.users: + statement: + from: + schema: test + table: users + type: select + where: + conditions: + - columns: + - id + negated: false + type: in + values: + columns: + - friend_a + from: + sampled: true + schema: test + table: friends + type: select + - columns: + - id + negated: false + type: in + values: + columns: + - friend_b + from: + sampled: true + schema: test + table: friends + type: select + type: or + test.visits: + statement: + from: + schema: test + table: visits + type: select + where: + columns: + - user_id + negated: false + type: in + values: + columns: + - id + from: + sampled: true + schema: test + table: users + type: select + test.websites: + statement: + from: + schema: test + table: websites + type: select + where: + conditions: + - columns: + - id + negated: false + type: in + values: + columns: + - website_id + from: + sampled: true + schema: test + table: bookmark + type: select + - columns: + - id + negated: false + type: in + values: + columns: + - website_id + from: + sampled: true + schema: test + table: visits + type: select + type: or + +expected_sample: + test_out.bookmark: + - id: 101 + user_id: 101 + website_id: 101 + - id: 102 + user_id: 102 + website_id: 102 + - id: 103 + user_id: 104 + website_id: 103 + test_out.friends: + - friend_a: 101 + friend_b: 102 + id: 101 + sample: 1 + - friend_a: 103 + friend_b: 104 + id: 102 + sample: 1 + - friend_a: 103 + friend_b: 105 + id: 103 + sample: 1 + - friend_a: 103 + friend_b: null + id: 104 + sample: 1 + test_out.referal_owners: + - id: 101 + referal_id: 101 + source_website_id: 101 + - id: 102 + referal_id: 106 + source_website_id: 105 + test_out.referals: + - count: 13 + id: 101 + website_id: 101 + - count: 18 + id: 102 + website_id: 101 + - count: 99 + id: 103 + website_id: 102 + - count: 15 + id: 104 + website_id: 102 + - count: 7 + id: 105 + website_id: 103 + - count: 2 + id: 106 + website_id: 105 + test_out.users: + - id: 101 + name: john + - id: 102 + name: peter + - id: 103 + name: julia + - id: 104 + name: ashley + - id: 105 + name: daniel + test_out.visits: + - count: 10 + id: 101 + user_id: 101 + website_id: 105 + - count: 100 + id: 102 + user_id: 103 + website_id: 104 + - count: 555 + id: 103 + user_id: 104 + website_id: 102 + test_out.websites: + - id: 101 + url: web1 + - id: 102 + url: web2 + - id: 103 + url: web4 + - id: 104 + url: web6 + - id: 105 + url: web9 diff --git a/tests/data/user_orders_compact.yaml b/tests/data/user_orders_compact.yaml new file mode 100644 index 0000000..13de9e0 --- /dev/null +++ b/tests/data/user_orders_compact.yaml @@ -0,0 +1,94 @@ +dataset: user_orders + +plan_config: + targets: + test.users: + in: + sample: [1, 99] + select: + - test.* + +sample_config: + compact: + auto_increment_keys: true + +expected_plan: + queries: + test.order_status: + materialize: false + statement: + from: + schema: test + table: order_status + type: select + where: + columns: + - order_id + type: in + values: + columns: + - id + from: + sampled: true + schema: test + table: orders + type: select + test.orders: + materialize: true + statement: + from: + schema: test + table: orders + type: select + where: + columns: + - user_id + type: in + values: + columns: + - id + from: + sampled: true + schema: test + table: users + type: select + test.users: + materialize: true + statement: + type: select + from: + schema: test + table: users + where: + type: in + columns: [sample] + values: [[1], [99]] + +expected_sample: + test_out.users: + - id: 1 + name: john + sample: 1 + - id: 2 + name: richard + sample: 1 + test_out.orders: + - id: 1 + user_id: 1 + name: stuff + - id: 2 + user_id: 2 + name: gold + test_out.order_status: + - id: 1 + order_id: 1 + info: pending + order_square: 1 + - id: 2 + order_id: 1 + info: sent + order_square: 1 + - id: 3 + order_id: 2 + info: lost + order_square: 9 diff --git a/tests/test_live.py b/tests/test_live.py index cb6fa3b..043c304 100644 --- a/tests/test_live.py +++ b/tests/test_live.py @@ -88,6 +88,11 @@ def test_user_orders(db_config): do_dataset_test(db_config, "user_orders") +@pytest.mark.parametrize("db_config", DATABASE_CONFIGURATIONS, indirect=True) +def test_user_orders_compact(db_config): + do_dataset_test(db_config, "user_orders_compact") + + @pytest.mark.parametrize("db_config", DATABASE_CONFIGURATIONS, indirect=True) def test_data_types(db_config): do_dataset_test(db_config, "data_types") @@ -98,6 +103,11 @@ def test_fk_chain(db_config): do_dataset_test(db_config, "fk_chain") +@pytest.mark.parametrize("db_config", DATABASE_CONFIGURATIONS, indirect=True) +def test_fk_chain_compact(db_config): + do_dataset_test(db_config, "fk_chain_compact") + + @pytest.mark.parametrize("db_config", DATABASE_CONFIGURATIONS, indirect=True) def test_instruments(db_config): do_dataset_test(db_config, "instruments")