Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Use cassandra connector #164

Open
wants to merge 21 commits into
base: master
Choose a base branch
from
107 changes: 104 additions & 3 deletions sdk/python/feast/infra/contrib/spark_kafka_processor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,17 @@
import time
from datetime import datetime
from types import MethodType
from typing import List, Optional, Set, Union, no_type_check
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Set,
Tuple,
Union,
no_type_check,
)

import pandas as pd
import pyarrow
Expand All @@ -10,8 +21,15 @@
from pyspark.sql.column import Column, _to_java_column
from pyspark.sql.functions import col, from_json
from pyspark.sql.streaming import StreamingQuery
from pyspark.sql.types import (
BinaryType,
StringType,
StructField,
StructType,
TimestampType,
)

from feast import FeatureView
from feast import FeatureView, RepoConfig
from feast.data_format import AvroFormat, ConfluentAvroFormat, JsonFormat, StreamFormat
from feast.data_source import KafkaSource, PushMode
from feast.feature_store import FeatureStore
Expand All @@ -20,10 +38,16 @@
StreamProcessor,
StreamTable,
)
from feast.infra.key_encoding_utils import serialize_entity_key
from feast.infra.materialization.contrib.spark.spark_materialization_engine import (
_SparkSerializedArtifacts,
)
from feast.infra.online_stores.contrib.cassandra_online_store.cassandra_online_store import (
CassandraOnlineStore,
)
from feast.infra.provider import get_provider
from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto
from feast.protos.feast.types.Value_pb2 import Value as ValueProto
from feast.stream_feature_view import StreamFeatureView
from feast.utils import _convert_arrow_to_proto, _run_pyarrow_field_mapping

Expand Down Expand Up @@ -272,6 +296,83 @@ def _write_stream_data_expedia(self, df: StreamTable, to: PushMode):
# TODO: Support writing to offline store and preprocess_fn. Remove _write_stream_data method

# Validation occurs at the fs.write_to_online_store() phase against the stream feature view schema.
def online_write_with_connector(
config: RepoConfig,
table: FeatureView,
data: List[
Tuple[
EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]
]
],
progress: Optional[Callable[[int], Any]],
) -> None:
"""
Write a batch of features of several entities to the database using Spark Cassandra Connector.

Args:
config: The RepoConfig for the current FeatureStore.
table: Feast FeatureView.
data: a list of quadruplets containing Feature data. Each
quadruplet contains an Entity Key, a dict containing feature
values, an event timestamp for the row, and
the created timestamp for the row if it exists.
progress: Optional function to be called once every mini-batch of
rows is written to the online store. Can be used to
display progress.
"""
keyspace = config.online_store.keyspace

fqtable = CassandraOnlineStore._fq_table_name(
keyspace, config.project, table
)
cassandra_keyspace = keyspace
cassandra_table = fqtable

def create_spark_dataframe():
"""
Convert the data into a Spark DataFrame.
"""
rows = []
for entity_key, values, timestamp, created_ts in data:
entity_key_bin = serialize_entity_key(
entity_key,
entity_key_serialization_version=config.entity_key_serialization_version,
).hex()
for feature_name, val in values.items():
rows.append(
(
feature_name,
val.SerializeToString(),
entity_key_bin,
timestamp,
created_ts,
)
)

schema = StructType(
[
StructField("feature_name", StringType(), False),
StructField("feature_value", BinaryType(), False),
StructField("entity_key", StringType(), False),
StructField("event_timestamp", TimestampType(), False),
StructField("created_timestamp", TimestampType(), True),
]
)

return self.spark.createDataFrame(rows, schema)

# Create a DataFrame from the input data
df = create_spark_dataframe()

# Write DataFrame to Cassandra
df.write.format("org.apache.spark.sql.cassandra").options(
keyspace=cassandra_keyspace, table=cassandra_table
).mode("append").save()

# Call progress function if provided
if progress:
progress(len(data))

def batch_write_pandas_df(iterator, spark_serialized_artifacts, join_keys):
for pdf in iterator:
(
Expand Down Expand Up @@ -305,7 +406,7 @@ def batch_write_pandas_df(iterator, spark_serialized_artifacts, join_keys):
rows_to_write = _convert_arrow_to_proto(
table, feature_view, join_key_to_value_type
)
online_store.online_write_batch(
online_write_with_connector(
repo_config,
feature_view,
rows_to_write,
Expand Down
Loading