Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in
Toggle navigation
Menu
Open sidebar
Estelle ARRICAU
Embedded Machine Learning
Commits
59db3d78
Commit
59db3d78
authored
Jan 11, 2022
by
hugopiq
Browse files
confusion matrix ANN python
parent
87c60971
Changes
2
Expand all
Hide whitespace changes
Inline
Side-by-side
ANN/ANNWeight.h
View file @
59db3d78
This diff is collapsed.
Click to expand it.
ANN/Python/ANN_training.py
View file @
59db3d78
...
...
@@ -9,7 +9,7 @@ from sklearn.preprocessing import StandardScaler
from
sklearn
import
preprocessing
from
matplotlib
import
pyplot
as
plt
from
sklearn.metrics
import
ConfusionMatrixDisplay
from
sklearn
import
metrics
from
sklearn
.metrics
import
ConfusionMatrixDisplay
,
confusion_matrix
def
train_ANN
(
saveWeighs
=
False
,
saveAlgo
=
False
):
...
...
@@ -66,10 +66,14 @@ def train_ANN(saveWeighs=False, saveAlgo=False):
# plt.grid(True)
# plt.gca().set_ylim(0, 2) # set the vertical range to [0-1]
# plt.show()
predicted_labels
=
model
.
predict
(
X_test
)
# Confusion matrix
predicted_labels
=
np
.
argmax
(
model
.
predict
(
X_test
),
axis
=
1
)
print
(
confusion_matrix
(
labelInd_test
,
predicted_labels
))
ConfusionMatrixDisplay
.
from_predictions
(
labelInd_test
,
predicted_labels
,
display_labels
=
classes
)
plt
.
title
(
"My predictions"
)
plt
.
show
()
# Save weights to a .h file
list_weights
=
[]
for
layer
in
model
.
layers
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment