Commit 59db3d78 authored by hugopiq's avatar hugopiq
Browse files

confusion matrix ANN python

parent 87c60971
This diff is collapsed.
......@@ -9,7 +9,7 @@ from sklearn.preprocessing import StandardScaler
from sklearn import preprocessing
from matplotlib import pyplot as plt
from sklearn.metrics import ConfusionMatrixDisplay
from sklearn import metrics
from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix
def train_ANN(saveWeighs=False, saveAlgo=False):
......@@ -66,10 +66,14 @@ def train_ANN(saveWeighs=False, saveAlgo=False):
# plt.grid(True)
# plt.gca().set_ylim(0, 2) # set the vertical range to [0-1]
# plt.show()
predicted_labels = model.predict(X_test)
# Confusion matrix
predicted_labels = np.argmax(model.predict(X_test), axis=1)
print(confusion_matrix(labelInd_test, predicted_labels))
ConfusionMatrixDisplay.from_predictions(
labelInd_test, predicted_labels, display_labels=classes)
plt.title("My predictions")
plt.show()
# Save weights to a .h file
list_weights = []
for layer in model.layers:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment