From 5810507811a65cb7add481592fcb86e4223c5f58 Mon Sep 17 00:00:00 2001 From: Jerome Gravel-Niquet Date: Mon, 23 Oct 2023 17:54:46 -0400 Subject: [PATCH] Improve SQLite-typed input for API (#82) * sqlite-input-improvements * fix unused imports --- Cargo.lock | 1 + Cargo.toml | 2 +- crates/corro-agent/src/agent.rs | 2 +- crates/corro-agent/src/api/public/mod.rs | 50 +++++--- crates/corro-agent/src/api/public/pubsub.rs | 25 +++- crates/corro-api-types/Cargo.toml | 3 +- crates/corro-api-types/src/lib.rs | 127 +++++++++++++++++++- crates/corro-tpl/src/lib.rs | 28 ++--- crates/corrosion/src/main.rs | 6 +- 9 files changed, 199 insertions(+), 45 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 57fe977d..a230a58a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -765,6 +765,7 @@ dependencies = [ "hex", "rusqlite", "serde", + "serde_json", "smallvec", "strum", "thiserror", diff --git a/Cargo.toml b/Cargo.toml index 108166af..75c430c5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -54,7 +54,7 @@ rustls = { version = "0.21.0", features = ["dangerous_configuration", "quic"] } rustls-pemfile = "1.0.2" seahash = "4.1.0" serde = "1.0.159" -serde_json = "1.0.95" +serde_json = { version = "1.0.95", features = ["raw_value"] } serde_with = "2.3.2" smallvec = { version = "1.11.0", features = ["serde", "write", "union"] } speedy = { version = "0.8.7", features = ["uuid", "smallvec"], package = "corro-speedy" } diff --git a/crates/corro-agent/src/agent.rs b/crates/corro-agent/src/agent.rs index 42c51993..29b5b5b4 100644 --- a/crates/corro-agent/src/agent.rs +++ b/crates/corro-agent/src/agent.rs @@ -2736,7 +2736,7 @@ pub mod tests { use serde_json::json; use spawn::wait_for_all_pending_handles; use tokio::time::{sleep, timeout, MissedTickBehavior}; - use tracing::{info, info_span}; + use tracing::info_span; use tripwire::Tripwire; use super::*; diff --git a/crates/corro-agent/src/api/public/mod.rs b/crates/corro-agent/src/api/public/mod.rs index f30e21ae..16d60d6b 100644 --- a/crates/corro-agent/src/api/public/mod.rs +++ b/crates/corro-agent/src/api/public/mod.rs @@ -308,16 +308,25 @@ where #[tracing::instrument(skip_all, err)] fn execute_statement(tx: &Transaction, stmt: &Statement) -> rusqlite::Result { - let mut prepped = match &stmt { - Statement::Simple(q) => tx.prepare(q), - Statement::WithParams(q, _) => tx.prepare(q), - Statement::WithNamedParams(q, _) => tx.prepare(q), - }?; + let mut prepped = tx.prepare(stmt.query())?; match stmt { - Statement::Simple(_) => prepped.execute([]), - Statement::WithParams(_, params) => prepped.execute(params_from_iter(params)), - Statement::WithNamedParams(_, params) => prepped.execute( + Statement::Simple(_) + | Statement::Verbose { + params: None, + named_params: None, + .. + } => prepped.execute([]), + Statement::WithParams(_, params) + | Statement::Verbose { + params: Some(params), + .. + } => prepped.execute(params_from_iter(params)), + Statement::WithNamedParams(_, params) + | Statement::Verbose { + named_params: Some(params), + .. + } => prepped.execute( params .iter() .map(|(k, v)| (k.as_str(), v as &dyn ToSql)) @@ -429,11 +438,7 @@ async fn build_query_rows_response( } }; - let prepped_res = block_in_place(|| match &stmt { - Statement::Simple(q) => conn.prepare(q), - Statement::WithParams(q, _) => conn.prepare(q), - Statement::WithNamedParams(q, _) => conn.prepare(q), - }); + let prepped_res = block_in_place(|| conn.prepare(stmt.query())); let mut prepped = match prepped_res { Ok(prepped) => prepped, @@ -476,9 +481,22 @@ async fn build_query_rows_response( let start = Instant::now(); let query = match stmt { - Statement::Simple(_) => prepped.query(()), - Statement::WithParams(_, params) => prepped.query(params_from_iter(params)), - Statement::WithNamedParams(_, params) => prepped.query( + Statement::Simple(_) + | Statement::Verbose { + params: None, + named_params: None, + .. + } => prepped.query(()), + Statement::WithParams(_, params) + | Statement::Verbose { + params: Some(params), + .. + } => prepped.query(params_from_iter(params)), + Statement::WithNamedParams(_, params) + | Statement::Verbose { + named_params: Some(params), + .. + } => prepped.query( params .iter() .map(|(k, v)| (k.as_str(), v as &dyn ToSql)) diff --git a/crates/corro-agent/src/api/public/pubsub.rs b/crates/corro-agent/src/api/public/pubsub.rs index f9f6fef2..bcdfe4d8 100644 --- a/crates/corro-agent/src/api/public/pubsub.rs +++ b/crates/corro-agent/src/api/public/pubsub.rs @@ -202,16 +202,31 @@ pub async fn process_sub_channel( fn expanded_statement(conn: &Connection, stmt: &Statement) -> rusqlite::Result> { Ok(match stmt { - Statement::Simple(q) => conn.prepare(q)?.expanded_sql(), - Statement::WithParams(q, params) => { - let mut prepped = conn.prepare(q)?; + Statement::Simple(query) + | Statement::Verbose { + query, + params: None, + named_params: None, + } => conn.prepare(query)?.expanded_sql(), + Statement::WithParams(query, params) + | Statement::Verbose { + query, + params: Some(params), + .. + } => { + let mut prepped = conn.prepare(query)?; for (i, param) in params.iter().enumerate() { prepped.raw_bind_parameter(i + 1, param)?; } prepped.expanded_sql() } - Statement::WithNamedParams(q, params) => { - let mut prepped = conn.prepare(q)?; + Statement::WithNamedParams(query, params) + | Statement::Verbose { + query, + named_params: Some(params), + .. + } => { + let mut prepped = conn.prepare(query)?; for (k, v) in params.iter() { let idx = match prepped.parameter_index(k)? { Some(idx) => idx, diff --git a/crates/corro-api-types/Cargo.toml b/crates/corro-api-types/Cargo.toml index 02658de6..9bf05601 100644 --- a/crates/corro-api-types/Cargo.toml +++ b/crates/corro-api-types/Cargo.toml @@ -13,8 +13,9 @@ compact_str = { workspace = true } hex = { workspace = true } rusqlite = { workspace = true } serde = { workspace = true } +serde_json = { workspace = true } smallvec = { workspace = true } speedy = { workspace = true } strum = { workspace = true } thiserror = { workspace = true } -tokio = { workspace = true } \ No newline at end of file +tokio = { workspace = true } diff --git a/crates/corro-api-types/src/lib.rs b/crates/corro-api-types/src/lib.rs index 3ccf6b2b..d83942b6 100644 --- a/crates/corro-api-types/src/lib.rs +++ b/crates/corro-api-types/src/lib.rs @@ -11,6 +11,7 @@ use rusqlite::{ Row, ToSql, }; use serde::{Deserialize, Serialize}; +use serde_json::value::RawValue; use smallvec::{SmallVec, ToSmallVec}; use speedy::{Context, Readable, Reader, Writable, Writer}; use sqlite::ChangeType; @@ -120,9 +121,25 @@ impl ToSql for ChangeId { #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(untagged)] pub enum Statement { + Verbose { + query: String, + params: Option>, + named_params: Option>, + }, Simple(String), - WithParams(String, Vec), - WithNamedParams(String, HashMap), + WithParams(String, Vec), + WithNamedParams(String, HashMap), +} + +impl Statement { + pub fn query(&self) -> &str { + match self { + Statement::Verbose { query, .. } + | Statement::Simple(query) + | Statement::WithParams(query, _) + | Statement::WithNamedParams(query, _) => query, + } + } } impl From<&str> for Statement { @@ -292,6 +309,76 @@ impl FromSql for ColumnType { } } +#[allow(clippy::large_enum_variant)] +#[derive(Debug, Default, Clone, Serialize, Deserialize)] +#[serde(untagged)] +pub enum SqliteParam { + #[default] + Null, + Bool(bool), + Integer(i64), + Real(f64), + Text(CompactString), + Blob(SmallVec<[u8; 512]>), + Json(Box), +} + +impl From<&str> for SqliteParam { + fn from(value: &str) -> Self { + Self::Text(value.into()) + } +} + +impl From> for SqliteParam { + fn from(value: Vec) -> Self { + Self::Blob(value.into()) + } +} + +impl From for SqliteParam { + fn from(value: String) -> Self { + Self::Text(value.into()) + } +} + +impl From for SqliteParam { + fn from(value: u16) -> Self { + Self::Integer(value as i64) + } +} + +impl From for SqliteParam { + fn from(value: i64) -> Self { + Self::Integer(value) + } +} + +impl ToSql for SqliteParam { + fn to_sql(&self) -> rusqlite::Result> { + Ok(match self { + SqliteParam::Null => ToSqlOutput::Owned(Value::Null), + SqliteParam::Bool(v) => ToSqlOutput::Owned(Value::Integer(*v as i64)), + SqliteParam::Integer(i) => ToSqlOutput::Owned(Value::Integer(*i)), + SqliteParam::Real(f) => ToSqlOutput::Owned(Value::Real(*f)), + SqliteParam::Text(t) => ToSqlOutput::Borrowed(ValueRef::Text(t.as_bytes())), + SqliteParam::Blob(b) => ToSqlOutput::Borrowed(ValueRef::Blob(b)), + SqliteParam::Json(map) => ToSqlOutput::Borrowed(ValueRef::Text(map.get().as_bytes())), + }) + } +} + +impl<'a> ToSql for SqliteValueRef<'a> { + fn to_sql(&self) -> rusqlite::Result> { + Ok(match self { + SqliteValueRef::Null => ToSqlOutput::Owned(Value::Null), + SqliteValueRef::Integer(i) => ToSqlOutput::Owned(Value::Integer(*i)), + SqliteValueRef::Real(f) => ToSqlOutput::Owned(Value::Real(*f)), + SqliteValueRef::Text(t) => ToSqlOutput::Borrowed(ValueRef::Text(t.as_bytes())), + SqliteValueRef::Blob(b) => ToSqlOutput::Borrowed(ValueRef::Blob(b)), + }) + } +} + #[allow(clippy::large_enum_variant)] #[derive(Debug, Default, Clone, Serialize, Deserialize, PartialEq, Hash)] #[serde(untagged)] @@ -655,3 +742,39 @@ impl ToSql for ColumnName { self.0.as_str().to_sql() } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_statement_serialization() { + let s = serde_json::to_string(&vec![Statement::WithParams( + "select 1 + from table + where column = ?" + .into(), + vec!["my-value".into()], + )]) + .unwrap(); + println!("{s}"); + + let stmts: Vec = serde_json::from_str(&s).unwrap(); + println!("stmts: {stmts:?}"); + + let json = r#"[["some statement",[1,"encodedID","nodeName",1,"Name","State",true,true,"",1234,1698084893487,1698084893487]]]"#; + + let value: serde_json::Value = serde_json::from_str(json).unwrap(); + println!("value: {value:#?}"); + + let stmts: Vec = serde_json::from_str(json).unwrap(); + println!("stmts: {stmts:?}"); + + let json = r#"[{"query": "some statement", "params": [1,"encodedID","nodeName",1,"Name","State",true,true,"",1234,1698084893487,1698084893487]}]"#; + let value: serde_json::Value = serde_json::from_str(json).unwrap(); + println!("value: {value:#?}"); + + let stmts: Vec = serde_json::from_str(json).unwrap(); + println!("stmts: {stmts:?}"); + } +} diff --git a/crates/corro-tpl/src/lib.rs b/crates/corro-tpl/src/lib.rs index 726a1a90..1328a548 100644 --- a/crates/corro-tpl/src/lib.rs +++ b/crates/corro-tpl/src/lib.rs @@ -12,6 +12,7 @@ use compact_str::ToCompactString; use corro_client::sub::SubscriptionStream; use corro_client::CorrosionApiClient; use corro_types::api::QueryEvent; +use corro_types::api::SqliteParam; use corro_types::api::Statement; use corro_types::change::SqliteValue; use futures::StreamExt; @@ -536,33 +537,28 @@ impl Engine { } }); - fn dyn_to_sql(v: Dynamic) -> Result> { + fn dyn_to_sql(v: Dynamic) -> Result> { Ok(match v.type_name() { - "()" => SqliteValue::Null, - "i64" => SqliteValue::Integer( + "()" => SqliteParam::Null, + "i64" => SqliteParam::Integer( v.as_int() .map_err(|_e| Box::new(EvalAltResult::from("could not cast to i64")))?, ), - "f64" => SqliteValue::Real(corro_types::api::Real( + "f64" => SqliteParam::Real( v.as_float() .map_err(|_e| Box::new(EvalAltResult::from("could not cast to f64")))?, - )), - "bool" => { - if v.as_bool() - .map_err(|_e| Box::new(EvalAltResult::from("could not cast to bool")))? - { - SqliteValue::Integer(1) - } else { - SqliteValue::Integer(0) - } - } - "blob" => SqliteValue::Blob( + ), + "bool" => SqliteParam::Bool( + v.as_bool() + .map_err(|_e| Box::new(EvalAltResult::from("could not cast to bool")))?, + ), + "blob" => SqliteParam::Blob( v.into_blob() .map_err(|_e| Box::new(EvalAltResult::from("could not cast to blob")))? .into(), ), // convert everything else into a string, including a string - _ => SqliteValue::Text(v.to_compact_string()), + _ => SqliteParam::Text(v.to_compact_string()), }) } diff --git a/crates/corrosion/src/main.rs b/crates/corrosion/src/main.rs index 685593b6..367054e6 100644 --- a/crates/corrosion/src/main.rs +++ b/crates/corrosion/src/main.rs @@ -12,7 +12,7 @@ use command::{ tls::{generate_ca, generate_client_cert, generate_server_cert}, tpl::TemplateFlags, }; -use corro_api_types::SqliteValue; +use corro_api_types::SqliteParam; use corro_client::CorrosionApiClient; use corro_types::{ api::{ExecResult, QueryEvent, Statement}, @@ -301,7 +301,7 @@ async fn process_cli(cli: Cli) -> eyre::Result<()> { } else { Statement::WithParams( query.clone(), - param.iter().map(|p| SqliteValue::Text(p.into())).collect(), + param.iter().map(|p| SqliteParam::Text(p.into())).collect(), ) }; @@ -359,7 +359,7 @@ async fn process_cli(cli: Cli) -> eyre::Result<()> { } else { Statement::WithParams( query.clone(), - param.iter().map(|p| SqliteValue::Text(p.into())).collect(), + param.iter().map(|p| SqliteParam::Text(p.into())).collect(), ) };