Skip to content

Commit

Permalink
Transaction 2 (#199)
Browse files Browse the repository at this point in the history
* Stream implementation

* use ouroboros to cover self-references

* Reduce test size

* Reduce test size

* Complete transaction rewrite + streams

* Less radioactive unsafe

* Mutex transaction

* Solve clippy lints

* panic on drop of a locked transaction

* panic on drop of a locked transaction

* panic on drop of a locked transaction

* Centralize Sync requirement
  • Loading branch information
nappa85 authored Oct 2, 2021
1 parent 460462b commit 9b9c217
Show file tree
Hide file tree
Showing 21 changed files with 702 additions and 274 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ serde = { version = "^1.0", features = ["derive"] }
serde_json = { version = "^1", optional = true }
sqlx = { version = "^0.5", optional = true }
uuid = { version = "0.8", features = ["serde", "v4"], optional = true }
ouroboros = "0.11"

[dev-dependencies]
smol = { version = "^1.2" }
Expand Down
51 changes: 43 additions & 8 deletions src/database/connection.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{pin::Pin, future::Future};
use std::{future::Future, pin::Pin, sync::Arc};
use crate::{DatabaseTransaction, ConnectionTrait, ExecResult, QueryResult, Statement, StatementBuilder, TransactionError, error::*};
use sea_query::{MysqlQueryBuilder, PostgresQueryBuilder, QueryBuilder, SqliteQueryBuilder};

Expand All @@ -11,7 +11,7 @@ pub enum DatabaseConnection {
#[cfg(feature = "sqlx-sqlite")]
SqlxSqlitePoolConnection(crate::SqlxSqlitePoolConnection),
#[cfg(feature = "mock")]
MockDatabaseConnection(crate::MockDatabaseConnection),
MockDatabaseConnection(Arc<crate::MockDatabaseConnection>),
Disconnected,
}

Expand Down Expand Up @@ -53,7 +53,9 @@ impl std::fmt::Debug for DatabaseConnection {
}

#[async_trait::async_trait]
impl ConnectionTrait for DatabaseConnection {
impl<'a> ConnectionTrait<'a> for DatabaseConnection {
type Stream = crate::QueryStream;

fn get_database_backend(&self) -> DbBackend {
match self {
#[cfg(feature = "sqlx-mysql")]
Expand All @@ -77,7 +79,7 @@ impl ConnectionTrait for DatabaseConnection {
#[cfg(feature = "sqlx-sqlite")]
DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.execute(stmt).await,
#[cfg(feature = "mock")]
DatabaseConnection::MockDatabaseConnection(conn) => conn.execute(stmt).await,
DatabaseConnection::MockDatabaseConnection(conn) => conn.execute(stmt),
DatabaseConnection::Disconnected => Err(DbErr::Conn("Disconnected".to_owned())),
}
}
Expand All @@ -91,7 +93,7 @@ impl ConnectionTrait for DatabaseConnection {
#[cfg(feature = "sqlx-sqlite")]
DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.query_one(stmt).await,
#[cfg(feature = "mock")]
DatabaseConnection::MockDatabaseConnection(conn) => conn.query_one(stmt).await,
DatabaseConnection::MockDatabaseConnection(conn) => conn.query_one(stmt),
DatabaseConnection::Disconnected => Err(DbErr::Conn("Disconnected".to_owned())),
}
}
Expand All @@ -105,16 +107,46 @@ impl ConnectionTrait for DatabaseConnection {
#[cfg(feature = "sqlx-sqlite")]
DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.query_all(stmt).await,
#[cfg(feature = "mock")]
DatabaseConnection::MockDatabaseConnection(conn) => conn.query_all(stmt).await,
DatabaseConnection::MockDatabaseConnection(conn) => conn.query_all(stmt),
DatabaseConnection::Disconnected => Err(DbErr::Conn("Disconnected".to_owned())),
}
}

fn stream(&'a self, stmt: Statement) -> Pin<Box<dyn Future<Output=Result<Self::Stream, DbErr>> + 'a>> {
Box::pin(async move {
Ok(match self {
#[cfg(feature = "sqlx-mysql")]
DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.stream(stmt).await?,
#[cfg(feature = "sqlx-postgres")]
DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.stream(stmt).await?,
#[cfg(feature = "sqlx-sqlite")]
DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.stream(stmt).await?,
#[cfg(feature = "mock")]
DatabaseConnection::MockDatabaseConnection(conn) => crate::QueryStream::from((Arc::clone(conn), stmt)),
DatabaseConnection::Disconnected => panic!("Disconnected"),
})
})
}

async fn begin(&self) -> Result<DatabaseTransaction, DbErr> {
match self {
#[cfg(feature = "sqlx-mysql")]
DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.begin().await,
#[cfg(feature = "sqlx-postgres")]
DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.begin().await,
#[cfg(feature = "sqlx-sqlite")]
DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.begin().await,
#[cfg(feature = "mock")]
DatabaseConnection::MockDatabaseConnection(conn) => DatabaseTransaction::new_mock(Arc::clone(conn)).await,
DatabaseConnection::Disconnected => panic!("Disconnected"),
}
}

/// Execute the function inside a transaction.
/// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed.
async fn transaction<F, T, E>(&self, _callback: F) -> Result<T, TransactionError<E>>
where
F: for<'c> FnOnce(&'c DatabaseTransaction<'_>) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>> + Send + Sync,
F: for<'c> FnOnce(&'c DatabaseTransaction) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>> + Send,
T: Send,
E: std::error::Error + Send,
{
Expand All @@ -126,7 +158,10 @@ impl ConnectionTrait for DatabaseConnection {
#[cfg(feature = "sqlx-sqlite")]
DatabaseConnection::SqlxSqlitePoolConnection(conn) => conn.transaction(_callback).await,
#[cfg(feature = "mock")]
DatabaseConnection::MockDatabaseConnection(_) => unimplemented!(), //TODO: support transaction in mock connection
DatabaseConnection::MockDatabaseConnection(conn) => {
let transaction = DatabaseTransaction::new_mock(Arc::clone(conn)).await.map_err(|e| TransactionError::Connection(e))?;
transaction.run(_callback).await
},
DatabaseConnection::Disconnected => panic!("Disconnected"),
}
}
Expand Down
28 changes: 24 additions & 4 deletions src/database/db_connection.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,24 @@
use std::{pin::Pin, future::Future};
use crate::{DatabaseTransaction, DbBackend, DbErr, ExecResult, QueryResult, Statement, TransactionError};
use std::{future::Future, pin::Pin, sync::Arc};
use crate::{DatabaseTransaction, DbBackend, DbErr, ExecResult, MockDatabaseConnection, QueryResult, Statement, TransactionError};
use futures::Stream;
#[cfg(feature = "sqlx-dep")]
use sqlx::pool::PoolConnection;

pub(crate) enum InnerConnection {
#[cfg(feature = "sqlx-mysql")]
MySql(PoolConnection<sqlx::MySql>),
#[cfg(feature = "sqlx-postgres")]
Postgres(PoolConnection<sqlx::Postgres>),
#[cfg(feature = "sqlx-sqlite")]
Sqlite(PoolConnection<sqlx::Sqlite>),
#[cfg(feature = "mock")]
Mock(Arc<MockDatabaseConnection>),
}

#[async_trait::async_trait]
pub trait ConnectionTrait: Sync {
pub trait ConnectionTrait<'a>: Sync {
type Stream: Stream<Item=Result<QueryResult, DbErr>>;

fn get_database_backend(&self) -> DbBackend;

async fn execute(&self, stmt: Statement) -> Result<ExecResult, DbErr>;
Expand All @@ -11,11 +27,15 @@ pub trait ConnectionTrait: Sync {

async fn query_all(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr>;

fn stream(&'a self, stmt: Statement) -> Pin<Box<dyn Future<Output=Result<Self::Stream, DbErr>> + 'a>>;

async fn begin(&self) -> Result<DatabaseTransaction, DbErr>;

/// Execute the function inside a transaction.
/// If the function returns an error, the transaction will be rolled back. If it does not return an error, the transaction will be committed.
async fn transaction<F, T, E>(&self, callback: F) -> Result<T, TransactionError<E>>
where
F: for<'c> FnOnce(&'c DatabaseTransaction<'_>) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>> + Send + Sync,
F: for<'c> FnOnce(&'c DatabaseTransaction) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'c>> + Send,
T: Send,
E: std::error::Error + Send;

Expand Down
Loading

0 comments on commit 9b9c217

Please sign in to comment.