Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question: making LBA compatible with new interface #95

Open
itsdfish opened this issue Sep 15, 2019 · 16 comments
Open

Question: making LBA compatible with new interface #95

itsdfish opened this issue Sep 15, 2019 · 16 comments

Comments

@itsdfish
Copy link

Hi Tamas-

I was wondering if you might help me update this code for the new interface. While updating my other models was simple, this one seems to be a little trickier because I had to make some changes to the adaptation parameters to make it work. Here is how it used to work.

I tried several potential solutions, including initialization = (q = zeros(n), κ = GaussianKineticEnergy(5, 0.1)), but to no avail. Any guidance would be much appreciated. Thanks!

using Distributions, Parameters, DynamicHMC, LogDensityProblems, TransformVariables
using Random
import Distributions: pdf,logpdf,rand
export LBA,pdf,logpdf,rand

mutable struct LBA{T1,T2,T3,T4} <: ContinuousUnivariateDistribution
    ν::T1
    A::T2
    k::T3
    τ::T4
    σ::Float64
end

Base.broadcastable(x::LBA)=Ref(x)

LBA(;τ,A,k,ν,σ=1.0) = LBA(ν,A,k,τ,σ)

function selectWinner(dt)
    if any(x->x >0,dt)
        mi,mv = 0,Inf
        for (i,t) in enumerate(dt)
            if (t > 0) && (t < mv)
                mi = i
                mv = t
            end
        end
    else
        return 1,-1.0
    end
    return mi,mv
end

function sampleDriftRates(ν,σ)
    noPositive=true
    v = similar(ν)
    while noPositive
        v = [rand(Normal(d,σ)) for d in ν]
        any(x->x>0,v) ? noPositive=false : nothing
    end
    return v
end

function rand(d::LBA)
    @unpack τ,A,k,ν,σ = d
    b=A+k
    N = length(ν)
    v = sampleDriftRates(ν,σ)
    a = rand(Uniform(0,A),N)
    dt = @. (b-a)/v
    choice,mn = selectWinner(dt)
    rt = τ .+ mn
    return choice,rt
end

function rand(d::LBA,N::Int)
    choice = fill(0,N)
    rt = fill(0.0,N)
    for i in 1:N
        choice[i],rt[i]=rand(d)
    end
    return (choice=choice,rt=rt)
end

logpdf(d::LBA,choice,rt) = log(pdf(d,choice,rt))

function logpdf(d::LBA,data::T) where {T<:NamedTuple}
    return sum(logpdf.(d,data...))
end

function logpdf(dist::LBA,data::Array{<:Tuple,1})
    LL = 0.0
    for d in data
        LL += logpdf(dist,d...)
    end
    return LL
end

function pdf(d::LBA,c,rt)
    @unpack τ,A,k,ν,σ = d
    b=A+k; den = 1.0
    rt < τ ? (return 1e-10) : nothing
    for (i,v) in enumerate(ν)
        if c == i
            den *= dens(d,v,rt)
        else
            den *= (1-cummulative(d,v,rt))
        end
    end
    pneg = pnegative(d)
    den = den/(1-pneg)
    den = max(den,1e-10)
    isnan(den) ? (return 0.0) : (return den)
end

logpdf(d::LBA,data::Tuple) = logpdf(d,data...)

function dens(d::LBA,v,rt)
    @unpack τ,A,k,ν,σ = d
    dt = rt-τ; b=A+k
    n1 = (b-A-dt*v)/(dt*σ)
    n2 = (b-dt*v)/(dt*σ)
    dens = (1/A)*(-v*cdf(Normal(0,1),n1) + σ*pdf(Normal(0,1),n1) +
        v*cdf(Normal(0,1),n2) - σ*pdf(Normal(0,1),n2))
    return dens
end

function cummulative(d::LBA,v,rt)
    @unpack τ,A,k,ν,σ = d
    dt = rt-τ; b=A+k
    n1 = (b-A-dt*v)/(dt*σ)
    n2 = (b-dt*v)/(dt*σ)
    cm = 1 + ((b-A-dt*v)/A)*cdf(Normal(0,1),n1) -
        ((b-dt*v)/A)*cdf(Normal(0,1),n2) + ((dt*σ)/A)*pdf(Normal(0,1),n1) -
        ((dt*σ)/A)*pdf(Normal(0,1),n2)
    return cm
end

function pnegative(d::LBA)
    @unpack ν,σ=d
    p=1.0
    for v in ν
        p*= cdf(Normal(0,1),-v/σ)
    end
    return p
