Skip to content

Commit

Permalink
Fixed some functions and with that a few posteriors
Browse files Browse the repository at this point in the history
  • Loading branch information
nsiccha committed Dec 13, 2024
1 parent f0840d5 commit cf7cc37
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 18 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,6 @@ sys
sysconfig
cache
*.ipynb
*.quarto_ipynb
*.sh
.jupyter_cache
1 change: 1 addition & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
NaNStatistics = "b946abbf-3ea7-4610-9019-9858bfdeaf2d"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
PlotlyJS = "f0f68f2c-4968-5e81-91da-67840de0976a"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
PosteriorDB = "1c4bc282-d2f5-44f9-b6d1-8c4424a23ad4"
Expand Down
6 changes: 3 additions & 3 deletions docs/julia/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ cd(dirname(Base.active_project()))
using LinearAlgebra
using DataFrames, StanBlocks, Markdown, PosteriorDB, StanLogDensityProblems, LogDensityProblems, Statistics, OrderedCollections, PrettyTables, Serialization, Chairmarks, Enzyme, BridgeStan, Markdown, Pkg, Mooncake
using Logging, Test, LinearAlgebra
using Statistics, StanBlocks, PosteriorDB, Distributions, Random
using Statistics, StanBlocks, PosteriorDB, Distributions, Random, OrdinaryDiffEq
using StatsPlots
getsampletimes(x::Chairmarks.Benchmark) = getfield.(x.samples, :time);
const ENZYME_VERSION = filter(x->x.second.name=="Enzyme", Pkg.dependencies()) |> x->first(x)[2].version
Expand Down Expand Up @@ -144,7 +144,7 @@ compute_property(e, ::Val{:lpdf_accuracy}) = finite_relative_difference(
e.julia_lpdfs .+ median(filter(isfinite, e.stan_lpdfs - e.julia_lpdfs)),
e.stan_lpdfs
)
compute_property(e, ::Val{:usable}) = !isnothing(e.lpdf_accuracy) && e.lpdf_accuracy <= 1e-8
compute_property(e, ::Val{:usable}) = !isnothing(e.lpdf_accuracy) && e.lpdf_accuracy <= (e.posterior_name in ("sir-sir","one_comp_mm_elim_abs-one_comp_mm_elim_abs", "soil_carbon-soil_incubation", "hudson_lynx_hare-lotka_volterra") ? 1e-4 : 1e-8)
compute_property(e, ::Val{:julia_lpdf_benchmark}) = (@be randn(e.dimension) e.julia_lpdf)
compute_property(e, ::Val{:stan_lpdf_benchmark}) = (@be randn(e.dimension) e.stan_lpdf)
compute_property(e, ::Val{:lpdf_comparison}) = begin
Expand Down Expand Up @@ -209,7 +209,7 @@ compute_property(e, ::Val{:df_row}) = begin
enzyme_crashed = isnothing(e.enzyme_accuracy)
merge(
row,
(;enzyme_accuracy=something(e.enzyme_accuracy, missing), mooncake_accuracy=something(e.mooncake_accuracy, missing)),
(;enzyme_accuracy=something(e.enzyme_accuracy, "FAILED"), mooncake_accuracy=something(e.mooncake_accuracy, "FAILED")),
something(e.lpdf_comparison, (;)),
!enzyme_crashed ? something(e.gradient_comparison, (;)) : (;)
)
Expand Down
19 changes: 15 additions & 4 deletions docs/performance.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ This page compares the performance of Julia's and Stan's log density and log den
```{julia}
include("julia/common.jl")
jdf = map(posterior_names) do posterior_name
PosteriorEvaluation(posterior_name).df_row
end |> pad_missing |> DataFrame;
PosteriorEvaluation(posterior_name).df_row
end |> pad_missing |> DataFrame;
```

:::{.column-page}
Expand Down Expand Up @@ -118,10 +118,19 @@ The below table shows information about the implemented posteriors. Will elabora

