diff --git a/Cargo.lock b/Cargo.lock index 67e9d12531..2f92a87a54 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2414,6 +2414,51 @@ version = "2.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" +[[package]] +name = "pest" +version = "2.7.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "879952a81a83930934cbf1786752d6dedc3b1f29e8f8fb2ad1d0a36f377cf442" +dependencies = [ + "memchr", + "thiserror", + "ucd-trie", +] + +[[package]] +name = "pest_derive" +version = "2.7.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d214365f632b123a47fd913301e14c946c61d1c183ee245fa76eb752e59a02dd" +dependencies = [ + "pest", + "pest_generator", +] + +[[package]] +name = "pest_generator" +version = "2.7.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb55586734301717aea2ac313f50b2eb8f60d2fc3dc01d190eefa2e625f60c4e" +dependencies = [ + "pest", + "pest_meta", + "proc-macro2", + "quote", + "syn 2.0.52", +] + +[[package]] +name = "pest_meta" +version = "2.7.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b75da2a70cf4d9cb76833c990ac9cd3923c9a8905a8929789ce347c84564d03d" +dependencies = [ + "once_cell", + "pest", + "sha2", +] + [[package]] name = "pin-project" version = "1.1.5" @@ -3647,6 +3692,8 @@ dependencies = [ "memchr", "num-bigint", "once_cell", + "pest", + "pest_derive", "rand", "rust_decimal", "serde", @@ -4085,6 +4132,12 @@ version = "1.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" +[[package]] +name = "ucd-trie" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2896d95c02a80c6d6a5d6e953d479f5ddf2dfdb6a244441010e373ac0fb88971" + [[package]] name = "unicode-bidi" version = "0.3.15" diff --git a/sqlx-postgres/Cargo.toml b/sqlx-postgres/Cargo.toml index e63d3065f6..637e679dc0 100644 --- a/sqlx-postgres/Cargo.toml +++ b/sqlx-postgres/Cargo.toml @@ -70,6 +70,8 @@ whoami = { version = "1.2.1", default-features = false } serde = { version = "1.0.144", features = ["derive"] } serde_json = { version = "1.0.85", features = ["raw_value"] } +pest = "2.7.14" +pest_derive = "2.7.14" [dependencies.sqlx-core] workspace = true diff --git a/sqlx-postgres/src/lib.rs b/sqlx-postgres/src/lib.rs index c50f53067e..6bab4f5c40 100644 --- a/sqlx-postgres/src/lib.rs +++ b/sqlx-postgres/src/lib.rs @@ -18,6 +18,7 @@ mod message; mod options; mod query_result; mod row; +mod split_sql; mod statement; mod transaction; mod type_checking; diff --git a/sqlx-postgres/src/migrate.rs b/sqlx-postgres/src/migrate.rs index c37e92f4d6..7bca218560 100644 --- a/sqlx-postgres/src/migrate.rs +++ b/sqlx-postgres/src/migrate.rs @@ -14,6 +14,7 @@ use crate::executor::Executor; use crate::query::query; use crate::query_as::query_as; use crate::query_scalar::query_scalar; +use crate::split_sql::split_sql; use crate::{PgConnectOptions, PgConnection, Postgres}; fn parse_for_maintenance(url: &str) -> Result<(PgConnectOptions, String), Error> { @@ -209,7 +210,6 @@ CREATE TABLE IF NOT EXISTS _sqlx_migrations ( ) -> BoxFuture<'m, Result> { Box::pin(async move { let start = Instant::now(); - // execute migration queries if migration.no_tx { execute_migration(self, migration).await?; @@ -276,10 +276,20 @@ async fn execute_migration( conn: &mut PgConnection, migration: &Migration, ) -> Result<(), MigrateError> { - let _ = conn - .execute(&*migration.sql) - .await - .map_err(|e| MigrateError::ExecuteMigration(e, migration.version))?; + if migration.no_tx { + let statements = split_sql(&*migration.sql); + for sql in statements.iter() { + let _ = conn + .execute(sql.as_str()) + .await + .map_err(|e| MigrateError::ExecuteMigration(e, migration.version))?; + } + } else { + let _ = conn + .execute(&*migration.sql) + .await + .map_err(|e| MigrateError::ExecuteMigration(e, migration.version))?; + } // language=SQL let _ = query( @@ -301,11 +311,20 @@ async fn revert_migration( conn: &mut PgConnection, migration: &Migration, ) -> Result<(), MigrateError> { - let _ = conn - .execute(&*migration.sql) - .await - .map_err(|e| MigrateError::ExecuteMigration(e, migration.version))?; - + if migration.no_tx { + let statements = split_sql(&*migration.sql); + for sql in statements.iter() { + let _ = conn + .execute(sql.as_str()) + .await + .map_err(|e| MigrateError::ExecuteMigration(e, migration.version))?; + } + } else { + let _ = conn + .execute(&*migration.sql) + .await + .map_err(|e| MigrateError::ExecuteMigration(e, migration.version))?; + } // language=SQL let _ = query(r#"DELETE FROM _sqlx_migrations WHERE version = $1"#) .bind(migration.version) diff --git a/sqlx-postgres/src/split_sql.rs b/sqlx-postgres/src/split_sql.rs new file mode 100644 index 0000000000..62b36c6915 --- /dev/null +++ b/sqlx-postgres/src/split_sql.rs @@ -0,0 +1,130 @@ +use pest::Parser; +use pest_derive::Parser; + +#[derive(Parser)] +#[grammar_inline = r#" +// The top-level rule matches the entire SQL input +sql = { SOI ~ statement* ~ EOI } + +// A statement consists of optional leading comments and whitespace, content, and is terminated by a semicolon or end of input +statement = { (WHITESPACE | COMMENT)* ~ statement_content ~ (semicolon | &EOI) } + +// Statement content is a sequence of constructs, comments, whitespace, or non-construct characters +statement_content = { (construct | COMMENT | WHITESPACE | non_construct_char)+ } + +// Constructs that may contain semicolons internally +construct = { DOLLAR_QUOTED_STRING | SINGLE_QUOTED_STRING | DOUBLE_QUOTED_IDENTIFIER } + +// Non-construct characters are any characters except semicolons +non_construct_char = { !semicolon ~ ANY } + +// Semicolon outside constructs acts as a statement terminator +semicolon = { ";" } + +// Single-quoted string literals, handling escaped quotes +SINGLE_QUOTED_STRING = { "'" ~ SINGLE_QUOTED_CONTENT ~ ("'" | EOI) } +SINGLE_QUOTED_CONTENT = { ( "''" | !("'" | EOI) ~ ANY )* } + +// Double-quoted identifiers, handling escaped quotes +DOUBLE_QUOTED_IDENTIFIER = { "\"" ~ DOUBLE_QUOTED_IDENTIFIER_CONTENT ~ ("\"" | EOI) } +DOUBLE_QUOTED_IDENTIFIER_CONTENT = { ( "\"\"" | !("\"" | EOI) ~ ANY )* } + +// Dollar-quoted strings, handling custom tags +DOLLAR_QUOTED_STRING = { DOLLAR_QUOTE_START ~ DOLLAR_QUOTED_CONTENT ~ DOLLAR_QUOTE_END } +DOLLAR_QUOTE_START = { "$" ~ DOLLAR_QUOTE_TAG ~ "$" } +DOLLAR_QUOTE_TAG = { ASCII_ALPHANUMERIC* } +DOLLAR_QUOTED_CONTENT = { ( !DOLLAR_QUOTE_END ~ ANY )* } +DOLLAR_QUOTE_END = { "$" ~ DOLLAR_QUOTE_TAG ~ "$" } + +// Comments (single-line and multi-line) +COMMENT = { SINGLE_LINE_COMMENT | MULTI_LINE_COMMENT } +SINGLE_LINE_COMMENT = { "--" ~ (!NEWLINE ~ ANY)* ~ NEWLINE? } + +MULTI_LINE_COMMENT = { "/*" ~ MULTI_LINE_COMMENT_CONTENT* ~ ( "*/" | EOI ) } +MULTI_LINE_COMMENT_CONTENT = { MULTI_LINE_COMMENT | (!"/*" ~ !"*/" ~ ANY) } + +// Whitespace rules +WHITESPACE = { " " | "\t" | NEWLINE } +NEWLINE = { "\r\n" | "\n" | "\r" } +"#] +struct PsqlSpliter; + +/// Splits a PostgreSQL query string into it's individual statements. +/// +/// This function parses and splits a SQL input string into separate statements, handling +/// PostgreSQL-specific syntax elements such as: +/// +/// - **Dollar-quoted strings**: Supports custom dollar-quoted tags (e.g., `$$`, `$tag$`). +/// - **Single and double-quoted strings**: Handles escaped quotes inside strings. +/// - **Comments**: Supports single-line (`--`) and multi-line (`/* ... */`) comments, preserving them as part of the statement. +/// - **Whitespace**: Retains all leading and trailing whitespace and comments around each statement. +/// - **Semicolons**: Recognizes semicolons as statement terminators, while ignoring them inside strings or comments. +/// +/// If parsing fails or only one statement is found, the input is returned in full. +/// +/// ```no_run +/// use sql_split_pest::split_psql; +/// let sql = r#" +/// -- First query +/// INSERT INTO users (name) VALUES ('Alice; Bob'); +/// +/// -- Second query +/// SELECT * FROM posts; +/// +/// /* Multi-line +/// comment */ +/// CREATE FUNCTION test_function() +/// RETURNS VOID AS $$ +/// BEGIN +/// -- Multiple statements inside the function +/// INSERT INTO table_a VALUES (1); +/// INSERT INTO table_b VALUES (2); +/// END; +/// $$ LANGUAGE plpgsql; +/// +/// -- invalid sql +/// SELECT 'This is an unterminated string FROM users; +/// SELECT * FROM users WHERE name = AND email = 'john@example.com'; +/// SELECT * FROM users JOIN other_table ON; +/// +/// "#; +/// +/// let statements = split_psql(sql); +/// dbg!(&statements); +/// assert_eq!(statements.len(), 4); +/// assert!(statements[0].contains("INSERT INTO users")); +/// assert!(statements[1].contains("SELECT * FROM posts")); +/// assert!(statements[2].contains("CREATE FUNCTION")); +/// assert!(statements[2].contains("plpgsql")); +/// assert!(statements[3].contains("other_table")); +/// ``` +pub fn split_sql>(sql: S) -> Vec { + let sql_str = sql.as_ref(); + + PsqlSpliter::parse(Rule::sql, sql_str).map_or_else( + |_| vec![sql_str.to_string()], + |mut parsed| match parsed.next() { + // this should never happen + None => vec![sql_str.to_string()], + Some(sql) => { + let mut statements = Vec::new(); + let mut statement = String::new(); + for pair in sql.into_inner() { + match pair.as_rule() { + Rule::WHITESPACE | Rule::COMMENT => statement.push_str(pair.as_str()), + Rule::statement | Rule::EOI => { + statement.push_str(pair.as_str()); + // omit empty whitespace at the end of sql + if !statement.is_empty() && !statement.chars().all(char::is_whitespace) + { + statements.push(std::mem::take(&mut statement)); + } + } + _ => unreachable!(), + } + } + statements + } + }, + ) +} diff --git a/tests/postgres/migrate.rs b/tests/postgres/migrate.rs index 636dffe860..bd0927edc9 100644 --- a/tests/postgres/migrate.rs +++ b/tests/postgres/migrate.rs @@ -74,7 +74,7 @@ async fn no_tx(mut conn: PoolConnection) -> anyhow::Result<()> { // run migration migrator.run(&mut conn).await?; - // check outcome + // check outcomes let res: String = conn .fetch_one("SELECT datname FROM pg_database WHERE datname = 'test_db'") .await? @@ -82,6 +82,13 @@ async fn no_tx(mut conn: PoolConnection) -> anyhow::Result<()> { assert_eq!(res, "test_db"); + let res: String = conn + .fetch_one("SELECT email FROM users WHERE username = 'test_user'") + .await? + .get(0); + + assert_eq!(res, "test_user@example.com"); + Ok(()) } diff --git a/tests/postgres/migrations_no_tx/0_create_db.sql b/tests/postgres/migrations_no_tx/0_create_db.sql index 95451f4adf..7f9e224aa1 100644 --- a/tests/postgres/migrations_no_tx/0_create_db.sql +++ b/tests/postgres/migrations_no_tx/0_create_db.sql @@ -1,3 +1,14 @@ -- no-transaction CREATE DATABASE test_db; + +CREATE TABLE users ( + id SERIAL PRIMARY KEY, + username VARCHAR(50) NOT NULL, + email VARCHAR(100) NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); + +CREATE INDEX CONCURRENTLY idx_users_email ON users(email); + +INSERT INTO users (username, email) VALUES ('test_user', 'test_user@example.com'); \ No newline at end of file