From 1117af71f2bb0011e0c54fc9170580beeddc8f53 Mon Sep 17 00:00:00 2001 From: AlexandraImbrisca Date: Tue, 14 Jan 2025 20:45:12 +0100 Subject: [PATCH] Reorder functions and use parametrised testing to check both implementations --- .../xml_download/utils_write_to_database.py | 190 ++++++++++-------- .../test_utils_write_to_database.py | 93 ++++++--- 2 files changed, 170 insertions(+), 113 deletions(-) diff --git a/open_mastr/xml_download/utils_write_to_database.py b/open_mastr/xml_download/utils_write_to_database.py index 83251dc5..5176821d 100644 --- a/open_mastr/xml_download/utils_write_to_database.py +++ b/open_mastr/xml_download/utils_write_to_database.py @@ -6,7 +6,7 @@ import numpy as np import pandas as pd import sqlalchemy -from sqlalchemy import inspect +from sqlalchemy import inspect, select from sqlalchemy.sql import text from sqlalchemy.sql.sqltypes import Date, DateTime @@ -112,6 +112,20 @@ def is_first_file(file_name: str) -> bool: ) +def cast_date_columns_to_datetime( + xml_table_name: str, df: pd.DataFrame +) -> pd.DataFrame: + sqlalchemy_columnlist = tablename_mapping[xml_table_name][ + "__class__" + ].__table__.columns.items() + for column in sqlalchemy_columnlist: + column_name = column[0] + if is_date_column(column, df): + # Convert column to datetime64, invalid string -> NaT + df[column_name] = pd.to_datetime(df[column_name], errors="coerce") + return df + + def cast_date_columns_to_string(xml_table_name: str, df: pd.DataFrame) -> pd.DataFrame: column_list = tablename_mapping[xml_table_name][ "__class__" @@ -185,6 +199,52 @@ def change_column_names_to_orm_format( return df +def add_table_to_non_sqlite_database( + df: pd.DataFrame, + xml_table_name: str, + sql_table_name: str, + engine: sqlalchemy.engine.Engine, +) -> None: + # get a dictionary for the data types + table_columns_list = list( + tablename_mapping[xml_table_name]["__class__"].__table__.columns + ) + dtypes_for_writing_sql = { + column.name: column.type + for column in table_columns_list + if column.name in df.columns + } + + # Convert date and datetime columns into the datatype datetime. + df = cast_date_columns_to_datetime(xml_table_name, df) + + add_missing_columns_to_table( + engine, xml_table_name, column_list=df.columns.tolist() + ) + + for _ in range(10000): + try: + with engine.connect() as con: + with con.begin(): + df.to_sql( + sql_table_name, + con=con, + index=False, + if_exists="append", + dtype=dtypes_for_writing_sql, + ) + break + + except sqlalchemy.exc.DataError as err: + delete_wrong_xml_entry(err, df) + + except sqlalchemy.exc.IntegrityError: + # error resulting from Unique constraint failed + df = write_single_entries_until_not_unique_comes_up( + df, xml_table_name, engine + ) + + def add_zero_as_first_character_for_too_short_string(df: pd.DataFrame) -> pd.DataFrame: """Some columns are read as integer even though they are actually strings starting with a 0. This function converts those columns back to strings and adds a 0 as first character. @@ -217,6 +277,46 @@ def add_zero_as_first_character_for_too_short_string(df: pd.DataFrame) -> pd.Dat return df +def write_single_entries_until_not_unique_comes_up( + df: pd.DataFrame, xml_table_name: str, engine: sqlalchemy.engine.Engine +) -> pd.DataFrame: + """ + Remove from dataframe these rows, which are already existing in the database table + Parameters + ---------- + df + xml_table_name + engine + + Returns + ------- + Filtered dataframe + """ + + table = tablename_mapping[xml_table_name]["__class__"].__table__ + primary_key = next(c for c in table.columns if c.primary_key) + + with engine.connect() as con: + with con.begin(): + key_list = ( + pd.read_sql(sql=select(primary_key), con=con).values.squeeze().tolist() + ) + + len_df_before = len(df) + df = df.drop_duplicates( + subset=[primary_key.name] + ) # drop all entries with duplicated primary keys in the dataframe + df = df.set_index(primary_key.name) + + df = df.drop( + labels=key_list, errors="ignore" + ) # drop primary keys that already exist in the table + df = df.reset_index() + print(f"{len_df_before - len(df)} entries already existed in the database.") + + return df + + def add_missing_columns_to_table( engine: sqlalchemy.engine.Engine, xml_table_name: str, @@ -356,91 +456,3 @@ def add_table_to_sqlite_database( break except sqlalchemy.exc.DataError as err: delete_wrong_xml_entry(err, df) - - -def write_single_entries_until_not_unique_comes_up( - df: pd.DataFrame, xml_table_name: str, engine: sqlalchemy.engine.Engine -) -> pd.DataFrame: - """ - Remove from dataframe these rows, which are already existing in the database table - Parameters - ---------- - df - xml_table_name - engine - - Returns - ------- - Filtered dataframe - """ - - table = tablename_mapping[xml_table_name]["__class__"].__table__ - primary_key = next(c for c in table.columns if c.primary_key) - - with engine.connect() as con: - with con.begin(): - key_list = ( - pd.read_sql(sql=select(primary_key), con=con).values.squeeze().tolist() - ) - - len_df_before = len(df) - df = df.drop_duplicates( - subset=[primary_key.name] - ) # drop all entries with duplicated primary keys in the dataframe - df = df.set_index(primary_key.name) - - df = df.drop( - labels=key_list, errors="ignore" - ) # drop primary keys that already exist in the table - df = df.reset_index() - print(f"{len_df_before - len(df)} entries already existed in the database.") - - return df - - -def add_table_to_non_sqlite_database( - df: pd.DataFrame, - xml_table_name: str, - sql_table_name: str, - engine: sqlalchemy.engine.Engine, -) -> None: - def add_table_to_database( - df: pd.DataFrame, - xml_table_name: str, - sql_table_name: str, - engine: sqlalchemy.engine.Engine, - ) -> None: - # get a dictionary for the data types - table_columns_list = list( - tablename_mapping[xml_table_name]["__class__"].__table__.columns - ) - dtypes_for_writing_sql = { - column.name: column.type - for column in table_columns_list - if column.name in df.columns - } - - add_missing_columns_to_table( - engine, xml_table_name, column_list=df.columns.tolist() - ) - for _ in range(10000): - try: - with engine.connect() as con: - with con.begin(): - df.to_sql( - sql_table_name, - con=con, - index=False, - if_exists="append", - dtype=dtypes_for_writing_sql, - ) - break - - except sqlalchemy.exc.DataError as err: - delete_wrong_xml_entry(err, df) - - except sqlalchemy.exc.IntegrityError: - # error resulting from Unique constraint failed - df = write_single_entries_until_not_unique_comes_up( - df, xml_table_name, engine - ) diff --git a/tests/xml_download/test_utils_write_to_database.py b/tests/xml_download/test_utils_write_to_database.py index 2619b9a8..e72e5501 100644 --- a/tests/xml_download/test_utils_write_to_database.py +++ b/tests/xml_download/test_utils_write_to_database.py @@ -27,7 +27,8 @@ is_table_relevant, process_table_before_insertion, read_xml_file, - add_table_to_database, + add_table_to_non_sqlite_database, + add_table_to_sqlite_database, ) # Check if xml file exists @@ -115,16 +116,27 @@ def test_cast_date_columns_to_string(): initial_df = pd.DataFrame( { "EegMastrNummer": [1, 2, 3], - "Registrierungsdatum": [datetime(2024, 3, 11).date(), datetime(1999, 2, 1).date(), np.datetime64("nat")], - "DatumLetzteAktualisierung": [datetime(2022, 3, 22), datetime(2020, 1, 2, 10, 12, 46), - np.datetime64("nat")], + "Registrierungsdatum": [ + datetime(2024, 3, 11).date(), + datetime(1999, 2, 1).date(), + np.datetime64("nat"), + ], + "DatumLetzteAktualisierung": [ + datetime(2022, 3, 22), + datetime(2020, 1, 2, 10, 12, 46), + np.datetime64("nat"), + ], } ) expected_df = pd.DataFrame( { "EegMastrNummer": [1, 2, 3], "Registrierungsdatum": ["2024-03-11", "1999-02-01", np.nan], - "DatumLetzteAktualisierung": ["2022-03-22 00:00:00.000000", "2020-01-02 10:12:46.000000", np.nan], + "DatumLetzteAktualisierung": [ + "2022-03-22 00:00:00.000000", + "2020-01-02 10:12:46.000000", + np.nan, + ], } ) @@ -146,10 +158,14 @@ def test_is_date_column(): date_column = list(filter(lambda col: col[0] == "Id", columns))[0] assert is_date_column(date_column, df) is False - datetime_column = list(filter(lambda col: col[0] == "DatumLetzteAktualisierung", columns))[0] + datetime_column = list( + filter(lambda col: col[0] == "DatumLetzteAktualisierung", columns) + )[0] assert is_date_column(datetime_column, df) is True - date_column = list(filter(lambda col: col[0] == "WiederinbetriebnahmeDatum", columns))[0] + date_column = list( + filter(lambda col: col[0] == "WiederinbetriebnahmeDatum", columns) + )[0] assert is_date_column(date_column, df) is True @@ -212,7 +228,9 @@ def test_read_xml_file(zipped_xml_file_path): # correctly created, we check that all of its columns are associated are included in our mapping. for column in df.columns: if column in tablename_mapping["einheitenkernkraft"]["replace_column_names"]: - column = tablename_mapping["einheitenkernkraft"]["replace_column_names"][column] + column = tablename_mapping["einheitenkernkraft"]["replace_column_names"][ + column + ] assert column in NuclearExtended.__table__.columns.keys() @@ -246,7 +264,8 @@ def test_change_column_names_to_orm_format(): ) pd.testing.assert_frame_equal( - expected_df, change_column_names_to_orm_format(initial_df, "lokationen")) + expected_df, change_column_names_to_orm_format(initial_df, "lokationen") + ) def test_process_table_before_insertion(zipped_xml_file_path): @@ -272,8 +291,14 @@ def test_process_table_before_insertion(zipped_xml_file_path): pd.testing.assert_frame_equal( expected_df, - process_table_before_insertion(initial_df, "einheitenkernkraft", zipped_xml_file_path, bulk_download_date, - bulk_cleansing=False)) + process_table_before_insertion( + initial_df, + "einheitenkernkraft", + zipped_xml_file_path, + bulk_download_date, + bulk_cleansing=False, + ), + ) def test_add_missing_columns_to_table(engine_testdb): @@ -289,36 +314,49 @@ def test_add_missing_columns_to_table(engine_testdb): "DatumLetzteAktualisierung": [datetime(2022, 2, 2)], } ) - initial_data_in_db.to_sql('gas_consumer', con=con, if_exists='append', index=False) + initial_data_in_db.to_sql( + "gas_consumer", con=con, if_exists="append", index=False + ) - add_missing_columns_to_table(engine_testdb, 'einheitengasverbraucher', ["NewColumn"]) + add_missing_columns_to_table( + engine_testdb, "einheitengasverbraucher", ["NewColumn"] + ) expected_df = pd.DataFrame( { "EinheitMastrNummer": ["id1"], "DatumLetzteAktualisierung": [datetime(2022, 2, 2)], - "NewColumn": [None] + "NewColumn": [None], } ) with engine_testdb.connect() as con: with con.begin(): - actual_df = pd.read_sql_table('gas_consumer', con=con) + actual_df = pd.read_sql_table("gas_consumer", con=con) # The actual_df will contain more columns than the expected_df, so we can't use assert_frame_equal. assert expected_df.index.isin(actual_df.index).all() -def test_add_table_to_database(engine_testdb): +@pytest.mark.parametrize( + "add_table_to_database_function", + [add_table_to_sqlite_database, add_table_to_non_sqlite_database], +) +def test_add_table_to_sqlite_database(engine_testdb, add_table_to_database_function): with engine_testdb.connect() as con: with con.begin(): # We must recreate the table to be sure that no other data is present. con.execute(text("DROP TABLE IF EXISTS gsgk_eeg")) - create_database_table(engine_testdb, "anlageneeggeothermiegrubengasdruckentspannung") + create_database_table( + engine_testdb, "anlageneeggeothermiegrubengasdruckentspannung" + ) df = pd.DataFrame( { "Registrierungsdatum": ["2022-02-02", "2024-03-20"], "EegMastrNummer": ["id1", "id2"], - "DatumLetzteAktualisierung": ["2022-12-02 10:10:10.000300", "2024-10-10 00:00:00.000000"], + "DatumLetzteAktualisierung": [ + "2022-12-02 10:10:10.000300", + "2024-10-10 00:00:00.000000", + ], "AusschreibungZuschlag": [True, False], "Netzbetreiberzuordnungen": ["test1", "test2"], "InstallierteLeistung": [1.0, 100.4], @@ -330,9 +368,12 @@ def test_add_table_to_database(engine_testdb): "AnlageBetriebsstatus": [None, None], "Registrierungsdatum": [datetime(2022, 2, 2), datetime(2024, 3, 20)], "EegMastrNummer": ["id1", "id2"], - "Meldedatum": [np.datetime64('NaT'), np.datetime64('NaT')], - "DatumLetzteAktualisierung": [datetime(2022, 12, 2, 10, 10, 10, 300), datetime(2024, 10, 10)], - "EegInbetriebnahmedatum": [np.datetime64('NaT'), np.datetime64('NaT')], + "Meldedatum": [np.datetime64("NaT"), np.datetime64("NaT")], + "DatumLetzteAktualisierung": [ + datetime(2022, 12, 2, 10, 10, 10, 300), + datetime(2024, 10, 10), + ], + "EegInbetriebnahmedatum": [np.datetime64("NaT"), np.datetime64("NaT")], "VerknuepfteEinheit": [None, None], "AnlagenschluesselEeg": [None, None], "AusschreibungZuschlag": [True, False], @@ -340,11 +381,15 @@ def test_add_table_to_database(engine_testdb): "AnlagenkennzifferAnlagenregister_nv": [None, None], "Netzbetreiberzuordnungen": ["test1", "test2"], "DatenQuelle": [None, None], - "DatumDownload": [np.datetime64('NaT'), np.datetime64('NaT')], + "DatumDownload": [np.datetime64("NaT"), np.datetime64("NaT")], } ) - add_table_to_database(df, "anlageneeggeothermiegrubengasdruckentspannung", "gsgk_eeg", engine_testdb) + add_table_to_database_function( + df, "anlageneeggeothermiegrubengasdruckentspannung", "gsgk_eeg", engine_testdb + ) with engine_testdb.connect() as con: with con.begin(): - pd.testing.assert_frame_equal(expected_df, pd.read_sql_table('gsgk_eeg', con=con)) + pd.testing.assert_frame_equal( + expected_df, pd.read_sql_table("gsgk_eeg", con=con) + )