Skip to content

Commit

Permalink
Ita grounding analysis (#64)
Browse files Browse the repository at this point in the history
* started setting it up, needs some ironing out but almost good to go

* finished the creation of the table, forgot I was in this branch and also fixed #63 here
  • Loading branch information
leokim-l authored Nov 29, 2024
1 parent 6586222 commit fe51a68
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 9 deletions.
79 changes: 79 additions & 0 deletions src/malco/analysis/ita_grounding_analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from malco.post_process.post_process_results_format import read_raw_result_yaml
from pathlib import Path
import pandas as pd
import os
# Each row has
# c1 * c2 * c3 * c4 * c5 * c6 * c7 * c8
# PMID (str) * label/term (str) * * rank * ita_reply (bool) * correct_result OMIM ID * correct_result OMIM label * MONDO label (if applicable) * correct? 0/1 (in excel)

# Correct results
file = "/Users/leonardo/git/malco/in_ita_reply/correct_results.tsv"
answers = pd.read_csv(
file, sep="\t", header=None, names=["description", "term", "label"]
)

# Mapping each label to its correct term
cres = answers.set_index("label").to_dict()

# Just populate df with two for loops, then sort alfabetically
data = []

# load ita replies
ita_file = Path("/Users/leonardo/git/malco/out_itanoeng/raw_results/multilingual/it/results.yaml")
ita_result = read_raw_result_yaml(ita_file)

# extract input_text from yaml for ita, or extracted_object, terms
for ppkt_out in ita_result:
extracted_object = ppkt_out.get("extracted_object")
if extracted_object:
label = extracted_object.get("label").replace('_it-prompt', '_en-prompt')
terms = extracted_object.get("terms")
if terms:
num_terms = len(terms)
rank_list = [i + 1 for i in range(num_terms)]
for term, rank in zip(terms, rank_list):
data.append({"pubmedid": label, "term": term, "mondo_label": float('Nan'), "rank": rank, "ita_reply": True, "correct_omim_id": cres['term'][label],
"correct_omim_description": cres['description'][label]})


# load eng replies
eng_file = Path("/Users/leonardo/git/malco/out_itanoeng/raw_results/multilingual/it_w_en/results.yaml")
eng_result = read_raw_result_yaml(eng_file)

# extract named_entities, id and label from yaml for eng
# extract input_text from yaml for ita, or extracted_object, terms
for ppkt_out in eng_result:
extracted_object = ppkt_out.get("extracted_object")
if extracted_object:
label = extracted_object.get("label").replace('_it-prompt', '_en-prompt')
terms = extracted_object.get("terms")
if terms:
num_terms = len(terms)
rank_list = [i + 1 for i in range(num_terms)]
for term, rank in zip(terms, rank_list):
if term.startswith("MONDO"):
ne = ppkt_out.get("named_entities")
for entity in ne:
if entity.get('id')==term:
mlab = entity.get('label')
else:
mlab = float('Nan')

data.append({"pubmedid": label, "term": mlab, "mondo_label": term, "rank": rank, "ita_reply": False, "correct_omim_id": cres["term"][label],
"correct_omim_description": cres['description'][label]})

# Create DataFrame
column_names = [
"PMID",
"GPT Diagnosis",
"MONDO ID",
"rank",
"ita_reply",
"correct_OMIMid",
"correct_OMIMlabel",
]

df = pd.DataFrame(data)
df.columns = column_names
df.sort_values(by = ['PMID', 'ita_reply', 'rank'], inplace=True)
#df.to_excel(os.getcwd() + "ita_replies2curate.xlsx") # does not work, wrong path, not important
9 changes: 8 additions & 1 deletion src/malco/post_process/post_process_results_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,14 @@ def create_standardised_results(
)
# terms will now ONLY contain MONDO IDs OR 'N/A'.
# The latter should be dealt with downstream
terms = [i[1][0][0] for i in result] # MONDO_ID
new_terms = []
for i in result:
if i[1] == [("N/A", "No grounding found")]:
new_terms.append(i[0])
else:
new_terms.append(i[1][0][0])
terms = new_terms
#terms = [i[1][0][0] for i in result] # MONDO_ID
if terms:
# Note, the if allows for rerunning ppkts that failed due to connection issues
# We can have multiple identical ppkts/prompts in results.yaml
Expand Down
23 changes: 15 additions & 8 deletions src/malco/post_process/ranking_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def compute_mrr_and_ranks(
"n10p",
"nf",
"num_cases",
"grounding_failed", # and no correct reply elsewhere in the differential
]
rank_df = pd.DataFrame(0, index=np.arange(len(results_files)), columns=header)

Expand Down Expand Up @@ -143,6 +144,7 @@ def compute_mrr_and_ranks(
)

df.dropna(subset=["correct_term"])

# Save full data frame
full_df_path = output_dir / results_files[i].split("/")[0]
full_df_filename = "full_df_results.tsv"
Expand All @@ -155,14 +157,17 @@ def compute_mrr_and_ranks(
# Calculate top<n> of each rank
rank_df.loc[i, comparing] = results_files[i].split("/")[0]

ppkts = df.groupby("label")[["rank", "is_correct"]]
ppkts = df.groupby("label")[["term", "rank", "is_correct"]]

# for each group
for ppkt in ppkts:
# is there a true? ppkt is tuple ("filename", dataframe) --> ppkt[1] is a dataframe
if not any(ppkt[1]["is_correct"]):
# no --> increase nf = "not found"
rank_df.loc[i, "nf"] += 1
if all(ppkt[1]["term"].str.startswith("MONDO")):
# no --> increase nf = "not found"
rank_df.loc[i, "nf"] += 1
else:
rank_df.loc[i, "grounding_failed"] += 1
else:
# yes --> what's it rank? It's <j>
jind = ppkt[1].index[ppkt[1]["is_correct"]]
Expand Down Expand Up @@ -204,10 +209,12 @@ def compute_mrr_and_ranks(
writer.writerow(results_files)
writer.writerow(mrr_scores)

# TODO this could be moved in an anaysis script with the plotting...
df = pd.read_csv(topn_file, delimiter="\t")
df["top1"] = (df["n1"]) / df["num_cases"]
df["top3"] = (df["n1"] + df["n2"] + df["n3"]) / df["num_cases"]
df["top5"] = (df["n1"] + df["n2"] + df["n3"] + df["n4"] + df["n5"]) / df["num_cases"]
valid_cases = df["num_cases"] - df["grounding_failed"]
df["top1"] = (df["n1"]) / valid_cases
df["top3"] = (df["n1"] + df["n2"] + df["n3"]) / valid_cases
df["top5"] = (df["n1"] + df["n2"] + df["n3"] + df["n4"] + df["n5"]) / valid_cases
df["top10"] = (
df["n1"]
+ df["n2"]
Expand All @@ -219,8 +226,8 @@ def compute_mrr_and_ranks(
+ df["n8"]
+ df["n9"]
+ df["n10"]
) / df["num_cases"]
df["not_found"] = (df["nf"]) / df["num_cases"]
) / valid_cases
df["not_found"] = (df["nf"]) / valid_cases

df_aggr = pd.DataFrame()
df_aggr = pd.melt(
Expand Down

0 comments on commit fe51a68

Please sign in to comment.