```{julia}
ternary(c, t, f) = c ? t : f
hl_best = HtmlHighlighter(
(data, i, j) -> (isa(data[i,j], UncertainStatistic) && isapprox(val(data[i,j]), 1.)),
HtmlDecoration(color = "blue")
);
hl_failed = HtmlHighlighter(
(data, i, j) -> ((data[i,j]==="FAILED")),
HtmlDecoration(color = "red")
);
pretty_table(
DataFrame(OrderedDict(
"posterior name"=>jdf.posterior_name,
"dimension"=>jdf.dimension,
"remaining relative lpdf error"=>jdf.lpdf_accuracy,
"relative mean primitive Julia runtime"=>jdf.julia_lpdf_times,
"relative mean primitive Stan runtime"=>jdf.stan_lpdf_times,
"relative mean Enzyme runtime"=>jdf.enzyme_times,
Expand All @@ -131,14 +140,16 @@ pretty_table(
"mean Enzyme allocations"=>jdf.enzyme_allocs,
"mean Mooncake allocations"=>jdf.mooncake_allocs,
"constant lpdf difference"=>jdf.lpdf_difference,
"remaining relative lpdf error"=>jdf.lpdf_accuracy,
"Enzyme relative gradient error"=>jdf.enzyme_accuracy,
"Mooncake relative gradient error"=>jdf.mooncake_accuracy,
"Enzyme"=>jdf.ENZYME_VERSION,
"Mooncake"=>jdf.MOONCAKE_VERSION,
"implementations"=>implementations_string.(jdf.posterior_name)
));
backend=Val(:html), show_subheader=false, table_class="interactive"
backend=Val(:html),
highlighters=(hl_best,hl_failed),
show_subheader=false,
table_class="interactive"
)
```
:::
6 changes: 3 additions & 3 deletions ext/PosteriorDBExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1557,8 +1557,7 @@ julia_implementation(::Val{:dogs_nonhierarchical}; n_dogs, n_trials, y, kwargs..
z::matrix[J, 2]
end
@model @views begin
logit_ab = rep_vector(1, J) * mu_logit_ab'
+ z * diag_pre_multiply(sigma_logit_ab, L_logit_ab);
logit_ab = rep_vector(1, J) * mu_logit_ab' + z * diag_pre_multiply(sigma_logit_ab, L_logit_ab);
a = inv_logit.(logit_ab[ : , 1]);
b = inv_logit.(logit_ab[ : , 2]);
y ~ bernoulli(@broadcasted(a ^ prev_shock * b ^ prev_avoid));
Expand Down Expand Up @@ -2437,6 +2436,7 @@ julia_implementation(::Val{:prophet};


logistic_gamma(k, m, delta, t_change, S) = begin
local gamma = zeros(S)
k_s = append_row(k, k + cumulative_sum(delta));

m_pr = m;
Expand All @@ -2448,7 +2448,7 @@ julia_implementation(::Val{:prophet};
end

logistic_trend(k, m, delta, t, cap, A, t_change, S) = begin
gamma = logistic_gamma(k, m, delta, t_change, S);
local gamma = logistic_gamma(k, m, delta, t_change, S);
return cap .* inv_logit.((k .+ A * delta) .* (t .- m .- A * gamma));
end

Expand Down
26 changes: 18 additions & 8 deletions src/functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# @inline bsum(x::Base.Broadcast.Broadcasted{Style,Axes,typeof(-),Tuple{T1,T2}}) where {Style,Axes,T1,T2} = flength(x) * (
# bsum(x.args[1]) / flength(x.args[1]) - bsum(x.args[2]) / flength(x.args[2])
# )
const loggamma = Distributions.SpecialFunctions.loggamma
begin
bsum_expr(::Type; x) = :(sum($x)/length($x))
bsum_expr(::Type{Base.Broadcast.Broadcasted{Style,Axes,typeof(+),Args}}; x) where {Style,Axes,Args} = Expr(:call, :+, [
Expand All @@ -30,9 +31,9 @@ ternary(c,t,f) = c ? t : f
# https://mc-stan.org/docs/functions-reference/real-valued_basic_functions.html#betafun
@inline choose(x, y) = binomial(x, y)
@inline lchoose(x, y) = (
+ Distributions.SpecialFunctions.loggamma(x+1)
- Distributions.SpecialFunctions.loggamma(y+1)
- Distributions.SpecialFunctions.loggamma(x-y+1)
+ loggamma(x+1)
- loggamma(y+1)
- loggamma(x-y+1)
)
# https://mc-stan.org/docs/functions-reference/unbounded_discrete_distributions.html#nbalt
@inline neg_binomial_2_lpdf(n, mu, phi) = bsum(@broadcasted(
Expand Down Expand Up @@ -85,11 +86,12 @@ ternary(c,t,f) = c ? t : f
@inline normal_lpdf(x, mu, sigma) = -bsum(@broadcasted(log(sigma)+.5*square((x-mu)/sigma)))
# https://mc-stan.org/docs/2_21/functions-reference/normal-id-glm.html
@inline normal_id_glm_lpdf(y,X,alpha,beta,sigma) = normal_lpdf(y, Base.broadcasted(+, alpha, X * beta), sigma)
# https://mc-stan.org/docs/functions-reference/positive_continuous_distributions.html#lognormal
@inline lognormal_lpdf(x, mu, sigma) = begin
bsum(@broadcasted(
-log(sigma)
-.5*log(x)
-.5*square((log(x)-mu)/sigma)
-bsum(@broadcasted(
log(sigma)
+log(x)
+.5*square((log(x)-mu)/sigma)
))
end
@inline weibull_lpdf(y, alpha, sigma) = bsum(@broadcasted(
Expand All @@ -100,9 +102,17 @@ end
))
@inline StudentT(nu, mu, sigma) = mu + sigma * TDist(nu)
@inline student_t_lpdf(x, args...) = bsum(@broadcasted(logpdf(StudentT(args...), x)))
# https://mc-stan.org/docs/functions-reference/unbounded_continuous_distributions.html#student-t-distribution
# @inline student_t_lpdf(y, nu, mu, sigma) = -bsum(@broadcasted(
# - loggamma((nu+1)/2)
# + loggamma(nu/2)
# + .5 * log(nu)
# + log(sigma)
# + .5 * (nu+1) * (log1p(square((y-mu)/sigma)/nu))
# ))
# https://mc-stan.org/docs/functions-reference/unbounded_continuous_distributions.html#cauchy-distribution
@inline cauchy_lpdf(x, location, scale) = begin
-bsum(@broadcasted(log(scale) + log1p((x-location)/scale)))
-bsum(@broadcasted(log(scale) + log1p(((x-location)/scale)^2)))
end
# https://mc-stan.org/docs/functions-reference/unbounded_continuous_distributions.html#logistic-distribution
@inline logistic_lpdf(y, location, scale) = begin
Expand Down

0 comments on commit cf7cc37

Please sign in to comment.