Skip to content

Commit

Permalink
modified: CHANGELOG.md
Browse files Browse the repository at this point in the history
	modified:   Project.toml
	modified:   src/sub/design_models.jl
	modified:   test/system_evaluation_test.jl
  • Loading branch information
Pierre BLAUD committed Mar 29, 2023
1 parent a40615c commit 17c825e
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 19 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## v0.1.7

* Add dict for return for _extract_model_from_machine functions.
* Tests improvement.

## v0.1.6
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "AutomationLabsSystems"
uuid = "6d3dfdf0-e107-48c6-a6ff-eced1a5a9334"
authors = ["Pierre BLAUD <pierre.blaud@ikmail.com> and contributors"]
version = "0.1.6"
version = "0.1.7"

[deps]
AutomationLabsIdentification = "48ff5a6f-d08b-4053-9585-6a9e3e078386"
Expand Down
12 changes: 9 additions & 3 deletions src/sub/design_models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,11 @@ function _extract_model_from_machine(
)

# Extract best model from the machine
return MLJ.fitted_params(MLJ.fitted_params(machine_mlj).machine).best_fitted_params[1]
f = MLJ.fitted_params(MLJ.fitted_params(machine_mlj).machine).best_fitted_params[1]

result = Dict(:f => f)

return result
end

"""
Expand All @@ -53,15 +57,17 @@ A function for design the system (model and constraints) with MathematicalSystem
function _extract_model_from_machine(
model_type::MLJMultivariateStatsInterface.MultitargetLinearRegressor,
machine_mlj,
nbr_state,
)

# Extract model from the machine
AB_t = MLJ.fitted_params(machine_mlj).coefficients
nbr_state = size(AB_t, 2)

AB = copy(AB_t')
A = AB[:, 1:nbr_state]
B = AB[:, nbr_state+1:end]

result = Dict(:A => A, :B => B)
# Set the system
return A, B
return result
end
30 changes: 15 additions & 15 deletions test/system_evaluation_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,26 +81,26 @@ import AutomationLabsSystems: _extract_model_from_machine
f_rnn = _extract_model_from_machine(rnn_type, rnn)
f_lstm = _extract_model_from_machine(lstm_type, lstm)
f_gru = _extract_model_from_machine(gru_type, gru)
f_linear = _extract_model_from_machine(linear_type, linear, 4)
f_linear = _extract_model_from_machine(linear_type, linear)

nbr_state = 4
nbr_input = 2
variation = "discrete"

sys_fnn = proceed_system(f_fnn, nbr_state, nbr_input, variation)
sys_icnn = proceed_system(f_icnn, nbr_state, nbr_input, variation)
sys_resnet = proceed_system(f_resnet, nbr_state, nbr_input, variation)
sys_polynet = proceed_system(f_polynet, nbr_state, nbr_input, variation)
sys_densenet = proceed_system(f_densenet, nbr_state, nbr_input, variation)
sys_rbf = proceed_system(f_rbf, nbr_state, nbr_input, variation)
sys_neuralode = proceed_system(f_neuralode, nbr_state, nbr_input, variation)
sys_rknn1 = proceed_system(f_rknn1, nbr_state, nbr_input, variation)
sys_rknn2 = proceed_system(f_rknn2, nbr_state, nbr_input, variation)
sys_rknn4 = proceed_system(f_rknn4, nbr_state, nbr_input, variation)
sys_rnn = proceed_system(f_rnn, nbr_state, nbr_input, variation)
sys_lstm = proceed_system(f_lstm, nbr_state, nbr_input, variation)
sys_gru = proceed_system(f_gru, nbr_state, nbr_input, variation)
sys_linear = proceed_system(f_linear[1], f_linear[2], nbr_state, nbr_input, variation)
sys_fnn = proceed_system(f_fnn[:f], nbr_state, nbr_input, variation)
sys_icnn = proceed_system(f_icnn[:f], nbr_state, nbr_input, variation)
sys_resnet = proceed_system(f_resnet[:f], nbr_state, nbr_input, variation)
sys_polynet = proceed_system(f_polynet[:f], nbr_state, nbr_input, variation)
sys_densenet = proceed_system(f_densenet[:f], nbr_state, nbr_input, variation)
sys_rbf = proceed_system(f_rbf[:f], nbr_state, nbr_input, variation)
sys_neuralode = proceed_system(f_neuralode[:f], nbr_state, nbr_input, variation)
sys_rknn1 = proceed_system(f_rknn1[:f], nbr_state, nbr_input, variation)
sys_rknn2 = proceed_system(f_rknn2[:f], nbr_state, nbr_input, variation)
sys_rknn4 = proceed_system(f_rknn4[:f], nbr_state, nbr_input, variation)
sys_rnn = proceed_system(f_rnn[:f], nbr_state, nbr_input, variation)
sys_lstm = proceed_system(f_lstm[:f], nbr_state, nbr_input, variation)
sys_gru = proceed_system(f_gru[:f], nbr_state, nbr_input, variation)
sys_linear = proceed_system(f_linear[:A], f_linear[:B], nbr_state, nbr_input, variation)

fnn_type_2 = proceed_system_model_evaluation(sys_fnn)
icnn_type_2 = proceed_system_model_evaluation(sys_icnn)
Expand Down

0 comments on commit 17c825e

Please sign in to comment.