From 86eaf102e8697cbc8330d8c47fa52449002dfad0 Mon Sep 17 00:00:00 2001 From: jkeifer Date: Thu, 22 Jun 2023 10:20:44 -0700 Subject: [PATCH] add execute-sql cli command --- CHANGELOG.md | 1 + src/dbami/cli.py | 21 +++++++++++++++++++++ tests/test_cli.py | 17 +++++++++++++++++ 3 files changed, 39 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index e0783a8..0841492 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. - `list-fixtures` command ([#10]) - `load-fixture` command ([#10]) - `DB.execute_sql()` method for running arbitrary SQL against a database ([#10]) +- `execute-sql` command ([#10]) ### Changed diff --git a/src/dbami/cli.py b/src/dbami/cli.py index 56e35ae..e32db72 100644 --- a/src/dbami/cli.py +++ b/src/dbami/cli.py @@ -538,6 +538,26 @@ async def run() -> int: return syncrun(run()) +class ExecuteSql(DbamiCommand): + help: str = "Run SQL from stdin against the database" + name: str = "execute-sql" + + def set_args( + self, + parser: argparse.ArgumentParser, + ) -> None: + Arguments.project(parser) + Arguments.wait_timeout(parser) + Arguments.database(parser) + + def __call__(self, args: argparse.Namespace) -> int: + async def run() -> int: + await args.db.execute_sql(sys.stdin.read(), database=args.database) + return 0 + + return syncrun(run()) + + class CLI(abc.ABC): def __init__( self, @@ -604,6 +624,7 @@ class DbamiCLI(CLI): Version(), ListFixtures(), LoadFixture(), + ExecuteSql(), ) } diff --git a/tests/test_cli.py b/tests/test_cli.py index 6103713..4590b66 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -476,3 +476,20 @@ def test_load_fixtures_extra(tmp_db, project_dir, extra_fixtures): print(out) print(err) assert rc == 0 + + +def test_execute_sql(tmp_db, project): + stdin = io.StringIO() + stdin.write("create table a_table (id int primary key);") + stdin.seek(0) + rc, out, err = run_cli( + "execute-sql", + "--database", + tmp_db, + stdin=stdin, + ) + print(out) + print(err) + assert rc == 0 + syncrun(project.execute_sql("select * from a_table", database=tmp_db)) + assert True