Skip to content

Commit

Permalink
feat: respect SNOWFLAKE_DEFAULT_CONNECTION_NAME environment variable
Browse files Browse the repository at this point in the history
  • Loading branch information
Zane Clark committed Oct 31, 2024
1 parent 6ee659f commit 6eaeaa9
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 23 deletions.
7 changes: 6 additions & 1 deletion schemachange/config/get_merged_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def get_yaml_config_kwargs(config_file_path: Optional[Path]) -> dict:

def get_merged_config() -> Union[DeployConfig, RenderConfig]:
env_kwargs: dict[str, str] = get_env_kwargs()
connection_name = env_kwargs.pop("connection_name", None)

cli_kwargs = parse_cli_args(sys.argv[1:])

Expand All @@ -46,7 +47,10 @@ def get_merged_config() -> Union[DeployConfig, RenderConfig]:
connections_file_path = validate_file_path(
file_path=cli_kwargs.pop("connections_file_path", None)
)
connection_name = cli_kwargs.pop("connection_name", None)

if connection_name is None:
connection_name = cli_kwargs.pop("connection_name", None)

config_folder = validate_directory(path=cli_kwargs.pop("config_folder", "."))
config_file_name = cli_kwargs.pop("config_file_name")
config_file_path = Path(config_folder) / config_file_name
Expand All @@ -61,6 +65,7 @@ def get_merged_config() -> Union[DeployConfig, RenderConfig]:
if connections_file_path is None:
connections_file_path = yaml_kwargs.pop("connections_file_path", None)
if config_folder is not None and connections_file_path is not None:
# noinspection PyTypeChecker
connections_file_path = config_folder / connections_file_path

connections_file_path = validate_file_path(file_path=connections_file_path)
Expand Down
1 change: 1 addition & 0 deletions schemachange/config/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ def get_env_kwargs() -> dict[str, str]:
"snowflake_private_key_path": os.getenv("SNOWFLAKE_PRIVATE_KEY_PATH"),
"snowflake_authenticator": os.getenv("SNOWFLAKE_AUTHENTICATOR"),
"snowflake_oauth_token": os.getenv("SNOWFLAKE_TOKEN"),
"connection_name": os.getenv("SNOWFLAKE_DEFAULT_CONNECTION_NAME"),
}
return {k: v for k, v in env_kwargs.items() if v is not None}

Expand Down
14 changes: 14 additions & 0 deletions tests/config/alt-connections.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,17 @@ port = "alt-connections.toml-port"
region = "alt-connections.toml-region"
private-key = "alt-connections.toml-private-key"
token_file_path = "alt-connections.toml-token_file_path"
[anotherconnection]
account = "another-connections.toml-account"
user = "another-connections.toml-user"
role = "another-connections.toml-role"
warehouse = "another-connections.toml-warehouse"
database = "another-connections.toml-database"
schema = "another-connections.toml-schema"
authenticator = "another-connections.toml-authenticator"
password = "another-connections.toml-password"
host = "another-connections.toml-host"
port = "another-connections.toml-port"
region = "another-connections.toml-region"
private-key = "another-connections.toml-private-key"
token_file_path = "another-connections.toml-token_file_path"
30 changes: 8 additions & 22 deletions tests/config/test_get_merged_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,7 @@ def get_connection_from_toml(file_path: Path, connection_name: str) -> dict:
"env_kwargs, cli_kwargs, yaml_kwargs, connection_kwargs, expected",
[
pytest.param(
{ # env_kwargs
"snowflake_password": None,
"snowflake_private_key_path": None,
"snowflake_authenticator": None,
},
{}, # env_kwargs
{**default_cli_kwargs}, # cli_kwargs
{}, # yaml_kwargs
{}, # connection_kwargs
Expand All @@ -68,11 +64,7 @@ def get_connection_from_toml(file_path: Path, connection_name: str) -> dict:
id="Deploy: Only required arguments",
),
pytest.param(
{ # env_kwargs
"snowflake_password": None,
"snowflake_private_key_path": None,
"snowflake_authenticator": None,
},
{}, # env_kwargs
{**default_cli_kwargs}, # cli_kwargs
{}, # yaml_kwargs
{ # connection_kwargs
Expand Down Expand Up @@ -106,11 +98,7 @@ def get_connection_from_toml(file_path: Path, connection_name: str) -> dict:
id="Deploy: all connection_kwargs",
),
pytest.param(
{ # env_kwargs
"snowflake_password": None,
"snowflake_private_key_path": None,
"snowflake_authenticator": None,
},
{}, # env_kwargs
{**default_cli_kwargs}, # cli_kwargs
{ # yaml_kwargs
"root_folder": "yaml_root_folder",
Expand Down Expand Up @@ -184,11 +172,7 @@ def get_connection_from_toml(file_path: Path, connection_name: str) -> dict:
id="Deploy: all yaml, all connection_kwargs",
),
pytest.param(
{ # env_kwargs
"snowflake_password": None,
"snowflake_private_key_path": None,
"snowflake_authenticator": None,
},
{}, # env_kwargs
{ # cli_kwargs
**default_cli_kwargs,
"config_folder": "cli_config_folder",
Expand Down Expand Up @@ -293,6 +277,7 @@ def get_connection_from_toml(file_path: Path, connection_name: str) -> dict:
"snowflake_password": "env_snowflake_password",
"snowflake_private_key_path": "env_snowflake_private_key_path",
"snowflake_authenticator": "env_snowflake_authenticator",
"connection_name": "env_connection_name",
},
{ # cli_kwargs
**default_cli_kwargs,
Expand Down Expand Up @@ -383,7 +368,7 @@ def get_connection_from_toml(file_path: Path, connection_name: str) -> dict:
"snowflake_private_key_path": "env_snowflake_private_key_path",
"snowflake_token_path": "cli_snowflake_token_path",
"connections_file_path": Path("cli_connections_file_path"),
"connection_name": "cli_connection_name",
"connection_name": "env_connection_name",
"change_history_table": "cli_change_history_table",
"create_change_history_table": False,
"autocommit": False,
Expand Down Expand Up @@ -707,6 +692,7 @@ def test_invalid_config_folder(mock_parse_cli_args, _):
"SNOWFLAKE_PRIVATE_KEY_PATH": "env_snowflake_private_key_path",
"SNOWFLAKE_AUTHENTICATOR": "env_snowflake_authenticator",
"SNOWFLAKE_TOKEN": "env_snowflake_token",
"SNOWFLAKE_DEFAULT_CONNECTION_NAME": "anotherconnection",
}, # env_kwargs
[ # cli_args
"schemachange",
Expand Down Expand Up @@ -780,7 +766,7 @@ def test_invalid_config_folder(mock_parse_cli_args, _):
"query_tag": "query-tag-from-cli",
"snowflake_oauth_token": "env_snowflake_token",
"oauth_config": {"oauth_config_variable": "cli_oauth_config_value"},
"connection_name": "myaltconnection",
"connection_name": "anotherconnection",
"connections_file_path": assets_path / "alt-connections.toml",
"snowflake_password": "env_snowflake_password",
},
Expand Down
4 changes: 4 additions & 0 deletions tests/config/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ def test_get_snowflake_password(env_vars: dict, expected: str):
{"SNOWFLAKE_AUTHENTICATOR": "my_snowflake_authenticator"},
{"snowflake_authenticator": "my_snowflake_authenticator"},
),
(
{"SNOWFLAKE_DEFAULT_CONNECTION_NAME": "my_connection_name"},
{"connection_name": "my_connection_name"},
),
],
)
def test_get_env_kwargs(env_vars: dict, expected: str):
Expand Down

0 comments on commit 6eaeaa9

Please sign in to comment.