end

   struct LBAProb{T}
      data::T
      N::Int
      Nc::Int
  end

  function (problem::LBAProb)(θ)
      @unpack data=problem
      @unpack v,A,k,tau=θ
      d=LBA(ν=v,A=A,k=k,τ=tau)
      minRT = minimum(x->x[2],data)
      logpdf(d,data)+sum(logpdf.(TruncatedNormal(0,3,0,Inf),v)) +
      logpdf(TruncatedNormal(.8,.4,0,Inf),A)+logpdf(TruncatedNormal(.2,.3,0,Inf),k)+
      logpdf(TruncatedNormal(.4,.1,0,minRT),tau)
  end

function sampleDHMC(choice,rt,N,Nc,nsamples)
    data = [(c,r) for (c,r) in zip(choice,rt)]
    return sampleDHMC(data,N,Nc,nsamples)
end

# Define problem with data and inits.
function sampleDHMC(data,N,Nc,nsamples)
    p = LBAProb(data,N,Nc)
    p((v=fill(.5,Nc),A=.8,k=.2,tau=.4))
    # Write a function to return properly dimensioned transformation.
    trans = as((v=as(Array,asℝ₊,Nc),A=asℝ₊,k=asℝ₊,tau=asℝ₊))
    # Use Flux for the gradient.
    P = TransformedLogDensity(trans, p)
    ∇P = ADgradient(:ForwardDiff, P)
    # FSample from the posterior.
    n = dimension(trans)
    results = mcmc_with_warmup(Random.GLOBAL_RNG, ∇P, nsamples;
        q = zeros(n), p = ones(n),reporter = NoProgressReport())
    # Undo the transformation to obtain the posterior from the chain.
    posterior = transform.(trans, results.chain)
    chns = nptochain(results,posterior)
    return chns
end

function simulateLBA(;Nd,v=[1.0,1.5,2.0],A=.8,k=.2,tau=.4,kwargs...)
    return (rand(LBA(ν=v,A=A,k=k,τ=tau),Nd)...,N=Nd,Nc=length(v))
end

data = simulateLBA(Nd=10)

samples = sampleDHMC(data...,2000)
@tpapp
Copy link
Owner

tpapp commented Sep 15, 2019

If this is from the Statistical Rethinking book, can you please tell me where to find it?

@itsdfish
Copy link
Author

This is actually a different model that Rob and I are using for MCMCBenchmarks.

@tpapp
Copy link
Owner

tpapp commented Sep 16, 2019

Thanks. I think the issue is with coding the log-likelihood in a numerically robust way, I will skim through the paper to understand the model and get back to you about this.

@itsdfish
Copy link
Author

Thanks. Something like that might help.

Since the model works in Stan fairly well, I was wondering whether adopting Stan's NUTS configuration might work too. In fact, having that as a pre-set configuration (e.g. setting something to Stan_Config) might be helpful (assuming that your settings are still different from Stan).

@goedman
Copy link

goedman commented Sep 17, 2019 via email

@itsdfish
Copy link
Author

Thanks, Rob. That is further than I got.

Roughly every other run ends in an error. When it does run all the way through, the v parameters are off quite a bit, even when I increase the number of data points to 100.

