Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in
Toggle navigation
Menu
Open sidebar
Estelle ARRICAU
Embedded Machine Learning
Commits
328235bf
Commit
328235bf
authored
Jan 09, 2022
by
hugopiq
Browse files
ok
parent
213f93f6
Changes
1
Hide whitespace changes
Inline
Side-by-side
ANN/Python/ANN_training.py
View file @
328235bf
...
...
@@ -11,11 +11,9 @@ from matplotlib import pyplot as plt
def
train_ANN
(
saveWeighs
=
False
,
saveAlgo
=
False
):
# dataset = r'features.csv'
dataset
=
"build/Extraction/features.csv"
df
=
pd
.
read_csv
(
dataset
,
header
=
0
)
features
=
df
.
columns
.
values
[:
-
2
]
# print(features)
Y
=
df
.
Style
.
values
X
=
df
.
values
classes
=
np
.
unique
(
Y
)
...
...
@@ -24,7 +22,7 @@ def train_ANN(saveWeighs=False, saveAlgo=False):
X
,
Y
,
test_size
=
0.33
,
random_state
=
42
)
X_test
,
X_val
,
Y_test
,
Y_val
=
train_test_split
(
X_test
,
Y_test
,
test_size
=
0.33
,
random_state
=
42
)
# Save test'set to a csv in order to compute accuracy in cpp
X_train
=
np
.
delete
(
X_train
,
[
-
1
,
-
2
],
axis
=
1
)
X_val
=
np
.
delete
(
X_val
,
[
-
1
,
-
2
],
axis
=
1
)
new_df
=
pd
.
DataFrame
(
X_test
)
...
...
@@ -67,11 +65,11 @@ def train_ANN(saveWeighs=False, saveAlgo=False):
plt
.
gca
().
set_ylim
(
0
,
2
)
# set the vertical range to [0-1]
plt
.
show
()
# Save weights to a .h file
list_weights
=
[]
for
layer
in
model
.
layers
:
weights
=
layer
.
get_weights
()
list_weights
.
append
(
weights
)
# print(str(layer.name) + ":" + str(weights))
mean
=
scaler
.
mean_
std
=
np
.
sqrt
(
scaler
.
var_
)
mean
=
transformListToStr
(
mean
)
...
...
@@ -92,27 +90,6 @@ def train_ANN(saveWeighs=False, saveAlgo=False):
return
model
,
history
def
save_model
(
model
,
history
):
# Save model
RESULT_PATH
=
code_folder
+
'/results'
# description
model_yaml
=
model
.
to_json
()
with
open
(
RESULT_PATH
+
"/modelANN.json"
,
"w"
)
as
yaml_file
:
yaml_file
.
write
(
model_yaml
)
# save model
model
.
save
(
RESULT_PATH
+
"/modelANN.h5"
)
# save weights of the model
model
.
save_weights
(
RESULT_PATH
+
"/modelANN.h5"
)
print
(
"Model saved to disk"
)
# Fit history saving
with
open
(
RESULT_PATH
+
'/trainHistoryANN'
,
'wb'
)
as
file_pi
:
pickle
.
dump
(
history
.
history
,
file_pi
)
def
transformArrayToStr
(
array
):
array
=
array
.
T
text
=
"{"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment