Skip to content

Commit

Permalink
ready for review
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-kschmaus committed Jan 16, 2025
1 parent 207206b commit f86f5b2
Show file tree
Hide file tree
Showing 8 changed files with 127 additions and 123 deletions.
5 changes: 2 additions & 3 deletions app_utils/shared_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1327,7 +1327,7 @@ def run_generate_model_str_from_snowflake(
model_name: str,
sample_values: int,
base_tables: list[str],
allow_joins: Optional[bool] = False,
allow_joins: bool = False,
) -> None:
"""
Runs generate_model_str_from_snowflake to generate cortex semantic shell.
Expand Down Expand Up @@ -1357,15 +1357,14 @@ def run_generate_model_str_from_snowflake(
st.session_state["yaml"] = yaml_str


# TODO(kschmaus): I still need to use this!
def create_cortex_search_service(
conn: SnowflakeConnection,
service_name: str,
column_name: str,
table_fqn: str,
warehouse_name: str,
target_lag: str,
):
) -> None:
query = f"""
CREATE OR REPLACE CORTEX SEARCH SERVICE {service_name}
ON {column_name}
Expand Down
92 changes: 42 additions & 50 deletions journeys/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,13 @@
import streamlit as st
from loguru import logger

from app_utils.shared_utils import create_cortex_search_service
from semantic_model_generator.data_processing import proto_utils
from semantic_model_generator.generate_model import append_comment_to_placeholders
from semantic_model_generator.generate_model import comment_out_section
from semantic_model_generator.generate_model import context_to_yaml
from semantic_model_generator.generate_model import get_table_representations
from semantic_model_generator.generate_model import translate_data_class_tables_to_model_protobuf
from semantic_model_generator.protos import semantic_model_pb2
from snowflake.connector import ProgrammingError
from streamlit_extras.tags import tagger_component
Expand Down Expand Up @@ -176,25 +180,14 @@ def generate_table_configs() -> None:
with st.spinner(
"Writing Semantic Model Descriptions Using AI (this may take a moment) ..."
):
tables = []
conn = get_snowflake_connection()
for table_fqn in st.session_state.get("selected_tables", []):
columns_df = get_valid_schemas_tables_columns_df(
conn=conn,
table_fqn=table_fqn,
)
table = get_table_representation(
session=st.session_state["session"],
table_fqn=table_fqn,
max_string_sample_values=16,
columns_df=columns_df,
max_workers=1,
)
tables.append(table)
base_tables = st.session_state["selected_tables"]
st.session_state["tables"] = get_table_representations(
conn=conn, base_tables=base_tables
)

# We will need warehouses for search integration.
st.session_state["available_warehouses"] = get_available_warehouses()
st.session_state["tables"] = tables


@st.experimental_fragment()
Expand Down Expand Up @@ -224,6 +217,7 @@ def _add_search_form(column: str) -> None:
options=get_available_warehouses(),
key=f"cortex_search_warehouse_name_{column}",
)
assert warehouse_name is not None
target_lag = st.text_input(
label="Target Lag",
value="1 hour",
Expand Down Expand Up @@ -264,6 +258,7 @@ def _remove_search_form(column: str) -> None:
options=possible_columns,
key="cortex_search_possible_columns",
)
assert current_column is not None
if current_column not in cortex_search_configs:
_add_search_form(current_column)
else:
Expand Down Expand Up @@ -292,42 +287,32 @@ def _remove_search_form(column: str) -> None:
literal_column=config.column_name,
)

# TODO(kschmaus) create cortex search instances, use spinner ...

st.rerun()
# TODO(kschmaus): this should probably it's own function ...
def create_cortex_search_services() -> None:
conn = get_snowflake_connection()
cortex_search_configs = st.session_state["cortex_search_configs"]
with st.spinner("Create Cortex Search Services ..."):
config: CortexSearchConfig
for config in cortex_search_configs.values():
create_cortex_search_service(
conn=conn,
service_name=config.service_name,
column_name=config.column_name,
table_fqn=config.table_fqn,
warehouse_name=config.warehouse_name,
target_lag=config.target_lag
)

# TODO(kschmaus): at a certain point this shouldn't live in builder ...
placeholder_relationships = (
_get_placeholder_joins()
if st.session_state["enable_joins"]
else None
)
table_objects = [
_raw_table_to_semantic_context_table(raw_table=table) for table in tables
]
context = semantic_model_pb2.SemanticModel(
name=st.session_state["semantic_model_name"],
tables=table_objects,
relationships=placeholder_relationships,
)
# Validate the generated yaml is within context limits.
# We just throw a warning here to allow users to update.
validate_context_length(context)

yaml_str = proto_utils.proto_to_yaml(context)
# Once we have the yaml, update to include to # <FILL-OUT> tokens.
yaml_str = append_comment_to_placeholders(yaml_str)
# Comment out the filters section as we don't have a way to auto-generate these yet.
yaml_str = comment_out_section(yaml_str, "filters")
yaml_str = comment_out_section(yaml_str, "relationships")
st.session_state["yaml"] = yaml_str