ArgumentError: Value and slope at step length = 0 must be finite.
(::LineSearches.HagerZhang{Float64,Base.RefValue{Bool}})(::Function, ::getfield(LineSearches, Symbol("#ϕdϕ#6")){Optim.ManifoldObjective{NLSolversBase.OnceDifferentiable{Float64,Array{Float64,1},Array{Float64,1}}},Array{Float64,1},Array{Float64,1},Array{Float64,1}}, ::Float64, ::Float64, ::Float64) at hagerzhang.jl:117
HagerZhang at hagerzhang.jl:101 [inlined]
perform_linesearch!(::Optim.LBFGSState{Array{Float64,1},Array{Array{Float64,1},1},Array{Array{Float64,1},1},Float64,Array{Float64,1}}, ::Optim.LBFGS{Nothing,LineSearches.InitialStatic{Float64},LineSearches.HagerZhang{Float64,Base.RefValue{Bool}},getfield(Optim, Symbol("##19#21"))}, ::Optim.ManifoldObjective{NLSolversBase.OnceDifferentiable{Float64,Array{Float64,1},Array{Float64,1}}}) at perform_linesearch.jl:53
update_state!(::NLSolversBase.OnceDifferentiable{Float64,Array{Float64,1},Array{Float64,1}}, ::Optim.LBFGSState{Array{Float64,1},Array{Array{Float64,1},1},Array{Array{Float64,1},1},Float64,Array{Float64,1}}, ::Optim.LBFGS{Nothing,LineSearches.InitialStatic{Float64},LineSearches.HagerZhang{Float64,Base.RefValue{Bool}},getfield(Optim, Symbol("##19#21"))}) at l_bfgs.jl:198
optimize(::NLSolversBase.OnceDifferentiable{Float64,Array{Float64,1},Array{Float64,1}}, ::Array{Float64,1}, ::Optim.LBFGS{Nothing,LineSearches.InitialStatic{Float64},LineSearches.HagerZhang{Float64,Base.RefValue{Bool}},getfield(Optim, Symbol("##19#21"))}, ::Optim.Options{Float64,Nothing}, ::Optim.LBFGSState{Array{Float64,1},Array{Array{Float64,1},1},Array{Array{Float64,1},1},Float64,Array{Float64,1}}) at optimize.jl:57
optimize(::NLSolversBase.OnceDifferentiable{Float64,Array{Float64,1},Array{Float64,1}}, ::Array{Float64,1}, ::Optim.LBFGS{Nothing,LineSearches.InitialStatic{Float64},LineSearches.HagerZhang{Float64,Base.RefValue{Bool}},getfield(Optim, Symbol("##19#21"))}, ::Optim.Options{Float64,Nothing}) at optimize.jl:33
warmup at mcmc.jl:149 [inlined]
#25 at mcmc.jl:378 [inlined]
mapfoldl_impl(::typeof(identity), ::getfield(DynamicHMC, Symbol("##25#26")){DynamicHMC.SamplingLogDensity{MersenneTwister,LogDensityProblems.ForwardDiffLogDensity{TransformedLogDensity{TransformVariables.TransformTuple{NamedTuple{(:v, :A, :k, :tau),Tuple{TransformVariables.ArrayTransform{TransformVariables.ShiftedExp{true,Float64},1},TransformVariables.ShiftedExp{true,Float64},TransformVariables.ShiftedExp{true,Float64},TransformVariables.ShiftedExp{true,Float64}}}},LBAProb{Array{Tuple{Int64,Float64},1}}},ForwardDiff.GradientConfig{ForwardDiff.Tag{getfield(LogDensityProblems, Symbol("##34#35")){TransformedLogDensity{TransformVariables.TransformTuple{NamedTuple{(:v, :A, :k, :tau),Tuple{TransformVariables.ArrayTransform{TransformVariables.ShiftedExp{true,Float64},1},TransformVariables.ShiftedExp{true,Float64},TransformVariables.ShiftedExp{true,Float64},TransformVariables.ShiftedExp{true,Float64}}}},LBAProb{Array{Tuple{Int64,Float64},1}}}},Float64},Float64,6,Array{ForwardDiff.Dual{ForwardDiff.Tag{getfield(LogDensityProblems, Symbol("##34#35")){TransformedLogDensity{TransformVariables.TransformTuple{NamedTuple{(:v, :A, :k, :tau),Tuple{TransformVariables.ArrayTransform{TransformVariables.ShiftedExp{true,Float64},1},TransformVariables.ShiftedExp{true,Float64},TransformVariables.ShiftedExp{true,Float64},TransformVariables.ShiftedExp{true,Float64}}}},LBAProb{Array{Tuple{Int64,Float64},1}}}},Float64},Float64,6},1}}},DynamicHMC.NUTS{Val{:generalized}},LogProgressReport{Nothing}}}, ::NamedTuple{(:init,),Tuple{Tuple{Tuple{},DynamicHMC.WarmupState{DynamicHMC.EvaluatedLogDensity{Array{Float64,1},Float64},GaussianKineticEnergy{LinearAlgebra.Diagonal{Float64,Array{Float64,1}},LinearAlgebra.Diagonal{Float64,Array{Float64,1}}},Nothing}}}}, ::Tuple{FindLocalOptimum{Float64},InitialStepsizeSearch,TuningNUTS{Nothing,DualAveraging{Float64}},TuningNUTS{LinearAlgebra.Diagonal,DualAveraging{Float64}},TuningNUTS{LinearAlgebra.Diagonal,DualAveraging{Float64}},TuningNUTS{LinearAlgebra.Diagonal,DualAveraging{Float64}},TuningNUTS{LinearAlgebra.Diagonal,DualAveraging{Float64}},TuningNUTS{LinearAlgebra.Diagonal,DualAveraging{Float64}},TuningNUTS{Nothing,DualAveraging{Float64}}}) at reduce.jl:45
#mapfoldl#187 at reduce.jl:72 [inlined]
#mapfoldl at none:0 [inlined]
#foldl#188 at reduce.jl:90 [inlined]
#foldl at none:0 [inlined]
_warmup(::DynamicHMC.SamplingLogDensity{MersenneTwister,LogDensityProblems.ForwardDiffLogDensity{TransformedLogDensity{TransformVariables.TransformTuple{NamedTuple{(:v, :A, :k, :tau),Tuple{TransformVariables.ArrayTransform{TransformVariables.ShiftedExp{true,Float64},1},TransformVariables.ShiftedExp{true,Float64},TransformVariables.ShiftedExp{true,Float64},TransformVariables.ShiftedExp{true,Float64}}}},LBAProb{Array{Tuple{Int64,Float64},1}}},ForwardDiff.GradientConfig{ForwardDiff.Tag{getfield(LogDensityProblems, Symbol("##34#35")){TransformedLogDensity{TransformVariables.TransformTuple{NamedTuple{(:v, :A, :k, :tau),Tuple{TransformVariables.ArrayT...

@goedman
Copy link

goedman commented Sep 17, 2019

Yes, I think I managed to run Stan using the same data and get very different results.

@tpapp
Copy link
Owner

tpapp commented Sep 17, 2019

I am working on this, but want to understand the model first.

@tpapp
Copy link
Owner

tpapp commented Sep 17, 2019

I worked a bit on the code and put it in a repo

https://github.com/tpapp/LBA_problem

where you can track the changes I did.

There seems to be a numerical problem since you are multiplying densities, which over/underflow rather quickly. I made some changes but could not comlete everything since I don't fully understand the model. The first thing I would recommend is that you finish this rewrite and verify that all the formulas are correct.

Then you should explore the robustness with LogDensityProblems.stresstest. If everything works, and there is still a bug, please get back to me.

@itsdfish
Copy link
Author

Thanks. I appreciate your help. I'll get back with you as soon as I know something.

@tpapp
Copy link
Owner

tpapp commented Oct 9, 2019

I am happy to keep this issue open, but please let me know if you need further help, or if the problem is solved now.

@itsdfish
Copy link
Author

itsdfish commented Oct 9, 2019

Sorry about that. Rob and I were looking into the issue and he started making progress, but had to put it aside for a while. I'll close the issue for the time being and will reopen if we reach an impasse.

@itsdfish itsdfish closed this as completed Oct 9, 2019
@itsdfish
Copy link
Author

Hi Tamas-

Rob and I hacked away at this problem but still have not came to a full resolution. I looked through your changes to the logpdf and found a minor error, but aside from that, it was good. Although it did not solve the numerical errors, it should reduce under/overflow, particularly in large data sets. Just as a reminder, this is the error message:

ArgumentError: Value and slope at step length = 0 must be finite.

After turning off the initial optimization stage, I encountered a domain error which produced the following message:

DomainError with [0.257204, 0.237231, 0.800883, -0.965525, -0.892265, -0.175778]:
Starting point has non-finite density.

Here is the up-to-date code to replicate this result. If I understand correctly, given the transformation bounds as((v=as(Array,asℝ₊,data.Nc),A=asℝ₊,k=asℝ₊,tau=asℝ₊)), the vector in the error message should contain all positive values. Is that correct?

Thanks for your help!

@itsdfish itsdfish reopened this Nov 11, 2019
@tpapp
Copy link
Owner

tpapp commented Nov 11, 2019

Thanks for the heads up, I will look at this.

@itsdfish
Copy link
Author

Thanks!

By the way, I just realized that the transformation on tau might need a finite upper bound:

minRT = minimum(data.rt)
trans = as((v=as(Array,asℝ₊,data.Nc),A=asℝ₊,k=asℝ₊,tau=as(Real,0,minRT)))

In either case, I want to check whether the negative values in the domain error are to be expected.

@tpapp
Copy link
Owner

tpapp commented Nov 12, 2019

I looked at your example.

First, those negative values are definitely to be expected. Sampling works on the ℝⁿ (after transformation).

Second, here is how I would recommend debugging this, either in the parameter or the unconstrained space:

bad_xs = LogDensityProblems.stresstest(LogDensityProblems.logdensity, P; scale = 0.001)
bad_θs = trans.(bad_xs)
LogDensityProblems.logdensity(P, bad_xs[1])
p(bad_θs[1])

The scale argument I added to rule out numerical problem in the first pass (if domain problems are fixed, then you should remove it and retest). You can also change logdensity to logdensity_and_gradient once this is fixed.

Again, please do not hesitate to ask for help if you get stuck.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants