From 9b9c217b3d2fe309e3aff1b939ae90f647eedbdb Mon Sep 17 00:00:00 2001 From: Marco Napetti <7566389+nappa85@users.noreply.github.com> Date: Sat, 2 Oct 2021 20:13:52 +0200 Subject: [PATCH] Transaction 2 (#199) * Stream implementation * use ouroboros to cover self-references * Reduce test size * Reduce test size * Complete transaction rewrite + streams * Less radioactive unsafe * Mutex transaction * Solve clippy lints * panic on drop of a locked transaction * panic on drop of a locked transaction * panic on drop of a locked transaction * Centralize Sync requirement --- Cargo.toml | 1 + src/database/connection.rs | 51 ++++- src/database/db_connection.rs | 28 ++- src/database/db_transaction.rs | 298 ++++++++++++++++------------- src/database/mock.rs | 4 +- src/database/mod.rs | 2 + src/database/stream/mod.rs | 5 + src/database/stream/query.rs | 108 +++++++++++ src/database/stream/transaction.rs | 82 ++++++++ src/driver/mock.rs | 21 +- src/driver/sqlx_mysql.rs | 38 +++- src/driver/sqlx_postgres.rs | 38 +++- src/driver/sqlx_sqlite.rs | 38 +++- src/entity/active_model.rs | 24 ++- src/executor/delete.rs | 20 +- src/executor/insert.rs | 28 +-- src/executor/paginator.rs | 8 +- src/executor/select.rs | 109 +++++++---- src/executor/update.rs | 34 ++-- tests/stream_tests.rs | 37 ++++ tests/transaction_tests.rs | 2 +- 21 files changed, 702 insertions(+), 274 deletions(-) create mode 100644 src/database/stream/mod.rs create mode 100644 src/database/stream/query.rs create mode 100644 src/database/stream/transaction.rs create mode 100644 tests/stream_tests.rs diff --git a/Cargo.toml b/Cargo.toml index 3e5a9682d..eeba39c9a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -36,6 +36,7 @@ serde = { version = "^1.0", features = ["derive"] } serde_json = { version = "^1", optional = true } sqlx = { version = "^0.5", optional = true } uuid = { version = "0.8", features = ["serde", "v4"], optional = true } +ouroboros = "0.11" [dev-dependencies] smol = { version = "^1.2" } diff --git a/src/database/connection.rs b/src/database/connection.rs index d2a6f9018..b1ba1e7cb 100644 --- a/src/database/connection.rs +++ b/src/database/connection.rs @@ -1,4 +1,4 @@ -use std::{pin::Pin, future::Future}; +use std::{future::Future, pin::Pin, sync::Arc}; use crate::{DatabaseTransaction, ConnectionTrait, ExecResult, QueryResult, Statement, StatementBuilder, TransactionError, error::*}; use sea_query::{MysqlQueryBuilder, PostgresQueryBuilder, QueryBuilder, SqliteQueryBuilder}; @@ -11,7 +11,7 @@ pub enum DatabaseConnection { #[cfg(feature = "sqlx-sqlite")] SqlxSqlitePoolConnection(crate::SqlxSqlitePoolConnection), #[cfg(feature = "mock")] - MockDatabaseConnection(crate::MockDatabaseConnection), + MockDatabaseConnection(Arc), Disconnected, } @@ -53,7 +53,9 @@ impl std::fmt::Debug for DatabaseConnection { } #[async_trait::async_trait] -impl ConnectionTrait for DatabaseConnection { +impl<'a> ConnectionTrait<'a> for DatabaseConnection { + type Stream = crate::QueryStream; + fn get_database_backend(&self) -> DbBackend { match self { #[cfg(feature = "sqlx-mysql")] @@ -77,7 +79,7 @@ impl ConnectionTrait for DatabaseConnection { #[cfg(feature = "sqlx-sqlite")] DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.execute(stmt).await, #[cfg(feature = "mock")] - DatabaseConnection::MockDatabaseConnection(conn) => conn.execute(stmt).await, + DatabaseConnection::MockDatabaseConnection(conn) => conn.execute(stmt), DatabaseConnection::Disconnected => Err(DbErr::Conn("Disconnected".to_owned())), } } @@ -91,7 +93,7 @@ impl ConnectionTrait for DatabaseConnection { #[cfg(feature = "sqlx-sqlite")] DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.query_one(stmt).await, #[cfg(feature = "mock")] - DatabaseConnection::MockDatabaseConnection(conn) => conn.query_one(stmt).await, + DatabaseConnection::MockDatabaseConnection(conn) => conn.query_one(stmt), DatabaseConnection::Disconnected => Err(DbErr::Conn("Disconnected".to_owned())), } } @@ -105,16 +107,46 @@ impl ConnectionTrait for DatabaseConnection { #[cfg(feature = "sqlx-sqlite")] DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.query_all(stmt).await, #[cfg(feature = "mock")] - DatabaseConnection::MockDatabaseConnection(conn) => conn.query_all(stmt).await, + DatabaseConnection::MockDatabaseConnection(conn) => conn.query_all(stmt), DatabaseConnection::Disconnected => Err(DbErr::Conn("Disconnected".to_owned())), } } + fn stream(&'a self, stmt: Statement) -> Pin> + 'a>> { + Box::pin(async move { + Ok(match self { + #[cfg(feature = "sqlx-mysql")] + DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.stream(stmt).await?, + #[cfg(feature = "sqlx-postgres")] + DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.stream(stmt).await?, + #[cfg(feature = "sqlx-sqlite")] + DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.stream(stmt).await?, + #[cfg(feature = "mock")] + DatabaseConnection::MockDatabaseConnection(conn) => crate::QueryStream::from((Arc::clone(conn), stmt)), + DatabaseConnection::Disconnected => panic!("Disconnected"), + }) + }) + } + + async fn begin(&self) -> Result { + match self { + #[cfg(feature = "sqlx-mysql")] + DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.begin().await, + #[cfg(feature = "sqlx-postgres")] + DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.begin().await, + #[cfg(feature = "sqlx-sqlite")] + DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.begin().await, + #[cfg(feature = "mock")] + DatabaseConnection::MockDatabaseConnection(conn) => DatabaseTransaction::new_mock(Arc::clone(conn)).await, + DatabaseConnection::Disconnected => panic!("Disconnected"), + } + } + /// Execute the function inside a transaction. /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed. async fn transaction(&self, _callback: F) -> Result> where - F: for<'c> FnOnce(&'c DatabaseTransaction<'_>) -> Pin> + Send + 'c>> + Send + Sync, + F: for<'c> FnOnce(&'c DatabaseTransaction) -> Pin> + Send + 'c>> + Send, T: Send, E: std::error::Error + Send, { @@ -126,7 +158,10 @@ impl ConnectionTrait for DatabaseConnection { #[cfg(feature = "sqlx-sqlite")] DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.transaction(_callback).await, #[cfg(feature = "mock")] - DatabaseConnection::MockDatabaseConnection(_) => unimplemented!(), //TODO: support transaction in mock connection + DatabaseConnection::MockDatabaseConnection(conn) => { + let transaction = DatabaseTransaction::new_mock(Arc::clone(conn)).await.map_err(|e| TransactionError::Connection(e))?; + transaction.run(_callback).await + }, DatabaseConnection::Disconnected => panic!("Disconnected"), } } diff --git a/src/database/db_connection.rs b/src/database/db_connection.rs index 615c0c3fd..569f38965 100644 --- a/src/database/db_connection.rs +++ b/src/database/db_connection.rs @@ -1,8 +1,24 @@ -use std::{pin::Pin, future::Future}; -use crate::{DatabaseTransaction, DbBackend, DbErr, ExecResult, QueryResult, Statement, TransactionError}; +use std::{future::Future, pin::Pin, sync::Arc}; +use crate::{DatabaseTransaction, DbBackend, DbErr, ExecResult, MockDatabaseConnection, QueryResult, Statement, TransactionError}; +use futures::Stream; +#[cfg(feature = "sqlx-dep")] +use sqlx::pool::PoolConnection; + +pub(crate) enum InnerConnection { + #[cfg(feature = "sqlx-mysql")] + MySql(PoolConnection), + #[cfg(feature = "sqlx-postgres")] + Postgres(PoolConnection), + #[cfg(feature = "sqlx-sqlite")] + Sqlite(PoolConnection), + #[cfg(feature = "mock")] + Mock(Arc), +} #[async_trait::async_trait] -pub trait ConnectionTrait: Sync { +pub trait ConnectionTrait<'a>: Sync { + type Stream: Stream>; + fn get_database_backend(&self) -> DbBackend; async fn execute(&self, stmt: Statement) -> Result; @@ -11,11 +27,15 @@ pub trait ConnectionTrait: Sync { async fn query_all(&self, stmt: Statement) -> Result, DbErr>; + fn stream(&'a self, stmt: Statement) -> Pin> + 'a>>; + + async fn begin(&self) -> Result; + /// Execute the function inside a transaction. /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed. async fn transaction(&self, callback: F) -> Result> where - F: for<'c> FnOnce(&'c DatabaseTransaction<'_>) -> Pin> + Send + 'c>> + Send + Sync, + F: for<'c> FnOnce(&'c DatabaseTransaction) -> Pin> + Send + 'c>> + Send, T: Send, E: std::error::Error + Send; diff --git a/src/database/db_transaction.rs b/src/database/db_transaction.rs index 000f6f609..b403971eb 100644 --- a/src/database/db_transaction.rs +++ b/src/database/db_transaction.rs @@ -1,153 +1,201 @@ -use std::{pin::Pin, future::Future}; -use crate::{DbBackend, ConnectionTrait, DbErr, ExecResult, QueryResult, Statement, debug_print}; +use std::{sync::Arc, future::Future, pin::Pin}; +use crate::{ConnectionTrait, DbBackend, DbErr, ExecResult, InnerConnection, QueryResult, Statement, TransactionStream, debug_print}; +use futures::lock::Mutex; #[cfg(feature = "sqlx-dep")] use crate::{sqlx_error_to_exec_err, sqlx_error_to_query_err}; #[cfg(feature = "sqlx-dep")] -use sqlx::Connection; +use sqlx::{pool::PoolConnection, TransactionManager}; -#[cfg(any(feature = "sqlx-mysql", feature = "sqlx-postgres", feature = "sqlx-sqlite"))] -use futures::lock::Mutex; +// a Transaction is just a sugar for a connection where START TRANSACTION has been executed +pub struct DatabaseTransaction { + conn: Arc>, + backend: DbBackend, + open: bool, +} -#[derive(Debug)] -pub enum DatabaseTransaction<'a> { +impl std::fmt::Debug for DatabaseTransaction { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "DatabaseTransaction") + } +} + +impl DatabaseTransaction { #[cfg(feature = "sqlx-mysql")] - SqlxMySqlTransaction(Mutex>), + pub(crate) async fn new_mysql(inner: PoolConnection) -> Result { + Self::build(Arc::new(Mutex::new(InnerConnection::MySql(inner))), DbBackend::MySql).await + } + #[cfg(feature = "sqlx-postgres")] - SqlxPostgresTransaction(Mutex>), - #[cfg(feature = "sqlx-sqlite")] - SqlxSqliteTransaction(Mutex>), - #[cfg(not(any(feature = "sqlx-mysql", feature = "sqlx-postgres", feature = "sqlx-sqlite")))] - None(&'a ()), -} + pub(crate) async fn new_postgres(inner: PoolConnection) -> Result { + Self::build(Arc::new(Mutex::new(InnerConnection::Postgres(inner))), DbBackend::Postgres).await + } -#[cfg(feature = "sqlx-mysql")] -impl<'a> From> for DatabaseTransaction<'a> { - fn from(inner: sqlx::Transaction<'a, sqlx::MySql>) -> Self { - DatabaseTransaction::SqlxMySqlTransaction(Mutex::new(inner)) + #[cfg(feature = "sqlx-sqlite")] + pub(crate) async fn new_sqlite(inner: PoolConnection) -> Result { + Self::build(Arc::new(Mutex::new(InnerConnection::Sqlite(inner))), DbBackend::Sqlite).await } -} -#[cfg(feature = "sqlx-postgres")] -impl<'a> From> for DatabaseTransaction<'a> { - fn from(inner: sqlx::Transaction<'a, sqlx::Postgres>) -> Self { - DatabaseTransaction::SqlxPostgresTransaction(Mutex::new(inner)) + #[cfg(feature = "mock")] + pub(crate) async fn new_mock(inner: Arc) -> Result { + let backend = inner.get_database_backend(); + Self::build(Arc::new(Mutex::new(InnerConnection::Mock(inner))), backend).await } -} -#[cfg(feature = "sqlx-sqlite")] -impl<'a> From> for DatabaseTransaction<'a> { - fn from(inner: sqlx::Transaction<'a, sqlx::Sqlite>) -> Self { - DatabaseTransaction::SqlxSqliteTransaction(Mutex::new(inner)) + async fn build(conn: Arc>, backend: DbBackend) -> Result { + let res = DatabaseTransaction { + conn, + backend, + open: true, + }; + match *res.conn.lock().await { + #[cfg(feature = "sqlx-mysql")] + InnerConnection::MySql(ref mut c) => { + ::TransactionManager::begin(c).await.map_err(sqlx_error_to_query_err)? + }, + #[cfg(feature = "sqlx-postgres")] + InnerConnection::Postgres(ref mut c) => { + ::TransactionManager::begin(c).await.map_err(sqlx_error_to_query_err)? + }, + #[cfg(feature = "sqlx-sqlite")] + InnerConnection::Sqlite(ref mut c) => { + ::TransactionManager::begin(c).await.map_err(sqlx_error_to_query_err)? + }, + // should we do something for mocked connections? + #[cfg(feature = "mock")] + InnerConnection::Mock(_) => {}, + } + Ok(res) } -} -#[allow(dead_code)] -impl<'a> DatabaseTransaction<'a> { pub(crate) async fn run(self, callback: F) -> Result> where - F: for<'b> FnOnce(&'b DatabaseTransaction<'a>) -> Pin> + Send + 'b>> + Send + Sync, + F: for<'b> FnOnce(&'b DatabaseTransaction) -> Pin> + Send + 'b>> + Send, T: Send, E: std::error::Error + Send, { let res = callback(&self).await.map_err(|e| TransactionError::Transaction(e)); if res.is_ok() { - self.commit().await?; + self.commit().await.map_err(|e| TransactionError::Connection(e))?; } else { - self.rollback().await?; + self.rollback().await.map_err(|e| TransactionError::Connection(e))?; } res } - async fn commit(self) -> Result<(), TransactionError> - where E: std::error::Error { - match self { + pub async fn commit(mut self) -> Result<(), DbErr> { + self.open = false; + match *self.conn.lock().await { #[cfg(feature = "sqlx-mysql")] - DatabaseTransaction::SqlxMySqlTransaction(inner) => { - let transaction = inner.into_inner(); - transaction.commit().await.map_err(|e| TransactionError::Connection(DbErr::Query(e.to_string()))) + InnerConnection::MySql(ref mut c) => { + ::TransactionManager::commit(c).await.map_err(sqlx_error_to_query_err)? }, #[cfg(feature = "sqlx-postgres")] - DatabaseTransaction::SqlxPostgresTransaction(inner) => { - let transaction = inner.into_inner(); - transaction.commit().await.map_err(|e| TransactionError::Connection(DbErr::Query(e.to_string()))) + InnerConnection::Postgres(ref mut c) => { + ::TransactionManager::commit(c).await.map_err(sqlx_error_to_query_err)? }, #[cfg(feature = "sqlx-sqlite")] - DatabaseTransaction::SqlxSqliteTransaction(inner) => { - let transaction = inner.into_inner(); - transaction.commit().await.map_err(|e| TransactionError::Connection(DbErr::Query(e.to_string()))) + InnerConnection::Sqlite(ref mut c) => { + ::TransactionManager::commit(c).await.map_err(sqlx_error_to_query_err)? }, - #[cfg(not(any(feature = "sqlx-mysql", feature = "sqlx-postgres", feature = "sqlx-sqlite")))] - _ => unimplemented!(), + //Should we do something for mocked connections? + #[cfg(feature = "mock")] + InnerConnection::Mock(_) => {}, } + Ok(()) } - async fn rollback(self) -> Result<(), TransactionError> - where E: std::error::Error { - match self { + pub async fn rollback(mut self) -> Result<(), DbErr> { + self.open = false; + match *self.conn.lock().await { #[cfg(feature = "sqlx-mysql")] - DatabaseTransaction::SqlxMySqlTransaction(inner) => { - let transaction = inner.into_inner(); - transaction.rollback().await.map_err(|e| TransactionError::Connection(DbErr::Query(e.to_string()))) + InnerConnection::MySql(ref mut c) => { + ::TransactionManager::rollback(c).await.map_err(sqlx_error_to_query_err)? }, #[cfg(feature = "sqlx-postgres")] - DatabaseTransaction::SqlxPostgresTransaction(inner) => { - let transaction = inner.into_inner(); - transaction.rollback().await.map_err(|e| TransactionError::Connection(DbErr::Query(e.to_string()))) + InnerConnection::Postgres(ref mut c) => { + ::TransactionManager::rollback(c).await.map_err(sqlx_error_to_query_err)? }, #[cfg(feature = "sqlx-sqlite")] - DatabaseTransaction::SqlxSqliteTransaction(inner) => { - let transaction = inner.into_inner(); - transaction.rollback().await.map_err(|e| TransactionError::Connection(DbErr::Query(e.to_string()))) + InnerConnection::Sqlite(ref mut c) => { + ::TransactionManager::rollback(c).await.map_err(sqlx_error_to_query_err)? }, - #[cfg(not(any(feature = "sqlx-mysql", feature = "sqlx-postgres", feature = "sqlx-sqlite")))] - _ => unimplemented!(), + //Should we do something for mocked connections? + #[cfg(feature = "mock")] + InnerConnection::Mock(_) => {}, + } + Ok(()) + } + + // the rollback is queued and will be performed on next async operation, like returning the connection to the pool + fn start_rollback(&mut self) { + if self.open { + if let Some(mut conn) = self.conn.try_lock() { + match &mut *conn { + #[cfg(feature = "sqlx-mysql")] + InnerConnection::MySql(c) => { + ::TransactionManager::start_rollback(c); + }, + #[cfg(feature = "sqlx-postgres")] + InnerConnection::Postgres(c) => { + ::TransactionManager::start_rollback(c); + }, + #[cfg(feature = "sqlx-sqlite")] + InnerConnection::Sqlite(c) => { + ::TransactionManager::start_rollback(c); + }, + //Should we do something for mocked connections? + #[cfg(feature = "mock")] + InnerConnection::Mock(_) => {}, + } + } + else { + //this should never happen + panic!("Dropping a locked Transaction"); + } } } } +impl Drop for DatabaseTransaction { + fn drop(&mut self) { + self.start_rollback(); + } +} + #[async_trait::async_trait] -impl<'a> ConnectionTrait for DatabaseTransaction<'a> { +impl<'a> ConnectionTrait<'a> for DatabaseTransaction { + type Stream = TransactionStream<'a>; + fn get_database_backend(&self) -> DbBackend { - match self { - #[cfg(feature = "sqlx-mysql")] - DatabaseTransaction::SqlxMySqlTransaction(_) => DbBackend::MySql, - #[cfg(feature = "sqlx-postgres")] - DatabaseTransaction::SqlxPostgresTransaction(_) => DbBackend::Postgres, - #[cfg(feature = "sqlx-sqlite")] - DatabaseTransaction::SqlxSqliteTransaction(_) => DbBackend::Sqlite, - #[cfg(not(any(feature = "sqlx-mysql", feature = "sqlx-postgres", feature = "sqlx-sqlite")))] - _ => unimplemented!(), - } + // this way we don't need to lock + self.backend } async fn execute(&self, stmt: Statement) -> Result { debug_print!("{}", stmt); - let _res = match self { + let _res = match &mut *self.conn.lock().await { #[cfg(feature = "sqlx-mysql")] - DatabaseTransaction::SqlxMySqlTransaction(conn) => { + InnerConnection::MySql(conn) => { let query = crate::driver::sqlx_mysql::sqlx_query(&stmt); - let mut conn = conn.lock().await; - query.execute(&mut *conn).await + query.execute(conn).await .map(Into::into) }, #[cfg(feature = "sqlx-postgres")] - DatabaseTransaction::SqlxPostgresTransaction(conn) => { + InnerConnection::Postgres(conn) => { let query = crate::driver::sqlx_postgres::sqlx_query(&stmt); - let mut conn = conn.lock().await; - query.execute(&mut *conn).await + query.execute(conn).await .map(Into::into) }, #[cfg(feature = "sqlx-sqlite")] - DatabaseTransaction::SqlxSqliteTransaction(conn) => { + InnerConnection::Sqlite(conn) => { let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt); - let mut conn = conn.lock().await; - query.execute(&mut *conn).await + query.execute(conn).await .map(Into::into) }, - #[cfg(not(any(feature = "sqlx-mysql", feature = "sqlx-postgres", feature = "sqlx-sqlite")))] - _ => unimplemented!(), + #[cfg(feature = "mock")] + InnerConnection::Mock(conn) => return conn.execute(stmt), }; #[cfg(feature = "sqlx-dep")] _res.map_err(sqlx_error_to_exec_err) @@ -156,30 +204,27 @@ impl<'a> ConnectionTrait for DatabaseTransaction<'a> { async fn query_one(&self, stmt: Statement) -> Result, DbErr> { debug_print!("{}", stmt); - let _res = match self { + let _res = match &mut *self.conn.lock().await { #[cfg(feature = "sqlx-mysql")] - DatabaseTransaction::SqlxMySqlTransaction(conn) => { + InnerConnection::MySql(conn) => { let query = crate::driver::sqlx_mysql::sqlx_query(&stmt); - let mut conn = conn.lock().await; - query.fetch_one(&mut *conn).await + query.fetch_one(conn).await .map(|row| Some(row.into())) }, #[cfg(feature = "sqlx-postgres")] - DatabaseTransaction::SqlxPostgresTransaction(conn) => { + InnerConnection::Postgres(conn) => { let query = crate::driver::sqlx_postgres::sqlx_query(&stmt); - let mut conn = conn.lock().await; - query.fetch_one(&mut *conn).await + query.fetch_one(conn).await .map(|row| Some(row.into())) }, #[cfg(feature = "sqlx-sqlite")] - DatabaseTransaction::SqlxSqliteTransaction(conn) => { - let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt); - let mut conn = conn.lock().await; - query.fetch_one(&mut *conn).await + InnerConnection::Sqlite(conn) => { + let query= crate::driver::sqlx_sqlite::sqlx_query(&stmt); + query.fetch_one(conn).await .map(|row| Some(row.into())) }, - #[cfg(not(any(feature = "sqlx-mysql", feature = "sqlx-postgres", feature = "sqlx-sqlite")))] - _ => unimplemented!(), + #[cfg(feature = "mock")] + InnerConnection::Mock(conn) => return conn.query_one(stmt), }; #[cfg(feature = "sqlx-dep")] if let Err(sqlx::Error::RowNotFound) = _res { @@ -193,65 +238,52 @@ impl<'a> ConnectionTrait for DatabaseTransaction<'a> { async fn query_all(&self, stmt: Statement) -> Result, DbErr> { debug_print!("{}", stmt); - let _res = match self { + let _res = match &mut *self.conn.lock().await { #[cfg(feature = "sqlx-mysql")] - DatabaseTransaction::SqlxMySqlTransaction(conn) => { + InnerConnection::MySql(conn) => { let query = crate::driver::sqlx_mysql::sqlx_query(&stmt); - let mut conn = conn.lock().await; - query.fetch_all(&mut *conn).await + query.fetch_all(conn).await .map(|rows| rows.into_iter().map(|r| r.into()).collect()) }, #[cfg(feature = "sqlx-postgres")] - DatabaseTransaction::SqlxPostgresTransaction(conn) => { + InnerConnection::Postgres(conn) => { let query = crate::driver::sqlx_postgres::sqlx_query(&stmt); - let mut conn = conn.lock().await; - query.fetch_all(&mut *conn).await + query.fetch_all(conn).await .map(|rows| rows.into_iter().map(|r| r.into()).collect()) }, #[cfg(feature = "sqlx-sqlite")] - DatabaseTransaction::SqlxSqliteTransaction(conn) => { + InnerConnection::Sqlite(conn) => { let query = crate::driver::sqlx_sqlite::sqlx_query(&stmt); - let mut conn = conn.lock().await; - query.fetch_all(&mut *conn).await + query.fetch_all(conn).await .map(|rows| rows.into_iter().map(|r| r.into()).collect()) }, - #[cfg(not(any(feature = "sqlx-mysql", feature = "sqlx-postgres", feature = "sqlx-sqlite")))] - _ => unimplemented!(), + #[cfg(feature = "mock")] + InnerConnection::Mock(conn) => return conn.query_all(stmt), }; #[cfg(feature = "sqlx-dep")] _res.map_err(sqlx_error_to_query_err) } + fn stream(&'a self, stmt: Statement) -> Pin> + 'a>> { + Box::pin(async move { + Ok(crate::TransactionStream::build(self.conn.lock().await, stmt).await) + }) + } + + async fn begin(&self) -> Result { + DatabaseTransaction::build(Arc::clone(&self.conn), self.backend).await + } + /// Execute the function inside a transaction. /// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed. async fn transaction(&self, _callback: F) -> Result> where - F: for<'c> FnOnce(&'c DatabaseTransaction<'_>) -> Pin> + Send + 'c>> + Send + Sync, + F: for<'c> FnOnce(&'c DatabaseTransaction) -> Pin> + Send + 'c>> + Send, T: Send, E: std::error::Error + Send, { - match self { - #[cfg(feature = "sqlx-mysql")] - DatabaseTransaction::SqlxMySqlTransaction(conn) => { - let mut conn = conn.lock().await; - let transaction = DatabaseTransaction::from(conn.begin().await.map_err(|e| TransactionError::Connection(DbErr::Query(e.to_string())))?); - transaction.run(_callback).await - }, - #[cfg(feature = "sqlx-postgres")] - DatabaseTransaction::SqlxPostgresTransaction(conn) => { - let mut conn = conn.lock().await; - let transaction = DatabaseTransaction::from(conn.begin().await.map_err(|e| TransactionError::Connection(DbErr::Query(e.to_string())))?); - transaction.run(_callback).await - }, - #[cfg(feature = "sqlx-sqlite")] - DatabaseTransaction::SqlxSqliteTransaction(conn) => { - let mut conn = conn.lock().await; - let transaction = DatabaseTransaction::from(conn.begin().await.map_err(|e| TransactionError::Connection(DbErr::Query(e.to_string())))?); - transaction.run(_callback).await - }, - #[cfg(not(any(feature = "sqlx-mysql", feature = "sqlx-postgres", feature = "sqlx-sqlite")))] - _ => unimplemented!(), - } + let transaction = self.begin().await.map_err(|e| TransactionError::Connection(e))?; + transaction.run(_callback).await } } diff --git a/src/database/mock.rs b/src/database/mock.rs index 7077a6330..d0862e807 100644 --- a/src/database/mock.rs +++ b/src/database/mock.rs @@ -4,7 +4,7 @@ use crate::{ Statement, Transaction, }; use sea_query::{Value, ValueType}; -use std::collections::BTreeMap; +use std::{collections::BTreeMap, sync::Arc}; #[derive(Debug)] pub struct MockDatabase { @@ -53,7 +53,7 @@ impl MockDatabase { } pub fn into_connection(self) -> DatabaseConnection { - DatabaseConnection::MockDatabaseConnection(MockDatabaseConnection::new(self)) + DatabaseConnection::MockDatabaseConnection(Arc::new(MockDatabaseConnection::new(self))) } pub fn append_exec_results(mut self, mut vec: Vec) -> Self { diff --git a/src/database/mod.rs b/src/database/mod.rs index 296b4d0bb..c8db4b99c 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -5,6 +5,7 @@ mod statement; mod transaction; mod db_connection; mod db_transaction; +mod stream; pub use connection::*; #[cfg(feature = "mock")] @@ -13,6 +14,7 @@ pub use statement::*; pub use transaction::*; pub use db_connection::*; pub use db_transaction::*; +pub use stream::*; use crate::DbErr; diff --git a/src/database/stream/mod.rs b/src/database/stream/mod.rs new file mode 100644 index 000000000..774cf45fa --- /dev/null +++ b/src/database/stream/mod.rs @@ -0,0 +1,5 @@ +mod query; +mod transaction; + +pub use query::*; +pub use transaction::*; diff --git a/src/database/stream/query.rs b/src/database/stream/query.rs new file mode 100644 index 000000000..553d9f7b0 --- /dev/null +++ b/src/database/stream/query.rs @@ -0,0 +1,108 @@ +use std::{pin::Pin, task::Poll, sync::Arc}; + +use futures::Stream; +#[cfg(feature = "sqlx-dep")] +use futures::TryStreamExt; + +#[cfg(feature = "sqlx-dep")] +use sqlx::{pool::PoolConnection, Executor}; + +use crate::{DbErr, InnerConnection, QueryResult, Statement}; + +#[ouroboros::self_referencing] +pub struct QueryStream { + stmt: Statement, + conn: InnerConnection, + #[borrows(mut conn, stmt)] + #[not_covariant] + stream: Pin> + 'this>>, +} + +#[cfg(feature = "sqlx-mysql")] +impl From<(PoolConnection, Statement)> for QueryStream { + fn from((conn, stmt): (PoolConnection, Statement)) -> Self { + QueryStream::build(stmt, InnerConnection::MySql(conn)) + } +} + +#[cfg(feature = "sqlx-postgres")] +impl From<(PoolConnection, Statement)> for QueryStream { + fn from((conn, stmt): (PoolConnection, Statement)) -> Self { + QueryStream::build(stmt, InnerConnection::Postgres(conn)) + } +} + +#[cfg(feature = "sqlx-sqlite")] +impl From<(PoolConnection, Statement)> for QueryStream { + fn from((conn, stmt): (PoolConnection, Statement)) -> Self { + QueryStream::build(stmt, InnerConnection::Sqlite(conn)) + } +} + +#[cfg(feature = "mock")] +impl From<(Arc, Statement)> for QueryStream { + fn from((conn, stmt): (Arc, Statement)) -> Self { + QueryStream::build(stmt, InnerConnection::Mock(conn)) + } +} + +impl std::fmt::Debug for QueryStream { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "QueryStream") + } +} + +impl QueryStream { + fn build(stmt: Statement, conn: InnerConnection) -> QueryStream { + QueryStreamBuilder { + stmt, + conn, + stream_builder: |conn, stmt| { + match conn { + #[cfg(feature = "sqlx-mysql")] + InnerConnection::MySql(c) => { + let query = crate::driver::sqlx_mysql::sqlx_query(stmt); + Box::pin( + c.fetch(query) + .map_ok(Into::into) + .map_err(crate::sqlx_error_to_query_err) + ) + }, + #[cfg(feature = "sqlx-postgres")] + InnerConnection::Postgres(c) => { + let query = crate::driver::sqlx_postgres::sqlx_query(stmt); + Box::pin( + c.fetch(query) + .map_ok(Into::into) + .map_err(crate::sqlx_error_to_query_err) + ) + }, + #[cfg(feature = "sqlx-sqlite")] + InnerConnection::Sqlite(c) => { + let query = crate::driver::sqlx_sqlite::sqlx_query(stmt); + Box::pin( + c.fetch(query) + .map_ok(Into::into) + .map_err(crate::sqlx_error_to_query_err) + ) + }, + #[cfg(feature = "mock")] + InnerConnection::Mock(c) => { + c.fetch(stmt) + }, + } + }, + }.build() + } +} + +impl Stream for QueryStream { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll> { + let this = self.get_mut(); + this.with_stream_mut(|stream| { + stream.as_mut().poll_next(cx) + }) + } +} diff --git a/src/database/stream/transaction.rs b/src/database/stream/transaction.rs new file mode 100644 index 000000000..d945f4095 --- /dev/null +++ b/src/database/stream/transaction.rs @@ -0,0 +1,82 @@ +use std::{ops::DerefMut, pin::Pin, task::Poll}; + +use futures::Stream; +#[cfg(feature = "sqlx-dep")] +use futures::TryStreamExt; + +#[cfg(feature = "sqlx-dep")] +use sqlx::Executor; + +use futures::lock::MutexGuard; + +use crate::{DbErr, InnerConnection, QueryResult, Statement}; + +#[ouroboros::self_referencing] +pub struct TransactionStream<'a> { + stmt: Statement, + conn: MutexGuard<'a, InnerConnection>, + #[borrows(mut conn, stmt)] + #[not_covariant] + stream: Pin> + 'this>>, +} + +impl<'a> std::fmt::Debug for TransactionStream<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "TransactionStream") + } +} + +impl<'a> TransactionStream<'a> { + pub(crate) async fn build(conn: MutexGuard<'a, InnerConnection>, stmt: Statement) -> TransactionStream<'a> { + TransactionStreamAsyncBuilder { + stmt, + conn, + stream_builder: |conn, stmt| Box::pin(async move { + match conn.deref_mut() { + #[cfg(feature = "sqlx-mysql")] + InnerConnection::MySql(c) => { + let query = crate::driver::sqlx_mysql::sqlx_query(stmt); + Box::pin( + c.fetch(query) + .map_ok(Into::into) + .map_err(crate::sqlx_error_to_query_err) + ) as Pin>>> + }, + #[cfg(feature = "sqlx-postgres")] + InnerConnection::Postgres(c) => { + let query = crate::driver::sqlx_postgres::sqlx_query(stmt); + Box::pin( + c.fetch(query) + .map_ok(Into::into) + .map_err(crate::sqlx_error_to_query_err) + ) as Pin>>> + }, + #[cfg(feature = "sqlx-sqlite")] + InnerConnection::Sqlite(c) => { + let query = crate::driver::sqlx_sqlite::sqlx_query(stmt); + Box::pin( + c.fetch(query) + .map_ok(Into::into) + .map_err(crate::sqlx_error_to_query_err) + ) as Pin>>> + }, + #[cfg(feature = "mock")] + InnerConnection::Mock(c) => { + c.fetch(stmt) + }, + } + }), + }.build().await + } +} + +impl<'a> Stream for TransactionStream<'a> { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll> { + let this = self.get_mut(); + this.with_stream_mut(|stream| { + stream.as_mut().poll_next(cx) + }) + } +} diff --git a/src/driver/mock.rs b/src/driver/mock.rs index 823ddb32d..5e4cc84ef 100644 --- a/src/driver/mock.rs +++ b/src/driver/mock.rs @@ -2,11 +2,11 @@ use crate::{ debug_print, error::*, DatabaseConnection, DbBackend, ExecResult, MockDatabase, QueryResult, Statement, Transaction, }; -use std::fmt::Debug; -use std::sync::{ +use std::{fmt::Debug, pin::Pin, sync::{Arc, atomic::{AtomicUsize, Ordering}, Mutex, -}; +}}; +use futures::Stream; #[derive(Debug)] pub struct MockDatabaseConnector; @@ -50,7 +50,7 @@ impl MockDatabaseConnector { macro_rules! connect_mock_db { ( $syntax: expr ) => { Ok(DatabaseConnection::MockDatabaseConnection( - MockDatabaseConnection::new(MockDatabase::new($syntax)), + Arc::new(MockDatabaseConnection::new(MockDatabase::new($syntax))), )) }; } @@ -86,25 +86,32 @@ impl MockDatabaseConnection { &self.mocker } - pub async fn execute(&self, statement: Statement) -> Result { + pub fn execute(&self, statement: Statement) -> Result { debug_print!("{}", statement); let counter = self.counter.fetch_add(1, Ordering::SeqCst); self.mocker.lock().unwrap().execute(counter, statement) } - pub async fn query_one(&self, statement: Statement) -> Result, DbErr> { + pub fn query_one(&self, statement: Statement) -> Result, DbErr> { debug_print!("{}", statement); let counter = self.counter.fetch_add(1, Ordering::SeqCst); let result = self.mocker.lock().unwrap().query(counter, statement)?; Ok(result.into_iter().next()) } - pub async fn query_all(&self, statement: Statement) -> Result, DbErr> { + pub fn query_all(&self, statement: Statement) -> Result, DbErr> { debug_print!("{}", statement); let counter = self.counter.fetch_add(1, Ordering::SeqCst); self.mocker.lock().unwrap().query(counter, statement) } + pub fn fetch(&self, statement: &Statement) -> Pin>>> { + match self.query_all(statement.clone()) { + Ok(v) => Box::pin(futures::stream::iter(v.into_iter().map(|r| Ok(r)))), + Err(e) => Box::pin(futures::stream::iter(Some(Err(e)).into_iter())), + } + } + pub fn get_database_backend(&self) -> DbBackend { self.mocker.lock().unwrap().get_database_backend() } diff --git a/src/driver/sqlx_mysql.rs b/src/driver/sqlx_mysql.rs index 2235b98b3..75e6e5ff1 100644 --- a/src/driver/sqlx_mysql.rs +++ b/src/driver/sqlx_mysql.rs @@ -1,11 +1,11 @@ -use std::{pin::Pin, future::Future}; +use std::{future::Future, pin::Pin}; -use sqlx::{Connection, MySql, MySqlPool, mysql::{MySqlArguments, MySqlQueryResult, MySqlRow}}; +use sqlx::{MySql, MySqlPool, mysql::{MySqlArguments, MySqlQueryResult, MySqlRow}}; sea_query::sea_query_driver_mysql!(); use sea_query_driver_mysql::bind_query; -use crate::{DatabaseConnection, DatabaseTransaction, Statement, TransactionError, debug_print, error::*, executor::*}; +use crate::{DatabaseConnection, DatabaseTransaction, QueryStream, Statement, TransactionError, debug_print, error::*, executor::*}; use super::sqlx_common::*; @@ -91,18 +91,36 @@ impl SqlxMySqlPoolConnection { } } + pub async fn stream(&self, stmt: Statement) -> Result { + debug_print!("{}", stmt); + + if let Ok(conn) = self.pool.acquire().await { + Ok(QueryStream::from((conn, stmt))) + } else { + Err(DbErr::Query( + "Failed to acquire connection from pool.".to_owned(), + )) + } + } + + pub async fn begin(&self) -> Result { + if let Ok(conn) = self.pool.acquire().await { + DatabaseTransaction::new_mysql(conn).await + } else { + Err(DbErr::Query( + "Failed to acquire connection from pool.".to_owned(), + )) + } + } + pub async fn transaction(&self, callback: F) -> Result> where - F: for<'b> FnOnce(&'b DatabaseTransaction<'_>) -> Pin> + Send + 'b>> + Send + Sync, + F: for<'b> FnOnce(&'b DatabaseTransaction) -> Pin> + Send + 'b>> + Send, T: Send, E: std::error::Error + Send, { - if let Ok(conn) = &mut self.pool.acquire().await { - let transaction = DatabaseTransaction::from( - conn.begin().await.map_err(|e| { - TransactionError::Connection(DbErr::Query(e.to_string())) - })? - ); + if let Ok(conn) = self.pool.acquire().await { + let transaction = DatabaseTransaction::new_mysql(conn).await.map_err(|e| TransactionError::Connection(e))?; transaction.run(callback).await } else { Err(TransactionError::Connection(DbErr::Query( diff --git a/src/driver/sqlx_postgres.rs b/src/driver/sqlx_postgres.rs index 72cb871a6..c9949375b 100644 --- a/src/driver/sqlx_postgres.rs +++ b/src/driver/sqlx_postgres.rs @@ -1,11 +1,11 @@ -use std::{pin::Pin, future::Future}; +use std::{future::Future, pin::Pin}; -use sqlx::{Connection, PgPool, Postgres, postgres::{PgArguments, PgQueryResult, PgRow}}; +use sqlx::{PgPool, Postgres, postgres::{PgArguments, PgQueryResult, PgRow}}; sea_query::sea_query_driver_postgres!(); use sea_query_driver_postgres::bind_query; -use crate::{DatabaseConnection, DatabaseTransaction, Statement, TransactionError, debug_print, error::*, executor::*}; +use crate::{DatabaseConnection, DatabaseTransaction, QueryStream, Statement, TransactionError, debug_print, error::*, executor::*}; use super::sqlx_common::*; @@ -91,18 +91,36 @@ impl SqlxPostgresPoolConnection { } } + pub async fn stream(&self, stmt: Statement) -> Result { + debug_print!("{}", stmt); + + if let Ok(conn) = self.pool.acquire().await { + Ok(QueryStream::from((conn, stmt))) + } else { + Err(DbErr::Query( + "Failed to acquire connection from pool.".to_owned(), + )) + } + } + + pub async fn begin(&self) -> Result { + if let Ok(conn) = self.pool.acquire().await { + DatabaseTransaction::new_postgres(conn).await + } else { + Err(DbErr::Query( + "Failed to acquire connection from pool.".to_owned(), + )) + } + } + pub async fn transaction(&self, callback: F) -> Result> where - F: for<'b> FnOnce(&'b DatabaseTransaction<'_>) -> Pin> + Send + 'b>> + Send + Sync, + F: for<'b> FnOnce(&'b DatabaseTransaction) -> Pin> + Send + 'b>> + Send, T: Send, E: std::error::Error + Send, { - if let Ok(conn) = &mut self.pool.acquire().await { - let transaction = DatabaseTransaction::from( - conn.begin().await.map_err(|e| { - TransactionError::Connection(DbErr::Query(e.to_string())) - })? - ); + if let Ok(conn) = self.pool.acquire().await { + let transaction = DatabaseTransaction::new_postgres(conn).await.map_err(|e| TransactionError::Connection(e))?; transaction.run(callback).await } else { Err(TransactionError::Connection(DbErr::Query( diff --git a/src/driver/sqlx_sqlite.rs b/src/driver/sqlx_sqlite.rs index 66cf2d5df..bf06a2659 100644 --- a/src/driver/sqlx_sqlite.rs +++ b/src/driver/sqlx_sqlite.rs @@ -1,11 +1,11 @@ -use std::{pin::Pin, future::Future}; +use std::{future::Future, pin::Pin}; -use sqlx::{Connection, Sqlite, SqlitePool, sqlite::{SqliteArguments, SqliteQueryResult, SqliteRow}}; +use sqlx::{Sqlite, SqlitePool, sqlite::{SqliteArguments, SqliteQueryResult, SqliteRow}}; sea_query::sea_query_driver_sqlite!(); use sea_query_driver_sqlite::bind_query; -use crate::{DatabaseConnection, DatabaseTransaction, Statement, TransactionError, debug_print, error::*, executor::*}; +use crate::{DatabaseConnection, DatabaseTransaction, QueryStream, Statement, TransactionError, debug_print, error::*, executor::*}; use super::sqlx_common::*; @@ -91,18 +91,36 @@ impl SqlxSqlitePoolConnection { } } + pub async fn stream(&self, stmt: Statement) -> Result { + debug_print!("{}", stmt); + + if let Ok(conn) = self.pool.acquire().await { + Ok(QueryStream::from((conn, stmt))) + } else { + Err(DbErr::Query( + "Failed to acquire connection from pool.".to_owned(), + )) + } + } + + pub async fn begin(&self) -> Result { + if let Ok(conn) = self.pool.acquire().await { + DatabaseTransaction::new_sqlite(conn).await + } else { + Err(DbErr::Query( + "Failed to acquire connection from pool.".to_owned(), + )) + } + } + pub async fn transaction(&self, callback: F) -> Result> where - F: for<'b> FnOnce(&'b DatabaseTransaction<'_>) -> Pin> + Send + 'b>> + Send + Sync, + F: for<'b> FnOnce(&'b DatabaseTransaction) -> Pin> + Send + 'b>> + Send, T: Send, E: std::error::Error + Send, { - if let Ok(conn) = &mut self.pool.acquire().await { - let transaction = DatabaseTransaction::from( - conn.begin().await.map_err(|e| { - TransactionError::Connection(DbErr::Query(e.to_string())) - })? - ); + if let Ok(conn) = self.pool.acquire().await { + let transaction = DatabaseTransaction::new_sqlite(conn).await.map_err(|e| TransactionError::Connection(e))?; transaction.run(callback).await } else { Err(TransactionError::Connection(DbErr::Query( diff --git a/src/entity/active_model.rs b/src/entity/active_model.rs index a7cb00553..32e9d77df 100644 --- a/src/entity/active_model.rs +++ b/src/entity/active_model.rs @@ -67,10 +67,11 @@ pub trait ActiveModelTrait: Clone + Debug { fn default() -> Self; - async fn insert(self, db: &C) -> Result + async fn insert<'a, C>(self, db: &'a C) -> Result where ::Model: IntoActiveModel, - C: ConnectionTrait, + C: ConnectionTrait<'a>, + Self: 'a, { let am = self; let exec = ::insert(am).exec(db); @@ -91,19 +92,22 @@ pub trait ActiveModelTrait: Clone + Debug { } } - async fn update(self, db: &C) -> Result - where C: ConnectionTrait { + async fn update<'a, C>(self, db: &'a C) -> Result + where + C: ConnectionTrait<'a>, + Self: 'a, + { let exec = Self::Entity::update(self).exec(db); exec.await } /// Insert the model if primary key is unset, update otherwise. /// Only works if the entity has auto increment primary key. - async fn save(self, db: &C) -> Result + async fn save<'a, C>(self, db: &'a C) -> Result where - Self: ActiveModelBehavior, + Self: ActiveModelBehavior + 'a, ::Model: IntoActiveModel, - C: ConnectionTrait, + C: ConnectionTrait<'a>, { let mut am = self; am = ActiveModelBehavior::before_save(am); @@ -125,10 +129,10 @@ pub trait ActiveModelTrait: Clone + Debug { } /// Delete an active model by its primary key - async fn delete(self, db: &C) -> Result + async fn delete<'a, C>(self, db: &'a C) -> Result where - Self: ActiveModelBehavior, - C: ConnectionTrait, + Self: ActiveModelBehavior + 'a, + C: ConnectionTrait<'a>, { let mut am = self; am = ActiveModelBehavior::before_delete(am); diff --git a/src/executor/delete.rs b/src/executor/delete.rs index 34c848e23..85b37cb09 100644 --- a/src/executor/delete.rs +++ b/src/executor/delete.rs @@ -20,7 +20,7 @@ where self, db: &'a C, ) -> impl Future> + 'a - where C: ConnectionTrait { + where C: ConnectionTrait<'a> { // so that self is dropped before entering await exec_delete_only(self.query, db) } @@ -34,7 +34,7 @@ where self, db: &'a C, ) -> impl Future> + 'a - where C: ConnectionTrait { + where C: ConnectionTrait<'a> { // so that self is dropped before entering await exec_delete_only(self.query, db) } @@ -45,27 +45,27 @@ impl Deleter { Self { query } } - pub fn exec( + pub fn exec<'a, C>( self, - db: &C, + db: &'a C, ) -> impl Future> + '_ - where C: ConnectionTrait { + where C: ConnectionTrait<'a> { let builder = db.get_database_backend(); exec_delete(builder.build(&self.query), db) } } -async fn exec_delete_only( +async fn exec_delete_only<'a, C>( query: DeleteStatement, - db: &C, + db: &'a C, ) -> Result -where C: ConnectionTrait { +where C: ConnectionTrait<'a> { Deleter::new(query).exec(db).await } // Only Statement impl Send -async fn exec_delete(statement: Statement, db: &C) -> Result -where C: ConnectionTrait { +async fn exec_delete<'a, C>(statement: Statement, db: &C) -> Result +where C: ConnectionTrait<'a> { let result = db.execute(statement).await?; Ok(DeleteResult { rows_affected: result.rows_affected(), diff --git a/src/executor/insert.rs b/src/executor/insert.rs index 079560e60..3f5ced1d4 100644 --- a/src/executor/insert.rs +++ b/src/executor/insert.rs @@ -1,6 +1,6 @@ -use crate::{ActiveModelTrait, DbBackend, ConnectionTrait, EntityTrait, Insert, PrimaryKeyTrait, Statement, TryFromU64, error::*}; +use crate::{ActiveModelTrait, ConnectionTrait, EntityTrait, Insert, PrimaryKeyTrait, Statement, TryFromU64, error::*}; use sea_query::InsertStatement; -use std::{future::Future, marker::PhantomData}; +use std::marker::PhantomData; #[derive(Clone, Debug)] pub struct Inserter @@ -24,19 +24,19 @@ where A: ActiveModelTrait, { #[allow(unused_mut)] - pub fn exec<'a, C>( + pub async fn exec<'a, C>( self, db: &'a C, - ) -> impl Future, DbErr>> + 'a + ) -> Result, DbErr> where - C: ConnectionTrait, + C: ConnectionTrait<'a>, A: 'a, { // TODO: extract primary key's value from query // so that self is dropped before entering await let mut query = self.query; #[cfg(feature = "sqlx-postgres")] - if db.get_database_backend() == DbBackend::Postgres && !db.is_mock_connection() { + if db.get_database_backend() == crate::DbBackend::Postgres && !db.is_mock_connection() { use crate::{sea_query::Query, Iterable}; if ::PrimaryKey::iter().count() > 0 { query.returning( @@ -46,7 +46,7 @@ where ); } } - Inserter::::new(query).exec(db) + Inserter::::new(query).exec(db).await // TODO: return primary key if extracted before, otherwise use InsertResult } } @@ -62,33 +62,33 @@ where } } - pub fn exec<'a, C>( + pub async fn exec<'a, C>( self, db: &'a C, - ) -> impl Future, DbErr>> + 'a + ) -> Result, DbErr> where - C: ConnectionTrait, + C: ConnectionTrait<'a>, A: 'a, { let builder = db.get_database_backend(); - exec_insert(builder.build(&self.query), db) + exec_insert(builder.build(&self.query), db).await } } // Only Statement impl Send -async fn exec_insert( +async fn exec_insert<'a, A, C>( statement: Statement, db: &C, ) -> Result, DbErr> where - C: ConnectionTrait, + C: ConnectionTrait<'a>, A: ActiveModelTrait, { type PrimaryKey = <::Entity as EntityTrait>::PrimaryKey; type ValueTypeOf = as PrimaryKeyTrait>::ValueType; let last_insert_id = match db.get_database_backend() { #[cfg(feature = "sqlx-postgres")] - DbBackend::Postgres if !db.is_mock_connection() => { + crate::DbBackend::Postgres if !db.is_mock_connection() => { use crate::{sea_query::Iden, Iterable}; let cols = PrimaryKey::::iter() .map(|col| col.to_string()) diff --git a/src/executor/paginator.rs b/src/executor/paginator.rs index 6b5943f45..b1db9342c 100644 --- a/src/executor/paginator.rs +++ b/src/executor/paginator.rs @@ -1,4 +1,4 @@ -use crate::{DbBackend, ConnectionTrait, SelectorTrait, error::*}; +use crate::{ConnectionTrait, SelectorTrait, error::*}; use async_stream::stream; use futures::Stream; use sea_query::{Alias, Expr, SelectStatement}; @@ -9,7 +9,7 @@ pub type PinBoxStream<'db, Item> = Pin + 'db>>; #[derive(Clone, Debug)] pub struct Paginator<'db, C, S> where - C: ConnectionTrait, + C: ConnectionTrait<'db>, S: SelectorTrait + 'db, { pub(crate) query: SelectStatement, @@ -23,7 +23,7 @@ where impl<'db, C, S> Paginator<'db, C, S> where - C: ConnectionTrait, + C: ConnectionTrait<'db>, S: SelectorTrait + 'db, { /// Fetch a specific page @@ -67,7 +67,7 @@ where }; let num_items = match builder { #[cfg(feature = "sqlx-postgres")] - DbBackend::Postgres if !self.db.is_mock_connection() => result.try_get::("", "num_items")? as usize, + crate::DbBackend::Postgres if !self.db.is_mock_connection() => result.try_get::("", "num_items")? as usize, _ => result.try_get::("", "num_items")? as usize, }; Ok(num_items) diff --git a/src/executor/select.rs b/src/executor/select.rs index 31521b5d9..acca03691 100644 --- a/src/executor/select.rs +++ b/src/executor/select.rs @@ -1,4 +1,8 @@ +#[cfg(feature = "sqlx-dep")] +use std::pin::Pin; use crate::{ConnectionTrait, EntityTrait, FromQueryResult, IdenStatic, Iterable, JsonValue, ModelTrait, Paginator, PrimaryKeyToColumn, QueryResult, Select, SelectA, SelectB, SelectTwo, SelectTwoMany, Statement, error::*}; +#[cfg(feature = "sqlx-dep")] +use futures::{Stream, TryStreamExt}; use sea_query::SelectStatement; use std::marker::PhantomData; @@ -99,27 +103,35 @@ where } } - pub async fn one(self, db: &C) -> Result, DbErr> - where C: ConnectionTrait { + pub async fn one<'a, C>(self, db: &C) -> Result, DbErr> + where C: ConnectionTrait<'a> { self.into_model().one(db).await } - pub async fn all(self, db: &C) -> Result, DbErr> - where C: ConnectionTrait { + pub async fn all<'a, C>(self, db: &C) -> Result, DbErr> + where C: ConnectionTrait<'a> { self.into_model().all(db).await } - pub fn paginate( + #[cfg(feature = "sqlx-dep")] + pub async fn stream<'a: 'b, 'b, C>(self, db: &'a C) -> Result> + 'b, DbErr> + where + C: ConnectionTrait<'a>, + { + self.into_model().stream(db).await + } + + pub fn paginate<'a, C>( self, - db: &C, + db: &'a C, page_size: usize, - ) -> Paginator<'_, C, SelectModel> - where C: ConnectionTrait { + ) -> Paginator<'a, C, SelectModel> + where C: ConnectionTrait<'a> { self.into_model().paginate(db, page_size) } - pub async fn count(self, db: &C) -> Result - where C: ConnectionTrait { + pub async fn count<'a, C>(self, db: &'a C) -> Result + where C: ConnectionTrait<'a> { self.paginate(db, 1).num_items().await } } @@ -148,33 +160,41 @@ where } } - pub async fn one( + pub async fn one<'a, C>( self, db: &C, ) -> Result)>, DbErr> - where C: ConnectionTrait { + where C: ConnectionTrait<'a> { self.into_model().one(db).await } - pub async fn all( + pub async fn all<'a, C>( self, db: &C, ) -> Result)>, DbErr> - where C: ConnectionTrait { + where C: ConnectionTrait<'a> { self.into_model().all(db).await } - pub fn paginate( + #[cfg(feature = "sqlx-dep")] + pub async fn stream<'a: 'b, 'b, C>(self, db: &'a C) -> Result), DbErr>> + 'b, DbErr> + where + C: ConnectionTrait<'a>, + { + self.into_model().stream(db).await + } + + pub fn paginate<'a, C>( self, - db: &C, + db: &'a C, page_size: usize, - ) -> Paginator<'_, C, SelectTwoModel> - where C: ConnectionTrait { + ) -> Paginator<'a, C, SelectTwoModel> + where C: ConnectionTrait<'a> { self.into_model().paginate(db, page_size) } - pub async fn count(self, db: &C) -> Result - where C: ConnectionTrait { + pub async fn count<'a, C>(self, db: &'a C) -> Result + where C: ConnectionTrait<'a> { self.paginate(db, 1).num_items().await } } @@ -203,19 +223,27 @@ where } } - pub async fn one( + pub async fn one<'a, C>( self, db: &C, ) -> Result)>, DbErr> - where C: ConnectionTrait { + where C: ConnectionTrait<'a> { self.into_model().one(db).await } - pub async fn all( + #[cfg(feature = "sqlx-dep")] + pub async fn stream<'a: 'b, 'b, C>(self, db: &'a C) -> Result), DbErr>> + 'b, DbErr> + where + C: ConnectionTrait<'a>, + { + self.into_model().stream(db).await + } + + pub async fn all<'a, C>( self, db: &C, ) -> Result)>, DbErr> - where C: ConnectionTrait { + where C: ConnectionTrait<'a> { let rows = self.into_model().all(db).await?; Ok(consolidate_query_result::(rows)) } @@ -234,8 +262,8 @@ impl Selector where S: SelectorTrait, { - pub async fn one(mut self, db: &C) -> Result, DbErr> - where C: ConnectionTrait { + pub async fn one<'a, C>(mut self, db: &C) -> Result, DbErr> + where C: ConnectionTrait<'a> { let builder = db.get_database_backend(); self.query.limit(1); let row = db.query_one(builder.build(&self.query)).await?; @@ -245,8 +273,8 @@ where } } - pub async fn all(self, db: &C) -> Result, DbErr> - where C: ConnectionTrait { + pub async fn all<'a, C>(self, db: &C) -> Result, DbErr> + where C: ConnectionTrait<'a> { let builder = db.get_database_backend(); let rows = db.query_all(builder.build(&self.query)).await?; let mut models = Vec::new(); @@ -256,8 +284,21 @@ where Ok(models) } - pub fn paginate(self, db: &C, page_size: usize) -> Paginator<'_, C, S> - where C: ConnectionTrait { + #[cfg(feature = "sqlx-dep")] + pub async fn stream<'a: 'b, 'b, C>(self, db: &'a C) -> Result> + 'b>>, DbErr> + where + C: ConnectionTrait<'a>, + S: 'b, + { + let builder = db.get_database_backend(); + let stream = db.stream(builder.build(&self.query)).await?; + Ok(Box::pin(stream.and_then(|row| { + futures::future::ready(S::from_raw_query_result(row)) + }))) + } + + pub fn paginate<'a, C>(self, db: &'a C, page_size: usize) -> Paginator<'a, C, S> + where C: ConnectionTrait<'a> { Paginator { query: self.query, page: 0, @@ -451,8 +492,8 @@ where /// ),] /// ); /// ``` - pub async fn one(self, db: &C) -> Result, DbErr> - where C: ConnectionTrait { + pub async fn one<'a, C>(self, db: &C) -> Result, DbErr> + where C: ConnectionTrait<'a> { let row = db.query_one(self.stmt).await?; match row { Some(row) => Ok(Some(S::from_raw_query_result(row)?)), @@ -491,8 +532,8 @@ where /// ),] /// ); /// ``` - pub async fn all(self, db: &C) -> Result, DbErr> - where C: ConnectionTrait { + pub async fn all<'a, C>(self, db: &C) -> Result, DbErr> + where C: ConnectionTrait<'a> { let rows = db.query_all(self.stmt).await?; let mut models = Vec::new(); for row in rows.into_iter() { diff --git a/src/executor/update.rs b/src/executor/update.rs index da0f5c401..06cd514eb 100644 --- a/src/executor/update.rs +++ b/src/executor/update.rs @@ -16,10 +16,10 @@ impl<'a, A: 'a> UpdateOne where A: ActiveModelTrait, { - pub fn exec(self, db: &'a C) -> impl Future> + 'a - where C: ConnectionTrait { + pub async fn exec<'b, C>(self, db: &'b C) -> Result + where C: ConnectionTrait<'b> { // so that self is dropped before entering await - exec_update_and_return_original(self.query, self.model, db) + exec_update_and_return_original(self.query, self.model, db).await } } @@ -31,7 +31,7 @@ where self, db: &'a C, ) -> impl Future> + 'a - where C: ConnectionTrait { + where C: ConnectionTrait<'a> { // so that self is dropped before entering await exec_update_only(self.query, db) } @@ -42,40 +42,40 @@ impl Updater { Self { query } } - pub fn exec( + pub async fn exec<'a, C>( self, - db: &C, - ) -> impl Future> + '_ - where C: ConnectionTrait { + db: &'a C, + ) -> Result + where C: ConnectionTrait<'a> { let builder = db.get_database_backend(); - exec_update(builder.build(&self.query), db) + exec_update(builder.build(&self.query), db).await } } -async fn exec_update_only( +async fn exec_update_only<'a, C>( query: UpdateStatement, - db: &C, + db: &'a C, ) -> Result -where C: ConnectionTrait { +where C: ConnectionTrait<'a> { Updater::new(query).exec(db).await } -async fn exec_update_and_return_original( +async fn exec_update_and_return_original<'a, A, C>( query: UpdateStatement, model: A, - db: &C, + db: &'a C, ) -> Result where A: ActiveModelTrait, - C: ConnectionTrait, + C: ConnectionTrait<'a>, { Updater::new(query).exec(db).await?; Ok(model) } // Only Statement impl Send -async fn exec_update(statement: Statement, db: &C) -> Result -where C: ConnectionTrait { +async fn exec_update<'a, C>(statement: Statement, db: &'a C) -> Result +where C: ConnectionTrait<'a> { let result = db.execute(statement).await?; Ok(UpdateResult { rows_affected: result.rows_affected(), diff --git a/tests/stream_tests.rs b/tests/stream_tests.rs new file mode 100644 index 000000000..969b93e1d --- /dev/null +++ b/tests/stream_tests.rs @@ -0,0 +1,37 @@ +pub mod common; + +pub use common::{bakery_chain::*, setup::*, TestContext}; +pub use sea_orm::entity::*; +pub use sea_orm::{QueryFilter, ConnectionTrait, DbErr}; +use futures::StreamExt; + +#[sea_orm_macros::test] +#[cfg(any( + feature = "sqlx-mysql", + feature = "sqlx-sqlite", + feature = "sqlx-postgres" +))] +pub async fn stream() -> Result<(), DbErr> { + let ctx = TestContext::new("stream").await; + + let bakery = bakery::ActiveModel { + name: Set("SeaSide Bakery".to_owned()), + profit_margin: Set(10.4), + ..Default::default() + } + .save(&ctx.db) + .await?; + + let result = Bakery::find_by_id(bakery.id.clone().unwrap()) + .stream(&ctx.db) + .await? + .next() + .await + .unwrap()?; + + assert_eq!(result.id, bakery.id.unwrap()); + + ctx.delete().await; + + Ok(()) +} diff --git a/tests/transaction_tests.rs b/tests/transaction_tests.rs index 33de12a5c..539eaefcc 100644 --- a/tests/transaction_tests.rs +++ b/tests/transaction_tests.rs @@ -14,7 +14,7 @@ pub use sea_orm::{QueryFilter, ConnectionTrait}; pub async fn transaction() { let ctx = TestContext::new("transaction_test").await; - ctx.db.transaction::<_, (), DbErr>(|txn| Box::pin(async move { + ctx.db.transaction::<_, _, DbErr>(|txn| Box::pin(async move { let _ = bakery::ActiveModel { name: Set("SeaSide Bakery".to_owned()), profit_margin: Set(10.4),