From a26b74967b4dc97a6b571c7c8f66aad24c151ce6 Mon Sep 17 00:00:00 2001 From: Haris Orgn Date: Thu, 15 Feb 2024 15:44:17 +0200 Subject: [PATCH 1/5] add type assertions for RL structs --- src/blox/reinforcement_learning.jl | 47 +++++++++++++++++------------- 1 file changed, 26 insertions(+), 21 deletions(-) diff --git a/src/blox/reinforcement_learning.jl b/src/blox/reinforcement_learning.jl index a6bf7a84..459ad841 100644 --- a/src/blox/reinforcement_learning.jl +++ b/src/blox/reinforcement_learning.jl @@ -86,14 +86,14 @@ function maybe_set_state_post!(lr::AbstractLearningRule, state) end end -mutable struct ClassificationEnvironment <: AbstractEnvironment - const name - const namespace - const source - const category - const N_trials - const t_trial - current_trial +mutable struct ClassificationEnvironment{S} <: AbstractEnvironment + const name::Symbol + const namespace::Symbol + const source::S + const category::Vector{Int} + const N_trials::Int + const t_trial::Float64 + current_trial::Int function ClassificationEnvironment(data::DataFrame; name, namespace=nothing, t_stimulus, t_pause) stim = ImageStimulus( @@ -108,14 +108,14 @@ mutable struct ClassificationEnvironment <: AbstractEnvironment N_trials = stim.N_stimuli t_trial = t_stimulus + t_pause - new(name, namespace, stim, category, N_trials, t_trial, 1) + new{typeof(stim)}(Symbol(name), Symbol(namespace), stim, category, N_trials, t_trial, 1) end function ClassificationEnvironment(stim::ImageStimulus; name, namespace=nothing) t_trial = stim.t_stimulus + stim.t_pause N_trials = stim.N_stimuli - new(name, namespace, stim, stim.category, N_trials, t_trial, 1) + new{typeof(stim)}(Symbol(name), Symbol(namespace), stim, stim.category, N_trials, t_trial, 1) end end @@ -133,11 +133,11 @@ end abstract type AbstractActionSelection <: AbstractBlox end mutable struct GreedyPolicy <: AbstractActionSelection - const name - const namespace - competitor_states - competitor_params - const t_decision + const name::Symbol + const namespace::Symbol + competitor_states::Vector{Num} + competitor_params::Vector{Num} + const t_decision::Float64 function GreedyPolicy(; name, t_decision, namespace=nothing, competitor_states=nothing, competitor_params=nothing) sts = isnothing(competitor_states) ? Num[] : competitor_states @@ -150,6 +150,7 @@ function (p::GreedyPolicy)(sol::SciMLBase.AbstractSciMLSolution) comp_vals = sol(p.t_decision; idxs=p.competitor_states) return argmax(comp_vals) end + """ function (p::GreedyPolicy)(sys::ODESystem, prob::ODEProblem) ps = parameters(sys) @@ -166,11 +167,13 @@ function (p::GreedyPolicy)(sys::ODESystem, prob::ODEProblem) return argmax(comp_vals) end """ -mutable struct Agent - odesystem - problem - action_selection - learning_rules + +mutable struct Agent{S,P,A,LR} + odesystem::S + problem::P + action_selection::A + learning_rules::LR + init_params::Vector{Float64} function Agent(g::MetaDiGraph; name, kwargs...) bc = connector_from_graph(g) @@ -188,7 +191,9 @@ mutable struct Agent policy = action_selection_from_graph(g) learning_rules = bc.learning_rules - new(ss, prob, policy, learning_rules) + + + new{typeof(sys), typeof(prob), typeof(policy), typeof(learning_rules)}(ss, prob, policy, learning_rules, init_params) end end From 69b46a523c9b2bb35e4e5f0ad5b064f30ec24e37 Mon Sep 17 00:00:00 2001 From: Haris Orgn Date: Thu, 15 Feb 2024 15:44:35 +0200 Subject: [PATCH 2/5] add initial parameters field in `Agent` --- src/blox/reinforcement_learning.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/blox/reinforcement_learning.jl b/src/blox/reinforcement_learning.jl index 459ad841..59d13537 100644 --- a/src/blox/reinforcement_learning.jl +++ b/src/blox/reinforcement_learning.jl @@ -186,8 +186,9 @@ mutable struct Agent{S,P,A,LR} u0 = haskey(kwargs, :u0) ? kwargs[:u0] : [] p = haskey(kwargs, :p) ? kwargs[:p] : [] - prob = ODEProblem(ss, u0, (0,1), p) - + prob = ODEProblem(ss, u0, (0.,1.), p) + init_params = prob.p + policy = action_selection_from_graph(g) learning_rules = bc.learning_rules From 7d8409ffbbd339fa05088e779d2eafb126417ddb Mon Sep 17 00:00:00 2001 From: Haris Orgn Date: Thu, 15 Feb 2024 15:44:59 +0200 Subject: [PATCH 3/5] add `reset!` function for `Agent` and `AbstractEnvironment` --- src/blox/reinforcement_learning.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/blox/reinforcement_learning.jl b/src/blox/reinforcement_learning.jl index 59d13537..74d2370d 100644 --- a/src/blox/reinforcement_learning.jl +++ b/src/blox/reinforcement_learning.jl @@ -123,6 +123,8 @@ end increment_trial!(env::AbstractEnvironment) = env.current_trial += 1 +reset!(env::AbstractEnvironment) = env.current_trial = 1 + function get_trial_stimulus(env::ClassificationEnvironment) stim_params = env.source.stim_parameters stim_values = env.source.IMG[:, env.current_trial] @@ -198,7 +200,7 @@ mutable struct Agent{S,P,A,LR} end end - +reset!(ag::Agent) = ag.problem = remake(ag.problem; p = ag.init_params) function run_experiment!(agent::Agent, env::ClassificationEnvironment, t_warmup=200.0; kwargs...) N_trials = env.N_trials From 183c428d2d2e757d5c05550b9e291953ae38068f Mon Sep 17 00:00:00 2001 From: Haris Orgn Date: Thu, 15 Feb 2024 15:45:10 +0200 Subject: [PATCH 4/5] export `reset!` --- src/Neuroblox.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Neuroblox.jl b/src/Neuroblox.jl index 7935b08d..5cddbcc5 100644 --- a/src/Neuroblox.jl +++ b/src/Neuroblox.jl @@ -188,7 +188,7 @@ export IFNeuronBlox, LIFNeuronBlox, QIFNeuronBlox, HHNeuronExciBlox, HHNeuronInh export LinearNeuralMass, HarmonicOscillator, JansenRit, WilsonCowan, LarterBreakspear, NextGenerationBlox, NextGenerationResolvedBlox, NextGenerationEIBlox export Matrisome, Striosome, Striatum, GPi, GPe, Thalamus, STN, TAN, SNc export HebbianPlasticity, HebbianModulationPlasticity -export Agent, ClassificationEnvironment, GreedyPolicy +export Agent, ClassificationEnvironment, GreedyPolicy, reset! export LearningBlox export CosineSource, CosineBlox, NoisyCosineBlox, PhaseBlox, ImageStimulus export PowerSpectrumBlox, BandPassFilterBlox From efaf3868885678808d17920ee81990119d86f8f6 Mon Sep 17 00:00:00 2001 From: Haris Orgn Date: Thu, 15 Feb 2024 15:45:33 +0200 Subject: [PATCH 5/5] use `reset!` in RL test --- test/reinforcement_learning.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/test/reinforcement_learning.jl b/test/reinforcement_learning.jl index 25e0e915..39fee2eb 100644 --- a/test/reinforcement_learning.jl +++ b/test/reinforcement_learning.jl @@ -69,4 +69,10 @@ using CSV @test any(init_params[map_idxs[idxs_weight]] .!= final_params[map_idxs[idxs_weight]]) # All non-weight parameters need to be the same. @test all(init_params[map_idxs[idxs_other_params]] .== final_params[map_idxs[idxs_other_params]]) + + reset!(agent) + @test all(init_params .== agent.problem.p) + + reset!(env) + @test env.current_trial == 1 end