diff --git a/src/Neuroblox.jl b/src/Neuroblox.jl index 7935b08d..5cddbcc5 100644 --- a/src/Neuroblox.jl +++ b/src/Neuroblox.jl @@ -188,7 +188,7 @@ export IFNeuronBlox, LIFNeuronBlox, QIFNeuronBlox, HHNeuronExciBlox, HHNeuronInh export LinearNeuralMass, HarmonicOscillator, JansenRit, WilsonCowan, LarterBreakspear, NextGenerationBlox, NextGenerationResolvedBlox, NextGenerationEIBlox export Matrisome, Striosome, Striatum, GPi, GPe, Thalamus, STN, TAN, SNc export HebbianPlasticity, HebbianModulationPlasticity -export Agent, ClassificationEnvironment, GreedyPolicy +export Agent, ClassificationEnvironment, GreedyPolicy, reset! export LearningBlox export CosineSource, CosineBlox, NoisyCosineBlox, PhaseBlox, ImageStimulus export PowerSpectrumBlox, BandPassFilterBlox diff --git a/src/blox/reinforcement_learning.jl b/src/blox/reinforcement_learning.jl index a6bf7a84..74d2370d 100644 --- a/src/blox/reinforcement_learning.jl +++ b/src/blox/reinforcement_learning.jl @@ -86,14 +86,14 @@ function maybe_set_state_post!(lr::AbstractLearningRule, state) end end -mutable struct ClassificationEnvironment <: AbstractEnvironment - const name - const namespace - const source - const category - const N_trials - const t_trial - current_trial +mutable struct ClassificationEnvironment{S} <: AbstractEnvironment + const name::Symbol + const namespace::Symbol + const source::S + const category::Vector{Int} + const N_trials::Int + const t_trial::Float64 + current_trial::Int function ClassificationEnvironment(data::DataFrame; name, namespace=nothing, t_stimulus, t_pause) stim = ImageStimulus( @@ -108,14 +108,14 @@ mutable struct ClassificationEnvironment <: AbstractEnvironment N_trials = stim.N_stimuli t_trial = t_stimulus + t_pause - new(name, namespace, stim, category, N_trials, t_trial, 1) + new{typeof(stim)}(Symbol(name), Symbol(namespace), stim, category, N_trials, t_trial, 1) end function ClassificationEnvironment(stim::ImageStimulus; name, namespace=nothing) t_trial = stim.t_stimulus + stim.t_pause N_trials = stim.N_stimuli - new(name, namespace, stim, stim.category, N_trials, t_trial, 1) + new{typeof(stim)}(Symbol(name), Symbol(namespace), stim, stim.category, N_trials, t_trial, 1) end end @@ -123,6 +123,8 @@ end increment_trial!(env::AbstractEnvironment) = env.current_trial += 1 +reset!(env::AbstractEnvironment) = env.current_trial = 1 + function get_trial_stimulus(env::ClassificationEnvironment) stim_params = env.source.stim_parameters stim_values = env.source.IMG[:, env.current_trial] @@ -133,11 +135,11 @@ end abstract type AbstractActionSelection <: AbstractBlox end mutable struct GreedyPolicy <: AbstractActionSelection - const name - const namespace - competitor_states - competitor_params - const t_decision + const name::Symbol + const namespace::Symbol + competitor_states::Vector{Num} + competitor_params::Vector{Num} + const t_decision::Float64 function GreedyPolicy(; name, t_decision, namespace=nothing, competitor_states=nothing, competitor_params=nothing) sts = isnothing(competitor_states) ? Num[] : competitor_states @@ -150,6 +152,7 @@ function (p::GreedyPolicy)(sol::SciMLBase.AbstractSciMLSolution) comp_vals = sol(p.t_decision; idxs=p.competitor_states) return argmax(comp_vals) end + """ function (p::GreedyPolicy)(sys::ODESystem, prob::ODEProblem) ps = parameters(sys) @@ -166,11 +169,13 @@ function (p::GreedyPolicy)(sys::ODESystem, prob::ODEProblem) return argmax(comp_vals) end """ -mutable struct Agent - odesystem - problem - action_selection - learning_rules + +mutable struct Agent{S,P,A,LR} + odesystem::S + problem::P + action_selection::A + learning_rules::LR + init_params::Vector{Float64} function Agent(g::MetaDiGraph; name, kwargs...) bc = connector_from_graph(g) @@ -183,16 +188,19 @@ mutable struct Agent u0 = haskey(kwargs, :u0) ? kwargs[:u0] : [] p = haskey(kwargs, :p) ? kwargs[:p] : [] - prob = ODEProblem(ss, u0, (0,1), p) - + prob = ODEProblem(ss, u0, (0.,1.), p) + init_params = prob.p + policy = action_selection_from_graph(g) learning_rules = bc.learning_rules - new(ss, prob, policy, learning_rules) + + + new{typeof(sys), typeof(prob), typeof(policy), typeof(learning_rules)}(ss, prob, policy, learning_rules, init_params) end end - +reset!(ag::Agent) = ag.problem = remake(ag.problem; p = ag.init_params) function run_experiment!(agent::Agent, env::ClassificationEnvironment, t_warmup=200.0; kwargs...) N_trials = env.N_trials diff --git a/test/reinforcement_learning.jl b/test/reinforcement_learning.jl index 25e0e915..39fee2eb 100644 --- a/test/reinforcement_learning.jl +++ b/test/reinforcement_learning.jl @@ -69,4 +69,10 @@ using CSV @test any(init_params[map_idxs[idxs_weight]] .!= final_params[map_idxs[idxs_weight]]) # All non-weight parameters need to be the same. @test all(init_params[map_idxs[idxs_other_params]] .== final_params[map_idxs[idxs_other_params]]) + + reset!(agent) + @test all(init_params .== agent.problem.p) + + reset!(env) + @test env.current_trial == 1 end