diff --git a/schemachange/config/BaseConfig.py b/schemachange/config/BaseConfig.py index 0051d40..5f9b192 100644 --- a/schemachange/config/BaseConfig.py +++ b/schemachange/config/BaseConfig.py @@ -4,7 +4,7 @@ import logging from abc import ABC from pathlib import Path -from typing import Literal, ClassVar, TypeVar +from typing import Literal, TypeVar import structlog @@ -38,10 +38,13 @@ def factory( modules_folder: Path | str | None = None, config_vars: str | dict | None = None, log_level: int = logging.INFO, + connection_secrets: set[str] | None = None, **kwargs, ): try: secrets = get_config_secrets(config_vars) + if connection_secrets is not None: + secrets.update(connection_secrets) except Exception as e: raise Exception( "config_vars did not parse correctly, please check its configuration" diff --git a/schemachange/config/DeployConfig.py b/schemachange/config/DeployConfig.py index 1558b94..3cffeaf 100644 --- a/schemachange/config/DeployConfig.py +++ b/schemachange/config/DeployConfig.py @@ -89,10 +89,20 @@ def factory( table_str=change_history_table ) + connection_secrets = { + secret + for secret in [ + kwargs.get("snowflake_password"), + kwargs.get("snowflake_oauth_token"), + ] + if secret is not None + } + return super().factory( subcommand="deploy", config_file_path=config_file_path, change_history_table=change_history_table, + connection_secrets=connection_secrets, **kwargs, ) diff --git a/schemachange/session/SnowflakeSession.py b/schemachange/session/SnowflakeSession.py index 6078f9f..1c3bd96 100644 --- a/schemachange/session/SnowflakeSession.py +++ b/schemachange/session/SnowflakeSession.py @@ -59,24 +59,24 @@ def __init__( if query_tag: self.session_parameters["QUERY_TAG"] += f";{query_tag}" - self.con = snowflake.connector.connect( - account=self.account, - user=self.user, - database=kwargs.get("database"), - schema=kwargs.get("schema"), - role=self.role, - warehouse=self.warehouse, - private_key=kwargs.get("private_key"), - private_key_file=kwargs.get("private_key_path"), - private_key_file_pwd=kwargs.get("private_key_path_password"), - token=kwargs.get("oauth_token"), - password=kwargs.get("password"), - authenticator=kwargs.get("authenticator"), - connection_name=kwargs.get("connection_name"), - connections_file_path=kwargs.get("connections_file_path"), - application=application, - session_parameters=self.session_parameters, - ) + connect_kwargs = { + "account": self.account, + "user": self.user, + "database": kwargs.get("database"), + "schema": kwargs.get("schema"), + "role": self.role, + "warehouse": self.warehouse, + "private_key_file": kwargs.get("private_key_path"), + "token": kwargs.get("oauth_token"), + "password": kwargs.get("password"), + "authenticator": kwargs.get("authenticator"), + "connection_name": kwargs.get("connection_name"), + "connections_file_path": kwargs.get("connections_file_path"), + "application": application, + "session_parameters": self.session_parameters, + } + self.logger.info("snowflake.connector.connect kwargs", **connect_kwargs) + self.con = snowflake.connector.connect(**connect_kwargs) print(f"Current session ID: {self.con.session_id}") if not self.autocommit: diff --git a/tests/session/test_SnowflakeSession.py b/tests/session/test_SnowflakeSession.py index 96b39ac..32b9d4b 100644 --- a/tests/session/test_SnowflakeSession.py +++ b/tests/session/test_SnowflakeSession.py @@ -34,7 +34,7 @@ def test_fetch_change_history_metadata_exists(self, session: SnowflakeSession): result = session.fetch_change_history_metadata() assert result == {"created": "created", "last_altered": "last_altered"} assert session.con.execute_string.call_count == 1 - assert session.logger.calls[0][1][0] == "Executing query" + assert session.logger.calls[1][1][0] == "Executing query" def test_fetch_change_history_metadata_does_not_exist( self, session: SnowflakeSession @@ -43,4 +43,4 @@ def test_fetch_change_history_metadata_does_not_exist( result = session.fetch_change_history_metadata() assert result == {} assert session.con.execute_string.call_count == 1 - assert session.logger.calls[0][1][0] == "Executing query" + assert session.logger.calls[1][1][0] == "Executing query"