st.code(st.session_state["yaml"])

# TODO(kschmaus): this ends things
st.session_state["page"] = GeneratorAppScreen.ITERATION
st.rerun()

def save_yaml() -> None:
tables = st.session_state["tables"]
context = translate_data_class_tables_to_model_protobuf(
raw_tables=tables,
semantic_model_name=st.session_state["semantic_model_name"],
allow_joins=st.session_state["enable_joins"],
)
yaml_string = context_to_yaml(context)
st.session_state["yaml"] = yaml_string


def show() -> None:
Expand All @@ -351,4 +336,11 @@ def show() -> None:
if not st.session_state["build_semantic_model"]:
st.stop()

# TODO(kschmaus): final function?
# (Re)Create cortex search services.
create_cortex_search_services()

# Save YAML model to session state.
save_yaml()

st.session_state["page"] = GeneratorAppScreen.ITERATION
st.rerun()
1 change: 0 additions & 1 deletion journeys/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,6 @@ def _llm_judge(frame: pd.DataFrame) -> pd.DataFrame:
if frame.empty:
return pd.DataFrame({"EXPLANATION": [], "CORRECT": []})

# TODO(kschmaus): abstract batch complete call <----
prompt_frame = frame.apply(
axis=1,
func=lambda x: LLM_JUDGE_PROMPT_TEMPLATE.format(
Expand Down
29 changes: 15 additions & 14 deletions journeys/iteration.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import time
from typing import Any, Dict, List, Optional

import loguru
import pandas as pd
import sqlglot
import streamlit as st
Expand Down Expand Up @@ -491,20 +492,20 @@ def yaml_editor(yaml_str: str) -> None:

def validate_and_update_session_state() -> None:
# Validate new content
# try:
validate(
content,
conn=get_snowflake_connection(),
)
st.session_state["validated"] = True
update_container(status_container, "success", prefix=status_container_title)
st.session_state.semantic_model = yaml_to_semantic_model(content)
st.session_state.last_saved_yaml = content
# except Exception as e:
# loguru.logger.error(e)
# st.session_state["validated"] = False
# update_container(status_container, "failed", prefix=status_container_title)
# exception_as_dialog(e)
try:
validate(
content,
conn=get_snowflake_connection(),
)
st.session_state["validated"] = True
update_container(status_container, "success", prefix=status_container_title)
st.session_state.semantic_model = yaml_to_semantic_model(content)
st.session_state.last_saved_yaml = content
except Exception as e:
loguru.logger.error(e)
st.session_state["validated"] = False
update_container(status_container, "failed", prefix=status_container_title)
exception_as_dialog(e)

button_row = row(5)
if button_row.button("Validate", use_container_width=True, help=VALIDATE_HELP):
Expand Down
1 change: 0 additions & 1 deletion semantic_model_generator/data_processing/data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ class Column:
values: Optional[List[str]] = None
# comment field's to save the column comment user specified on the column
comment: Optional[str] = None
# TODO(kschmaus): this probably doesn't belong here.
cortex_search_service: Optional[CortexSearchService] = None

def __post_init__(self: Any) -> None:
Expand Down
2 changes: 1 addition & 1 deletion semantic_model_generator/data_processing/proto_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def proto_to_yaml(message: ProtoMsg) -> str:
# Using ruamel.yaml package to preserve message order.
yaml = ruamel.yaml.YAML()
yaml.indent(mapping=2, sequence=4, offset=2)
yaml.preserve_quotes = True # type: ignore[assignment]
yaml.preserve_quotes = True

with io.StringIO() as stream:
yaml.dump(json_data, stream)
Expand Down
117 changes: 65 additions & 52 deletions semantic_model_generator/generate_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from snowflake.snowpark import Session

from semantic_model_generator.data_processing import data_types, proto_utils
from semantic_model_generator.data_processing.data_types import Table
from semantic_model_generator.protos import semantic_model_pb2
from semantic_model_generator.snowflake_utils.snowflake_connector import (
AUTOGEN_TOKEN,
Expand Down Expand Up @@ -151,18 +152,6 @@ def _raw_table_to_semantic_context_table(
sample_values=col.values,
)
)
# elif col.column_type.upper() in DIMENSION_DATATYPES:
# print("HELLO!!!!!!", col.column_type.upper())
# dimensions.append(
# semantic_model_pb2.Dimension(
# name=col.column_name,
# expr=col.column_name,
# data_type=col.column_type,
# sample_values=col.values,
# synonyms=[_PLACEHOLDER_COMMENT],
# description=col.comment if col.comment else _PLACEHOLDER_COMMENT,
# )
# )
if len(time_dimensions) + len(dimensions) + len(measures) == 0:
raise ValueError(
f"No valid columns found for table {raw_table.name}. Please verify that this table contains column's datatypes not in {OBJECT_DATATYPES}."
Expand All @@ -183,11 +172,49 @@ def _raw_table_to_semantic_context_table(
)


def get_table_representations(conn: SnowflakeConnection, base_tables: List[str]) -> List[Table]:
raw_tables = []
for table_fqn in base_tables:
columns_df = get_valid_schemas_tables_columns_df(conn=conn, table_fqn=table_fqn)
assert not columns_df.empty

session = Session.builder.configs({"connection": conn}).create()
raw_table = get_table_representation(
session=session,
table_fqn=table_fqn,
max_string_sample_values=16,
columns_df=columns_df,
max_workers=1,
)
raw_tables.append(raw_table)
return raw_tables


def translate_data_class_tables_to_model_protobuf(
raw_tables: List[Table],
semantic_model_name: str,
allow_joins: bool = False,
) -> semantic_model_pb2.SemanticModel:
table_objects = []
for raw_table in raw_tables:
table_object = _raw_table_to_semantic_context_table(raw_table=raw_table)
table_objects.append(table_object)
# TODO(jhilgart): Call cortex model to generate a semantically friendly name here.

placeholder_relationships = _get_placeholder_joins() if allow_joins else None
context = semantic_model_pb2.SemanticModel(
name=semantic_model_name,
tables=table_objects,
relationships=placeholder_relationships,
)
return context


def raw_schema_to_semantic_context(
base_tables: List[str],
semantic_model_name: str,
conn: SnowflakeConnection,
allow_joins: Optional[bool] = False,
allow_joins: bool = False,
) -> semantic_model_pb2.SemanticModel:
"""
Converts a list of fully qualified Snowflake table names into a semantic model.
Expand All @@ -210,32 +237,14 @@ def raw_schema_to_semantic_context(
"""

# For FQN tables, create a new snowflake connection per table in case the db/schema is different.
table_objects = []
raw_tables = []
for table_fqn in base_tables:
columns_df = get_valid_schemas_tables_columns_df(conn=conn, table_fqn=table_fqn)
assert not columns_df.empty

# TODO(kschmaus): clean this up, it's pretty awkward.
session = Session.builder.configs({"connection": conn}).create()
raw_table = get_table_representation(
session=session,
table_fqn=table_fqn,
max_string_sample_values=16,
columns_df=columns_df,
max_workers=1,
)
raw_tables.append(raw_table)

table_object = _raw_table_to_semantic_context_table(raw_table=raw_table)
table_objects.append(table_object)
# TODO(jhilgart): Call cortex model to generate a semantically friendly name here.

placeholder_relationships = _get_placeholder_joins() if allow_joins else None
context = semantic_model_pb2.SemanticModel(
name=semantic_model_name,
tables=table_objects,
relationships=placeholder_relationships,
raw_tables = get_table_representations(
conn=conn,
base_tables=base_tables,
)
context = translate_data_class_tables_to_model_protobuf(
raw_tables=raw_tables,
semantic_model_name=semantic_model_name,
allow_joins=allow_joins
)
return context

Expand Down Expand Up @@ -394,12 +403,26 @@ def generate_base_semantic_model_from_snowflake(
return None


#TODO(kschmaus): repair?
def context_to_yaml(context: semantic_model_pb2.SemanticModel) -> str:
# Validate the generated yaml is within context limits.
# We just throw a warning here to allow users to update.
validate_context_length(context)

yaml_str = proto_utils.proto_to_yaml(context)
# Once we have the yaml, update to include to # <FILL-OUT> tokens.
yaml_str = append_comment_to_placeholders(yaml_str)
# Comment out the filters section as we don't have a way to auto-generate these yet.
yaml_str = comment_out_section(yaml_str, "filters")
yaml_str = comment_out_section(yaml_str, "relationships")

return yaml_str


def generate_model_str_from_snowflake(
base_tables: List[str],
semantic_model_name: str,
conn: SnowflakeConnection,
allow_joins: Optional[bool] = False,
allow_joins: bool = False,
) -> str:
"""
Generates a base semantic context from specified Snowflake tables and returns the raw string.
Expand All @@ -419,15 +442,5 @@ def generate_model_str_from_snowflake(
allow_joins=allow_joins,
conn=conn,
)
# Validate the generated yaml is within context limits.
# We just throw a warning here to allow users to update.
validate_context_length(context)

yaml_str = proto_utils.proto_to_yaml(context)
# Once we have the yaml, update to include to # <FILL-OUT> tokens.
yaml_str = append_comment_to_placeholders(yaml_str)
# Comment out the filters section as we don't have a way to auto-generate these yet.
yaml_str = comment_out_section(yaml_str, "filters")
yaml_str = comment_out_section(yaml_str, "relationships")

yaml_str = context_to_yaml(context)
return yaml_str
Loading

0 comments on commit f86f5b2

Please sign in to comment.