Skip to content

Commit

Permalink
Merge pull request #340 from Neuroblox/ho/reset
Browse files Browse the repository at this point in the history
Add `reset!` function for `Agent` and `AbstractEnvironment`
  • Loading branch information
harisorgn authored Feb 15, 2024
2 parents e7bf50d + efaf386 commit fe808d8
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 25 deletions.
2 changes: 1 addition & 1 deletion src/Neuroblox.jl
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ export IFNeuronBlox, LIFNeuronBlox, QIFNeuronBlox, HHNeuronExciBlox, HHNeuronInh
export LinearNeuralMass, HarmonicOscillator, JansenRit, WilsonCowan, LarterBreakspear, NextGenerationBlox, NextGenerationResolvedBlox, NextGenerationEIBlox
export Matrisome, Striosome, Striatum, GPi, GPe, Thalamus, STN, TAN, SNc
export HebbianPlasticity, HebbianModulationPlasticity
export Agent, ClassificationEnvironment, GreedyPolicy
export Agent, ClassificationEnvironment, GreedyPolicy, reset!
export LearningBlox
export CosineSource, CosineBlox, NoisyCosineBlox, PhaseBlox, ImageStimulus
export PowerSpectrumBlox, BandPassFilterBlox
Expand Down
56 changes: 32 additions & 24 deletions src/blox/reinforcement_learning.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,14 @@ function maybe_set_state_post!(lr::AbstractLearningRule, state)
end
end

mutable struct ClassificationEnvironment <: AbstractEnvironment
const name
const namespace
const source
const category
const N_trials
const t_trial
current_trial
mutable struct ClassificationEnvironment{S} <: AbstractEnvironment
const name::Symbol
const namespace::Symbol
const source::S
const category::Vector{Int}
const N_trials::Int
const t_trial::Float64
current_trial::Int

function ClassificationEnvironment(data::DataFrame; name, namespace=nothing, t_stimulus, t_pause)
stim = ImageStimulus(
Expand All @@ -108,21 +108,23 @@ mutable struct ClassificationEnvironment <: AbstractEnvironment
N_trials = stim.N_stimuli
t_trial = t_stimulus + t_pause

new(name, namespace, stim, category, N_trials, t_trial, 1)
new{typeof(stim)}(Symbol(name), Symbol(namespace), stim, category, N_trials, t_trial, 1)
end

function ClassificationEnvironment(stim::ImageStimulus; name, namespace=nothing)
t_trial = stim.t_stimulus + stim.t_pause
N_trials = stim.N_stimuli

new(name, namespace, stim, stim.category, N_trials, t_trial, 1)
new{typeof(stim)}(Symbol(name), Symbol(namespace), stim, stim.category, N_trials, t_trial, 1)
end
end

(env::ClassificationEnvironment)(action) = action == env.category[env.current_trial]

increment_trial!(env::AbstractEnvironment) = env.current_trial += 1

reset!(env::AbstractEnvironment) = env.current_trial = 1

function get_trial_stimulus(env::ClassificationEnvironment)
stim_params = env.source.stim_parameters
stim_values = env.source.IMG[:, env.current_trial]
Expand All @@ -133,11 +135,11 @@ end
abstract type AbstractActionSelection <: AbstractBlox end

mutable struct GreedyPolicy <: AbstractActionSelection
const name
const namespace
competitor_states
competitor_params
const t_decision
const name::Symbol
const namespace::Symbol
competitor_states::Vector{Num}
competitor_params::Vector{Num}
const t_decision::Float64

function GreedyPolicy(; name, t_decision, namespace=nothing, competitor_states=nothing, competitor_params=nothing)
sts = isnothing(competitor_states) ? Num[] : competitor_states
Expand All @@ -150,6 +152,7 @@ function (p::GreedyPolicy)(sol::SciMLBase.AbstractSciMLSolution)
comp_vals = sol(p.t_decision; idxs=p.competitor_states)
return argmax(comp_vals)
end

"""
function (p::GreedyPolicy)(sys::ODESystem, prob::ODEProblem)
ps = parameters(sys)
Expand All @@ -166,11 +169,13 @@ function (p::GreedyPolicy)(sys::ODESystem, prob::ODEProblem)
return argmax(comp_vals)
end
"""
mutable struct Agent
odesystem
problem
action_selection
learning_rules

mutable struct Agent{S,P,A,LR}
odesystem::S
problem::P
action_selection::A
learning_rules::LR
init_params::Vector{Float64}

function Agent(g::MetaDiGraph; name, kwargs...)
bc = connector_from_graph(g)
Expand All @@ -183,16 +188,19 @@ mutable struct Agent
u0 = haskey(kwargs, :u0) ? kwargs[:u0] : []
p = haskey(kwargs, :p) ? kwargs[:p] : []

prob = ODEProblem(ss, u0, (0,1), p)

prob = ODEProblem(ss, u0, (0.,1.), p)
init_params = prob.p

policy = action_selection_from_graph(g)
learning_rules = bc.learning_rules

new(ss, prob, policy, learning_rules)


new{typeof(sys), typeof(prob), typeof(policy), typeof(learning_rules)}(ss, prob, policy, learning_rules, init_params)
end
end


reset!(ag::Agent) = ag.problem = remake(ag.problem; p = ag.init_params)

function run_experiment!(agent::Agent, env::ClassificationEnvironment, t_warmup=200.0; kwargs...)
N_trials = env.N_trials
Expand Down
6 changes: 6 additions & 0 deletions test/reinforcement_learning.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,10 @@ using CSV
@test any(init_params[map_idxs[idxs_weight]] .!= final_params[map_idxs[idxs_weight]])
# All non-weight parameters need to be the same.
@test all(init_params[map_idxs[idxs_other_params]] .== final_params[map_idxs[idxs_other_params]])

reset!(agent)
@test all(init_params .== agent.problem.p)

reset!(env)
@test env.current_trial == 1
end

0 comments on commit fe808d8

Please sign in to comment.