From c9246bf7d908d8785e16ffd2bd01e139454202e3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefan=20Gr=C3=B6nke?= Date: Tue, 15 Oct 2024 12:43:49 +0200 Subject: [PATCH] ExceptionStatement raises custom SQL errors --- src/backend/mysql/query.rs | 11 +++++++++ src/backend/postgres/query.rs | 6 +++++ src/backend/query_builder.rs | 12 +++++++++ src/backend/sqlite/query.rs | 6 +++++ src/exception.rs | 46 +++++++++++++++++++++++++++++++++++ src/expr.rs | 3 ++- src/lib.rs | 2 ++ tests/mysql/exception.rs | 20 +++++++++++++++ tests/mysql/mod.rs | 1 + tests/postgres/exception.rs | 20 +++++++++++++++ tests/postgres/mod.rs | 1 + tests/sqlite/exception.rs | 21 ++++++++++++++++ tests/sqlite/mod.rs | 1 + 13 files changed, 149 insertions(+), 1 deletion(-) create mode 100644 src/exception.rs create mode 100644 tests/mysql/exception.rs create mode 100644 tests/postgres/exception.rs create mode 100644 tests/sqlite/exception.rs diff --git a/src/backend/mysql/query.rs b/src/backend/mysql/query.rs index 58126014c..3244838db 100644 --- a/src/backend/mysql/query.rs +++ b/src/backend/mysql/query.rs @@ -140,6 +140,17 @@ impl QueryBuilder for MysqlQueryBuilder { fn prepare_returning(&self, _returning: &Option, _sql: &mut dyn SqlWriter) {} + fn prepare_exception_statement(&self, exception: &ExceptionStatement, sql: &mut dyn SqlWriter) { + let mut quoted_exception_message = String::new(); + self.write_string_quoted(&exception.message, &mut quoted_exception_message); + write!( + sql, + "SIGNAL SQLSTATE '45000' SET MESSAGE_TEXT = {}", + quoted_exception_message + ) + .unwrap(); + } + fn random_function(&self) -> &str { "RAND" } diff --git a/src/backend/postgres/query.rs b/src/backend/postgres/query.rs index f06501103..bf6662a29 100644 --- a/src/backend/postgres/query.rs +++ b/src/backend/postgres/query.rs @@ -153,6 +153,12 @@ impl QueryBuilder for PostgresQueryBuilder { sql.push_param(value.clone(), self as _); } + fn prepare_exception_statement(&self, exception: &ExceptionStatement, sql: &mut dyn SqlWriter) { + let mut quoted_exception_message = String::new(); + self.write_string_quoted(&exception.message, &mut quoted_exception_message); + write!(sql, "RAISE EXCEPTION {}", quoted_exception_message).unwrap(); + } + fn write_string_quoted(&self, string: &str, buffer: &mut String) { let escaped = self.escape_string(string); let string = if escaped.find('\\').is_some() { diff --git a/src/backend/query_builder.rs b/src/backend/query_builder.rs index 41cac8626..abcdf2ae1 100644 --- a/src/backend/query_builder.rs +++ b/src/backend/query_builder.rs @@ -387,6 +387,9 @@ pub trait QueryBuilder: SimpleExpr::Constant(val) => { self.prepare_constant(val, sql); } + SimpleExpr::Exception(val) => { + self.prepare_exception_statement(val, sql); + } } } @@ -982,6 +985,15 @@ pub trait QueryBuilder: } } + // Translate [`Exception`] into SQL statement. + fn prepare_exception_statement( + &self, + _exception: &ExceptionStatement, + _sql: &mut dyn SqlWriter, + ) { + panic!("Exception handling not implemented for this backend"); + } + /// Convert a SQL value into syntax-specific string fn value_to_string(&self, v: &Value) -> String { self.value_to_string_common(v) diff --git a/src/backend/sqlite/query.rs b/src/backend/sqlite/query.rs index 0a062294d..ceb0b6d43 100644 --- a/src/backend/sqlite/query.rs +++ b/src/backend/sqlite/query.rs @@ -84,6 +84,12 @@ impl QueryBuilder for SqliteQueryBuilder { "MIN" } + fn prepare_exception_statement(&self, exception: &ExceptionStatement, sql: &mut dyn SqlWriter) { + let mut quoted_exception_message = String::new(); + self.write_string_quoted(&exception.message, &mut quoted_exception_message); + write!(sql, "SELECT RAISE(ABORT, {})", quoted_exception_message).unwrap(); + } + fn char_length_function(&self) -> &str { "LENGTH" } diff --git a/src/exception.rs b/src/exception.rs new file mode 100644 index 000000000..2a4af51a2 --- /dev/null +++ b/src/exception.rs @@ -0,0 +1,46 @@ +//! Custom SQL exceptions and errors +use inherent::inherent; + +use crate::backend::SchemaBuilder; + +/// SQL Exceptions +#[derive(Debug, Clone, PartialEq)] +pub struct ExceptionStatement { + pub(crate) message: String, +} + +impl ExceptionStatement { + pub fn new(message: String) -> Self { + Self { message } + } +} + +pub trait ExceptionStatementBuilder { + /// Build corresponding SQL statement for certain database backend and return SQL string + fn build(&self, schema_builder: T) -> String; + + /// Build corresponding SQL statement for certain database backend and return SQL string + fn build_any(&self, schema_builder: &dyn SchemaBuilder) -> String; + + /// Build corresponding SQL statement for certain database backend and return SQL string + fn to_string(&self, schema_builder: T) -> String { + self.build(schema_builder) + } +} + +#[inherent] +impl ExceptionStatementBuilder for ExceptionStatement { + pub fn build(&self, schema_builder: T) -> String { + let mut sql = String::with_capacity(256); + schema_builder.prepare_exception_statement(self, &mut sql); + sql + } + + pub fn build_any(&self, schema_builder: &dyn SchemaBuilder) -> String { + let mut sql = String::with_capacity(256); + schema_builder.prepare_exception_statement(self, &mut sql); + sql + } + + pub fn to_string(&self, schema_builder: T) -> String; +} diff --git a/src/expr.rs b/src/expr.rs index b6894c944..391fca1cc 100644 --- a/src/expr.rs +++ b/src/expr.rs @@ -4,7 +4,7 @@ //! //! [`SimpleExpr`] is the expression common among select fields, where clauses and many other places. -use crate::{func::*, query::*, types::*, value::*}; +use crate::{exception::ExceptionStatement, func::*, query::*, types::*, value::*}; /// Helper to build a [`SimpleExpr`]. #[derive(Debug, Clone)] @@ -35,6 +35,7 @@ pub enum SimpleExpr { AsEnum(DynIden, Box), Case(Box), Constant(Value), + Exception(ExceptionStatement), } /// "Operator" methods for building complex expressions. diff --git a/src/lib.rs b/src/lib.rs index 15e1a189c..4f7de5099 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -814,6 +814,7 @@ pub mod backend; pub mod error; +pub mod exception; pub mod expr; pub mod extension; pub mod foreign_key; @@ -832,6 +833,7 @@ pub mod value; pub mod tests_cfg; pub use backend::*; +pub use exception::*; pub use expr::*; pub use foreign_key::*; pub use func::*; diff --git a/tests/mysql/exception.rs b/tests/mysql/exception.rs new file mode 100644 index 000000000..2446330cc --- /dev/null +++ b/tests/mysql/exception.rs @@ -0,0 +1,20 @@ +use super::*; +use pretty_assertions::assert_eq; + +#[test] +fn signal_sqlstate() { + let message = "Some error occurred"; + assert_eq!( + ExceptionStatement::new(message.to_string()).to_string(MysqlQueryBuilder), + format!("SIGNAL SQLSTATE '45000' SET MESSAGE_TEXT = '{message}'") + ); +} + +#[test] +fn escapes_message() { + let unescaped_message = "Does this 'break'?"; + assert_eq!( + ExceptionStatement::new(unescaped_message.to_string()).to_string(MysqlQueryBuilder), + format!("SIGNAL SQLSTATE '45000' SET MESSAGE_TEXT = 'Does this \\'break\\'?'") + ); +} diff --git a/tests/mysql/mod.rs b/tests/mysql/mod.rs index d717774f1..e4acd8365 100644 --- a/tests/mysql/mod.rs +++ b/tests/mysql/mod.rs @@ -1,5 +1,6 @@ use sea_query::{extension::mysql::*, tests_cfg::*, *}; +mod exception; mod foreign_key; mod index; mod query; diff --git a/tests/postgres/exception.rs b/tests/postgres/exception.rs new file mode 100644 index 000000000..372cac0b1 --- /dev/null +++ b/tests/postgres/exception.rs @@ -0,0 +1,20 @@ +use super::*; +use pretty_assertions::assert_eq; + +#[test] +fn raise_exception() { + let message = "Some error occurred"; + assert_eq!( + ExceptionStatement::new(message.to_string()).to_string(PostgresQueryBuilder), + format!("RAISE EXCEPTION '{message}'") + ); +} + +#[test] +fn escapes_message() { + let unescaped_message = "Does this 'break'?"; + assert_eq!( + ExceptionStatement::new(unescaped_message.to_string()).to_string(PostgresQueryBuilder), + format!("RAISE EXCEPTION E'Does this \\'break\\'?'") + ); +} diff --git a/tests/postgres/mod.rs b/tests/postgres/mod.rs index 82b85df3b..d65872c97 100644 --- a/tests/postgres/mod.rs +++ b/tests/postgres/mod.rs @@ -1,5 +1,6 @@ use sea_query::{tests_cfg::*, *}; +mod exception; mod foreign_key; mod index; mod query; diff --git a/tests/sqlite/exception.rs b/tests/sqlite/exception.rs new file mode 100644 index 000000000..81d175074 --- /dev/null +++ b/tests/sqlite/exception.rs @@ -0,0 +1,21 @@ +use super::*; +use pretty_assertions::assert_eq; + +#[test] +fn select_raise_abort() { + let message = "Some error occurred here"; + assert_eq!( + ExceptionStatement::new(message.to_string()).to_string(SqliteQueryBuilder), + format!("SELECT RAISE(ABORT, '{}')", message) + ); +} + +#[test] +fn escapes_message() { + let unescaped_message = "Does this 'break'?"; + let escaped_message = "Does this ''break''?"; + assert_eq!( + ExceptionStatement::new(unescaped_message.to_string()).to_string(SqliteQueryBuilder), + format!("SELECT RAISE(ABORT, '{}')", escaped_message) + ); +} diff --git a/tests/sqlite/mod.rs b/tests/sqlite/mod.rs index fc7388cd0..7206e8feb 100644 --- a/tests/sqlite/mod.rs +++ b/tests/sqlite/mod.rs @@ -1,5 +1,6 @@ use sea_query::{tests_cfg::*, *}; +mod exception; mod foreign_key; mod index; mod query;