Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Run Name to Customer Eval Results Table #222

Merged
merged 3 commits into from
Jan 17, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 75 additions & 53 deletions journeys/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
"SEMANTIC_MODEL_STRING": "VARCHAR",
"EVAL_TABLE": "VARCHAR",
"EVAL_HASH": "VARCHAR",
"EVAL_RUN_NAME": "VARCHAR",
}

LLM_JUDGE_PROMPT_TEMPLATE = """\
Expand Down Expand Up @@ -79,56 +80,58 @@ def visualize_eval_results(frame: pd.DataFrame) -> None:
n_questions = len(frame)
n_correct = frame["CORRECT"].sum()
accuracy = (n_correct / n_questions) * 100
st.markdown(
f"###### Results: {n_correct} out of {n_questions} questions correct with accuracy {accuracy:.2f}%"
)
for row_id, row in frame.iterrows():
match_emoji = "✅" if row["CORRECT"] else "❌"
with st.expander(f"Row ID: {row_id} {match_emoji}"):
st.write(f"Input Query: {row['QUERY']}")
st.write(row["ANALYST_TEXT"].replace("\n", " "))

col1, col2 = st.columns(2)

try:
analyst_sql = sqlglot.parse_one(row["ANALYST_SQL"], dialect="snowflake")
analyst_sql = analyst_sql.sql(dialect="snowflake", pretty=True)
except Exception as e:
logger.warning(f"Error parsing analyst SQL: {e} for {row_id}")
analyst_sql = row["ANALYST_SQL"]

try:
gold_sql = sqlglot.parse_one(row["GOLD_SQL"], dialect="snowflake")
gold_sql = gold_sql.sql(dialect="snowflake", pretty=True)
except Exception as e:
logger.warning(f"Error parsing gold SQL: {e} for {row_id}")
gold_sql = row["GOLD_SQL"]

with col1:
st.write("Analyst SQL")
st.code(analyst_sql, language="sql")

with col2:
st.write("Golden SQL")
st.code(gold_sql, language="sql")

col1, col2 = st.columns(2)
with col1:
if isinstance(row["ANALYST_RESULT"], str):
st.error(row["ANALYST_RESULT"])
else:
st.write(row["ANALYST_RESULT"])

with col2:
if isinstance(row["GOLD_RESULT"], str):
st.error(row["GOLD_RESULT"])
else:
st.write(row["GOLD_RESULT"])

st.write(f"**Explanation**: {row['EXPLANATION']}")


def _llm_judge(frame: pd.DataFrame) -> pd.DataFrame:
results_placeholder = st.session_state.get("eval_results_placeholder")
with results_placeholder.container():
st.markdown(
f"###### Results: {n_correct} out of {n_questions} questions correct with accuracy {accuracy:.2f}%"
)
for row_id, row in frame.iterrows():
match_emoji = "✅" if row["CORRECT"] else "❌"
with st.expander(f"Row ID: {row_id} {match_emoji}"):
st.write(f"Input Query: {row['QUERY']}")
st.write(row["ANALYST_TEXT"].replace("\n", " "))

col1, col2 = st.columns(2)

try:
analyst_sql = sqlglot.parse_one(row["ANALYST_SQL"], dialect="snowflake")
analyst_sql = analyst_sql.sql(dialect="snowflake", pretty=True)
except Exception as e:
logger.warning(f"Error parsing analyst SQL: {e} for {row_id}")
analyst_sql = row["ANALYST_SQL"]

try:
gold_sql = sqlglot.parse_one(row["GOLD_SQL"], dialect="snowflake")
gold_sql = gold_sql.sql(dialect="snowflake", pretty=True)
except Exception as e:
logger.warning(f"Error parsing gold SQL: {e} for {row_id}")
gold_sql = row["GOLD_SQL"]

with col1:
st.write("Analyst SQL")
st.code(analyst_sql, language="sql")

with col2:
st.write("Golden SQL")
st.code(gold_sql, language="sql")

col1, col2 = st.columns(2)
with col1:
if isinstance(row["ANALYST_RESULT"], str):
st.error(row["ANALYST_RESULT"])
else:
st.write(row["ANALYST_RESULT"])

with col2:
if isinstance(row["GOLD_RESULT"], str):
st.error(row["GOLD_RESULT"])
else:
st.write(row["GOLD_RESULT"])

st.write(f"**Explanation**: {row['EXPLANATION']}")


def _llm_judge(frame: pd.DataFrame, max_frame_size = 200) -> pd.DataFrame:

if frame.empty:
return pd.DataFrame({"EXPLANATION": [], "CORRECT": []})
Expand All @@ -140,8 +143,8 @@ def _llm_judge(frame: pd.DataFrame) -> pd.DataFrame:
axis=1,
func=lambda x: LLM_JUDGE_PROMPT_TEMPLATE.format(
input_question=x["QUERY"],
frame1_str=x["ANALYST_RESULT"].to_string(index=False),
frame2_str=x["GOLD_RESULT"].to_string(index=False),
frame1_str=x["ANALYST_RESULT"][:max_frame_size].to_string(index=False),
frame2_str=x["GOLD_RESULT"][:max_frame_size].to_string(index=False),
),
).to_frame(name=col_name)
session = st.session_state["session"]
Expand Down Expand Up @@ -289,6 +292,7 @@ def write_eval_results(frame: pd.DataFrame) -> None:
frame_to_write = frame.copy()
frame_to_write["TIMESTAMP"] = st.session_state["eval_timestamp"]
frame_to_write["EVAL_HASH"] = st.session_state["eval_hash"]
frame_to_write["EVAL_RUN_NAME"] = st.session_state["eval_run_name"]
frame_to_write["EVAL_TABLE"] = st.session_state["eval_table"]
frame_to_write["EVAL_TABLE_HASH"] = st.session_state["eval_table_hash"]
frame_to_write["MODEL_HASH"] = st.session_state["semantic_model_hash"]
Expand Down Expand Up @@ -353,7 +357,9 @@ def run_sql_queries() -> None:

for i, (row_id, row) in enumerate(eval_table_frame.iterrows(), start=1):
status_text.text(f"Evaluating Analyst query {i}/{total_requests}...")



analyst_query = analyst_results_frame.loc[row_id, "ANALYST_SQL"]
analyst_result = execute_query(
conn=get_snowflake_connection(), query=analyst_query
Expand Down Expand Up @@ -586,6 +592,7 @@ def clear_evaluation_selection() -> None:
"selected_results_eval_old_table",
"selected_results_eval_schema",
"use_existing_eval_results_table",
"selected_eval_run_name",
)
for feature in session_states:
if feature in st.session_state:
Expand All @@ -598,6 +605,9 @@ def clear_evaluation_data() -> None:
"eval_accuracy",
"analyst_results_frame",
"query_results_frame",
"eval_run_name",
"eval_timestamp",
"eval_hash",
)
for feature in session_states:
if feature in st.session_state:
Expand All @@ -612,6 +622,7 @@ def evaluation_mode_show() -> None:
st.write(
"Welcome!🧪 In the evaluation mode you can evaluate your semantic model using pairs of golden queries/questions and their expected SQL statements. These pairs should be captured in an **Evaluation Table**. Accuracy metrics will be shown and the results will be stored in an **Evaluation Results Table**."
)
st.text_input("Evaluation Run Name", key="selected_eval_run_name", value= st.session_state.get("selected_eval_run_name", ""))

# TODO: find a less awkward way of specifying this.
if any(key not in st.session_state for key in ("eval_table", "results_eval_table")):
Expand All @@ -631,12 +642,15 @@ def evaluation_mode_show() -> None:
if st.button("Run Evaluation"):
run_evaluation()



if "total_eval_frame" in st.session_state:
current_hash = generate_hash(st.session_state["working_yml"])
model_changed_test = current_hash != st.session_state["semantic_model_hash"]

evolution_run_summary = pd.DataFrame(
[
["Evaluation Run Name", st.session_state["eval_run_name"]],
["Evaluation Table Hash", st.session_state["eval_table_hash"]],
["Semantic Model Hash", st.session_state["semantic_model_hash"]],
["Evaluation Run Hash", st.session_state["eval_hash"]],
Expand All @@ -651,6 +665,7 @@ def evaluation_mode_show() -> None:
else:
st.markdown("#### Current Evaluation Run Summary")
st.dataframe(evolution_run_summary, hide_index=True)
st.session_state["eval_results_placeholder"] = st.empty()
visualize_eval_results(st.session_state["total_eval_frame"])


Expand Down Expand Up @@ -684,8 +699,15 @@ def run_evaluation() -> None:
return
placeholder.write("Model validated ✅")
clear_evaluation_data()
st.session_state["eval_run_name"] = st.session_state["selected_eval_run_name"]
st.session_state["semantic_model_hash"] = current_hash
st.write("Running evaluation...")
if st.session_state["eval_run_name"] == "":
st.write("Running evaluation ...")
else:
st.write(f"Running evaluation for name: {st.session_state['eval_run_name']} ...")
if "eval_results_placeholder" in st.session_state:
results_placeholder = st.session_state["eval_results_placeholder"]
results_placeholder.empty()
st.session_state["eval_timestamp"] = time.strftime("%Y-%m-%d %H:%M:%S")
st.session_state["eval_hash"] = generate_hash(st.session_state["eval_timestamp"])
send_analyst_requests()
Expand Down