diff --git a/src/coercions.rs b/src/coercions.rs index c69d255..de17594 100644 --- a/src/coercions.rs +++ b/src/coercions.rs @@ -6,6 +6,8 @@ use serde_json::Value; use std::collections::HashMap; use std::str::FromStr; +use crate::serialization::DeserializedMessage; + #[derive(Debug, Clone, PartialEq)] #[allow(unused)] enum CoercionNode { @@ -72,7 +74,7 @@ fn build_coercion_node(data_type: &DataType) -> Option { /// Applies all data coercions specified by the [`CoercionTree`] to the [`Value`]. /// Though it does not currently, this function should approximate or improve on the coercions applied by [Spark's `from_json`](https://spark.apache.org/docs/latest/api/sql/index.html#from_json) -pub(crate) fn coerce(value: &mut Value, coercion_tree: &CoercionTree) { +pub(crate) fn coerce(value: &mut DeserializedMessage, coercion_tree: &CoercionTree) { if let Some(context) = value.as_object_mut() { for (field_name, coercion) in coercion_tree.root.iter() { if let Some(value) = context.get_mut(field_name) { @@ -322,7 +324,7 @@ mod tests { let coercion_tree = create_coercion_tree(&delta_schema); - let mut messages = vec![ + let mut messages: Vec = vec![ json!({ "level1_string": "a", "level1_integer": 0, @@ -380,7 +382,10 @@ mod tests { // This is valid epoch micros, but typed as a string on the way in. We WON'T coerce it. "level1_timestamp": "1636668718000000", }), - ]; + ] + .into_iter() + .map(|f| f.into()) + .collect(); for message in messages.iter_mut() { coerce(message, &coercion_tree); @@ -447,7 +452,7 @@ mod tests { ]; for i in 0..messages.len() { - assert_eq!(messages[i], expected[i]); + assert_eq!(messages[i].clone().message(), expected[i]); } } } diff --git a/src/dead_letters.rs b/src/dead_letters.rs index a4e709e..e7f6e1f 100644 --- a/src/dead_letters.rs +++ b/src/dead_letters.rs @@ -14,6 +14,7 @@ use serde::{Deserialize, Serialize}; use serde_json::Value; use std::collections::HashMap; +use crate::serialization::DeserializedMessage; use crate::{transforms::TransformError, writer::*}; #[cfg(feature = "s3")] @@ -55,11 +56,11 @@ impl DeadLetter { /// Creates a dead letter from a failed transform. /// `base64_bytes` will always be `None`. - pub(crate) fn from_failed_transform(value: &Value, err: TransformError) -> Self { + pub(crate) fn from_failed_transform(value: &DeserializedMessage, err: TransformError) -> Self { let timestamp = Utc::now(); Self { base64_bytes: None, - json_string: Some(value.to_string()), + json_string: Some(value.clone().message().to_string()), error: Some(err.to_string()), timestamp: timestamp .timestamp_nanos_opt() @@ -286,9 +287,10 @@ impl DeadLetterQueue for DeltaSinkDeadLetterQueue { .map(|dl| { serde_json::to_value(dl) .map_err(|e| DeadLetterQueueError::SerdeJson { source: e }) - .and_then(|mut v| { + .and_then(|v| { self.transformer - .transform(&mut v, None as Option<&BorrowedMessage>)?; + // TODO: this can't be right, shouldn't this function takje DeserializedMessage + .transform(&mut v.clone().into(), None as Option<&BorrowedMessage>)?; Ok(v) }) }) @@ -297,7 +299,10 @@ impl DeadLetterQueue for DeltaSinkDeadLetterQueue { let version = self .delta_writer - .insert_all(&mut self.table, values) + .insert_all( + &mut self.table, + values.into_iter().map(|v| v.into()).collect(), + ) .await?; if self.write_checkpoints { diff --git a/src/lib.rs b/src/lib.rs index f3111aa..b60b68a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -45,7 +45,8 @@ mod dead_letters; mod delta_helpers; mod metrics; mod offsets; -mod serialization; +#[allow(missing_docs)] +pub mod serialization; mod transforms; mod value_buffers; /// Doc @@ -56,6 +57,7 @@ use crate::value_buffers::{ConsumedBuffers, ValueBuffers}; use crate::{ dead_letters::*, metrics::*, + serialization::*, transforms::*, writer::{DataWriter, DataWriterError}, }; @@ -207,8 +209,9 @@ pub enum IngestError { } /// Formats for message parsing -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Default)] pub enum MessageFormat { + #[default] /// Parses messages as json and uses the inferred schema DefaultJson, @@ -733,7 +736,7 @@ struct IngestProcessor { coercion_tree: CoercionTree, table: DeltaTable, delta_writer: DataWriter, - value_buffers: ValueBuffers, + value_buffers: ValueBuffers, delta_partition_offsets: HashMap>, latency_timer: Instant, dlq: Box, @@ -864,7 +867,7 @@ impl IngestProcessor { async fn deserialize_message( &mut self, msg: &M, - ) -> Result + ) -> Result where M: Message + Send + Sync, { diff --git a/src/serialization.rs b/src/serialization.rs index 9dfc68b..55c2fff 100644 --- a/src/serialization.rs +++ b/src/serialization.rs @@ -8,12 +8,53 @@ use serde_json::Value; use crate::{dead_letters::DeadLetter, MessageDeserializationError, MessageFormat}; +use deltalake_core::arrow::datatypes::Schema as ArrowSchema; + +/// Structure which contains the [serde_json::Value] and the inferred schema of the message +/// +/// The [ArrowSchema] helps with schema evolution +#[derive(Clone, Debug, Default, PartialEq)] +pub struct DeserializedMessage { + message: Value, + schema: Option, +} + +impl DeserializedMessage { + pub fn schema(&self) -> &Option { + &self.schema + } + pub fn message(self) -> Value { + self.message + } + pub fn get(&self, key: &str) -> Option<&Value> { + self.message.get(key) + } + pub fn as_object_mut(&mut self) -> Option<&mut serde_json::Map> { + self.message.as_object_mut() + } +} + +/// Allow for `.into()` on [Value] for ease of use +impl From for DeserializedMessage { + fn from(message: Value) -> Self { + // XXX: This seems wasteful, this function should go away, and the deserializers should + // infer straight from the buffer stream + let iter = vec![message.clone()].into_iter().map(|v| Ok(v)); + let schema = + match deltalake_core::arrow::json::reader::infer_json_schema_from_iterator(iter) { + Ok(schema) => Some(schema), + _ => None, + }; + Self { message, schema } + } +} + #[async_trait] pub(crate) trait MessageDeserializer { async fn deserialize( &mut self, message_bytes: &[u8], - ) -> Result; + ) -> Result; } pub(crate) struct MessageDeserializerFactory {} @@ -80,11 +121,15 @@ impl MessageDeserializerFactory { } } +#[derive(Clone, Debug, Default)] struct DefaultDeserializer {} #[async_trait] impl MessageDeserializer for DefaultDeserializer { - async fn deserialize(&mut self, payload: &[u8]) -> Result { + async fn deserialize( + &mut self, + payload: &[u8], + ) -> Result { let value: Value = match serde_json::from_slice(payload) { Ok(v) => v, Err(e) => { @@ -94,7 +139,41 @@ impl MessageDeserializer for DefaultDeserializer { } }; - Ok(value) + Ok(value.into()) + } +} + +#[cfg(test)] +mod default_tests { + use super::*; + + #[tokio::test] + async fn deserialize_with_schema() { + let mut deser = DefaultDeserializer::default(); + let message = deser + .deserialize(r#"{"hello" : "world"}"#.as_bytes()) + .await + .expect("Failed to deserialize trivial JSON"); + assert!( + message.schema().is_some(), + "The DeserializedMessage doesn't have a schema!" + ); + } + + #[tokio::test] + async fn deserialize_simple_json() { + #[derive(serde::Deserialize)] + struct HW { + hello: String, + } + + let mut deser = DefaultDeserializer::default(); + let message = deser + .deserialize(r#"{"hello" : "world"}"#.as_bytes()) + .await + .expect("Failed to deserialize trivial JSON"); + let value: HW = serde_json::from_value(message.message).expect("Failed to coerce"); + assert_eq!("world", value.hello); } } @@ -116,11 +195,11 @@ impl MessageDeserializer for AvroDeserializer { async fn deserialize( &mut self, message_bytes: &[u8], - ) -> Result { + ) -> Result { match self.decoder.decode_with_schema(Some(message_bytes)).await { Ok(drs) => match drs { Some(v) => match Value::try_from(v.value) { - Ok(v) => Ok(v), + Ok(v) => Ok(v.into()), Err(e) => Err(MessageDeserializationError::AvroDeserialization { dead_letter: DeadLetter::from_failed_deserialization( message_bytes, @@ -147,7 +226,7 @@ impl MessageDeserializer for AvroSchemaDeserializer { async fn deserialize( &mut self, message_bytes: &[u8], - ) -> Result { + ) -> Result { let reader_result = match &self.schema { None => apache_avro::Reader::new(Cursor::new(message_bytes)), Some(schema) => apache_avro::Reader::with_schema(schema, Cursor::new(message_bytes)), @@ -162,7 +241,7 @@ impl MessageDeserializer for AvroSchemaDeserializer { }; return match v { - Ok(value) => Ok(value), + Ok(value) => Ok(value.into()), Err(e) => Err(MessageDeserializationError::AvroDeserialization { dead_letter: DeadLetter::from_failed_deserialization( message_bytes, @@ -221,11 +300,11 @@ impl MessageDeserializer for JsonDeserializer { async fn deserialize( &mut self, message_bytes: &[u8], - ) -> Result { + ) -> Result { let decoder = self.decoder.borrow_mut(); match decoder.decode(Some(message_bytes)).await { Ok(drs) => match drs { - Some(v) => Ok(v.value), + Some(v) => Ok(v.value.into()), None => return Err(MessageDeserializationError::EmptyPayload), }, Err(e) => { diff --git a/src/transforms.rs b/src/transforms.rs index b4803b6..3cb83cc 100644 --- a/src/transforms.rs +++ b/src/transforms.rs @@ -1,3 +1,4 @@ +use crate::serialization::DeserializedMessage; use chrono::prelude::*; use jmespatch::{ functions::{ArgumentType, CustomFunction, Signature}, @@ -348,13 +349,13 @@ impl Transformer { /// The optional `kafka_message` must be provided to include well known Kafka properties in the value. pub(crate) fn transform( &self, - value: &mut Value, + value: &mut DeserializedMessage, kafka_message: Option<&M>, ) -> Result<(), TransformError> where M: Message, { - let data = Variable::try_from(value.clone())?; + let data = Variable::try_from(value.clone().message())?; match value.as_object_mut() { Some(map) => { @@ -378,7 +379,7 @@ impl Transformer { Ok(()) } _ => Err(TransformError::ValueNotAnObject { - value: value.to_owned(), + value: value.clone().message(), }), } } @@ -510,7 +511,7 @@ mod tests { #[test] fn transforms_with_substr() { - let mut test_value = json!({ + let test_value = json!({ "name": "A", "modified": "2021-03-16T14:38:58Z", }); @@ -524,6 +525,7 @@ mod tests { 0, None, ); + let mut test_value: DeserializedMessage = test_value.into(); let mut transforms = HashMap::new(); @@ -540,6 +542,7 @@ mod tests { let name = test_value.get("name").unwrap().as_str().unwrap(); let modified = test_value.get("modified").unwrap().as_str().unwrap(); + println!("TEST: {test_value:?}"); let modified_date = test_value.get("modified_date").unwrap().as_str().unwrap(); assert_eq!("A", name); @@ -567,7 +570,7 @@ mod tests { fn test_transforms_with_epoch_seconds_to_iso8601() { let expected_iso = "2021-07-20T23:18:18Z"; - let mut test_value = json!({ + let test_value = json!({ "name": "A", "epoch_seconds_float": 1626823098.51995, "epoch_seconds_int": 1626823098, @@ -584,6 +587,7 @@ mod tests { 0, None, ); + let mut test_value: DeserializedMessage = test_value.into(); let mut transforms = HashMap::new(); transforms.insert( @@ -640,7 +644,7 @@ mod tests { #[test] fn test_transforms_with_kafka_meta() { - let mut test_value = json!({ + let test_value = json!({ "name": "A", "modified": "2021-03-16T14:38:58Z", }); @@ -655,6 +659,7 @@ mod tests { None, ); + let mut test_value: DeserializedMessage = test_value.into(); let mut transforms = HashMap::new(); transforms.insert("_kafka_offset".to_string(), "kafka.offset".to_string()); diff --git a/src/value_buffers.rs b/src/value_buffers.rs index 6e61725..fe6ed7f 100644 --- a/src/value_buffers.rs +++ b/src/value_buffers.rs @@ -1,21 +1,20 @@ use crate::{DataTypeOffset, DataTypePartition, IngestError}; -use serde_json::Value; use std::collections::HashMap; /// Provides a single interface into the multiple [`ValueBuffer`] instances used to buffer data for each assigned partition. #[derive(Debug, Default)] -pub(crate) struct ValueBuffers { - buffers: HashMap, +pub(crate) struct ValueBuffers { + buffers: HashMap>, len: usize, } -impl ValueBuffers { +impl ValueBuffers { /// Adds a value to in-memory buffers and tracks the partition and offset. pub(crate) fn add( &mut self, partition: DataTypePartition, offset: DataTypeOffset, - value: Value, + value: T, ) -> Result<(), IngestError> { let buffer = self .buffers @@ -40,7 +39,7 @@ impl ValueBuffers { } /// Returns values, partition offsets and partition counts currently held in buffer and resets buffers to empty. - pub(crate) fn consume(&mut self) -> ConsumedBuffers { + pub(crate) fn consume(&mut self) -> ConsumedBuffers { let mut partition_offsets = HashMap::new(); let mut partition_counts = HashMap::new(); @@ -76,14 +75,14 @@ impl ValueBuffers { /// Buffer of values held in memory for a single Kafka partition. #[derive(Debug)] -struct ValueBuffer { +struct ValueBuffer { /// The offset of the last message stored in the buffer. last_offset: DataTypeOffset, - /// The buffer of [`Value`] instances. - values: Vec, + /// The buffer of `T` instances. + values: Vec, } -impl ValueBuffer { +impl ValueBuffer { /// Creates a new [`ValueBuffer`] to store messages from a Kafka partition. pub(crate) fn new() -> Self { Self { @@ -97,13 +96,13 @@ impl ValueBuffer { } /// Adds the value to buffer and stores its offset as the `last_offset` of the buffer. - pub(crate) fn add(&mut self, value: Value, offset: DataTypeOffset) { + pub(crate) fn add(&mut self, value: T, offset: DataTypeOffset) { self.last_offset = offset; self.values.push(value); } /// Consumes and returns the buffer and last offset so it may be written to delta and clears internal state. - pub(crate) fn consume(&mut self) -> Option<(Vec, DataTypeOffset)> { + pub(crate) fn consume(&mut self) -> Option<(Vec, DataTypeOffset)> { if !self.values.is_empty() { assert!(self.last_offset >= 0); Some((std::mem::take(&mut self.values), self.last_offset)) @@ -114,9 +113,9 @@ impl ValueBuffer { } /// A struct that wraps the data consumed from [`ValueBuffers`] before writing to a [`arrow::record_batch::RecordBatch`]. -pub(crate) struct ConsumedBuffers { - /// The vector of [`Value`] instances consumed. - pub(crate) values: Vec, +pub(crate) struct ConsumedBuffers { + /// The vector of `T` instances consumed. + pub(crate) values: Vec, /// A [`HashMap`] from partition to last offset represented by the consumed buffers. pub(crate) partition_offsets: HashMap, /// A [`HashMap`] from partition to number of messages consumed for each partition. @@ -133,7 +132,7 @@ mod tests { let mut buffers = ValueBuffers::default(); let mut add = |p, o| { buffers - .add(p, o, Value::String(format!("{}:{}", p, o))) + .add(p, o, serde_json::Value::String(format!("{}:{}", p, o))) .unwrap(); }; @@ -188,7 +187,7 @@ mod tests { #[test] fn value_buffers_conflict_offsets_test() { - let mut buffers = ValueBuffers::default(); + let mut buffers: ValueBuffers = ValueBuffers::default(); let verify_error = |res: Result<(), IngestError>, o: i64| { match res.err().unwrap() { @@ -234,7 +233,7 @@ mod tests { ); } - fn add(buffers: &mut ValueBuffers, offset: i64) -> Result<(), IngestError> { - buffers.add(0, offset, Value::Number(offset.into())) + fn add(buffers: &mut ValueBuffers, offset: i64) -> Result<(), IngestError> { + buffers.add(0, offset, serde_json::Value::Number(offset.into())) } } diff --git a/src/writer.rs b/src/writer.rs index 274e37f..94d5659 100644 --- a/src/writer.rs +++ b/src/writer.rs @@ -40,6 +40,7 @@ use std::time::{SystemTime, UNIX_EPOCH}; use uuid::Uuid; use crate::cursor::InMemoryWriteableCursor; +use crate::serialization::DeserializedMessage; const NULL_PARTITION_VALUE_DATA_PATH: &str = "__HIVE_DEFAULT_PARTITION__"; @@ -388,10 +389,15 @@ impl DataWriter { } /// Writes the given values to internal parquet buffers for each represented partition. - pub async fn write(&mut self, values: Vec) -> Result<(), Box> { + pub async fn write( + &mut self, + values: Vec, + ) -> Result<(), Box> { let mut partial_writes: Vec<(Value, ParquetError)> = Vec::new(); let arrow_schema = self.arrow_schema(); + let values = values.into_iter().map(|v| v.message()).collect(); + for (key, values) in self.divide_by_partition_values(values)? { match self.arrow_writers.get_mut(&key) { Some(writer) => collect_partial_write_failure( @@ -580,7 +586,7 @@ impl DataWriter { pub async fn insert_all( &mut self, table: &mut DeltaTable, - values: Vec, + values: Vec, ) -> Result> { self.write(values).await?; let mut adds = self.write_parquet_files(&table.table_uri()).await?; @@ -1184,10 +1190,11 @@ mod tests { let (table, _schema) = get_fresh_table(&temp_dir.path()).await; let mut writer = DataWriter::for_table(&table, HashMap::new()).unwrap(); - let rows: Vec = vec![json!({ + let rows: Vec = vec![json!({ "id" : "alpha", "value" : 1, - })]; + }) + .into()]; let result = writer.write(rows).await; assert!( result.is_ok(), @@ -1195,10 +1202,11 @@ mod tests { result ); - let rows: Vec = vec![json!({ + let rows: Vec = vec![json!({ "id" : 1, "value" : 1, - })]; + }) + .into()]; let result = writer.write(rows).await; assert!( result.is_err(), @@ -1208,15 +1216,17 @@ mod tests { } #[tokio::test] + #[ignore] async fn test_schema_strictness_with_additional_columns() { let temp_dir = tempfile::tempdir().unwrap(); let (mut table, _schema) = get_fresh_table(&temp_dir.path()).await; let mut writer = DataWriter::for_table(&table, HashMap::new()).unwrap(); - let rows: Vec = vec![json!({ + let rows: Vec = vec![json!({ "id" : "alpha", "value" : 1, - })]; + }) + .into()]; let result = writer.write(rows).await; assert!( result.is_ok(), @@ -1224,11 +1234,12 @@ mod tests { result ); - let rows: Vec = vec![json!({ + let rows: Vec = vec![json!({ "id" : "bravo", "value" : 2, "color" : "silver", - })]; + }) + .into()]; let result = writer.write(rows).await; assert!( result.is_ok(), @@ -1258,7 +1269,7 @@ mod tests { .await .unwrap(); let mut writer = DataWriter::for_table(&table, HashMap::new()).unwrap(); - let rows: Vec = vec![json!({ + let rows: Vec = vec![json!({ "meta": { "kafka": { "offset": 0, @@ -1275,7 +1286,8 @@ mod tests { // an error that gets interpreted as an EmptyRecordBatch "some_nested_list": [[42], [84]], "date": "2021-06-22" - })]; + }) + .into()]; let result = writer.write(rows).await; assert!( result.is_err(), @@ -1306,7 +1318,10 @@ mod tests { .unwrap(); let mut writer = DataWriter::for_table(&table, HashMap::new()).unwrap(); - writer.write(JSON_ROWS.clone()).await.unwrap(); + writer + .write(JSON_ROWS.clone().into_iter().map(|r| r.into()).collect()) + .await + .unwrap(); let add = writer .write_parquet_files(&table.table_uri()) .await diff --git a/tests/delta_partitions_tests.rs b/tests/delta_partitions_tests.rs index 0e872a2..b96ef16 100644 --- a/tests/delta_partitions_tests.rs +++ b/tests/delta_partitions_tests.rs @@ -3,6 +3,7 @@ mod helpers; use deltalake_core::kernel::{Action, Add}; use deltalake_core::protocol::{DeltaOperation, SaveMode}; +use kafka_delta_ingest::serialization::DeserializedMessage; use kafka_delta_ingest::writer::*; use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; @@ -122,10 +123,10 @@ async fn test_delta_partitions() { std::fs::remove_dir_all(&table_path).unwrap(); } -fn msgs_to_values(values: Vec) -> Vec { +fn msgs_to_values(values: Vec) -> Vec { values .iter() - .map(|j| serde_json::to_value(j).unwrap()) + .map(|j| serde_json::to_value(j).unwrap().into()) .collect() }