Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
mooreryan committed Jan 6, 2025
1 parent 10410ea commit 2253316
Show file tree
Hide file tree
Showing 3 changed files with 144 additions and 131 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,6 @@
"outputs": [],
"source": [
"def get_metrics(true_labels, predicted_labels):\n",
"\n",
" print(\n",
" \"Accuracy:\", np.round(metrics.accuracy_score(true_labels, predicted_labels), 4)\n",
" )\n",
Expand All @@ -194,15 +193,13 @@
"\n",
"\n",
"def display_classification_report(true_labels, predicted_labels, classes=[1, 0]):\n",
"\n",
" report = metrics.classification_report(\n",
" y_true=true_labels, y_pred=predicted_labels, labels=classes\n",
" )\n",
" print(report)\n",
"\n",
"\n",
"def display_confusion_matrix(true_labels, predicted_labels, classes=[1, 0]):\n",
"\n",
" total_classes = len(classes)\n",
" level_labels = [total_classes * [0], list(range(total_classes))]\n",
" cm = metrics.confusion_matrix(\n",
Expand Down Expand Up @@ -235,7 +232,6 @@
"def plot_model_roc_curve(\n",
" clf, features, true_labels, label_encoder=None, class_names=None\n",
"):\n",
"\n",
" ## Compute ROC curve and ROC area for each class\n",
" fpr = dict()\n",
" tpr = dict()\n",
Expand Down Expand Up @@ -309,25 +305,28 @@
" plt.plot(\n",
" fpr[\"micro\"],\n",
" tpr[\"micro\"],\n",
" label=\"micro-average ROC curve (area = {0:0.2f})\"\n",
" \"\".format(roc_auc[\"micro\"]),\n",
" label=\"micro-average ROC curve (area = {0:0.2f})\" \"\".format(\n",
" roc_auc[\"micro\"]\n",
" ),\n",
" linewidth=3,\n",
" )\n",
"\n",
" plt.plot(\n",
" fpr[\"macro\"],\n",
" tpr[\"macro\"],\n",
" label=\"macro-average ROC curve (area = {0:0.2f})\"\n",
" \"\".format(roc_auc[\"macro\"]),\n",
" label=\"macro-average ROC curve (area = {0:0.2f})\" \"\".format(\n",
" roc_auc[\"macro\"]\n",
" ),\n",
" linewidth=3,\n",
" )\n",
"\n",
" for i, label in enumerate(class_labels):\n",
" plt.plot(\n",
" fpr[i],\n",
" tpr[i],\n",
" label=\"ROC curve of class {0} (area = {1:0.2f})\"\n",
" \"\".format(label, roc_auc[i]),\n",
" label=\"ROC curve of class {0} (area = {1:0.2f})\" \"\".format(\n",
" label, roc_auc[i]\n",
" ),\n",
" linewidth=2,\n",
" linestyle=\":\",\n",
" )\n",
Expand All @@ -354,7 +353,6 @@
" alphas=None,\n",
" colors=None,\n",
"):\n",
"\n",
" if train_features.shape[1] != 2:\n",
" raise ValueError(\"X_train should have exactly 2 columnns!\")\n",
"\n",
Expand Down Expand Up @@ -1243,7 +1241,11 @@
]
}
],
"metadata": {},
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Loading

0 comments on commit 2253316

Please sign in to comment.