Skip to content

Commit

Permalink
Better dist (#81)
Browse files Browse the repository at this point in the history
* use better dist

* update Dense api usage

* add I to sigma
  • Loading branch information
prbzrg authored Jul 16, 2022
1 parent 98c51ef commit 55be9f0
Show file tree
Hide file tree
Showing 9 changed files with 15 additions and 12 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down Expand Up @@ -40,6 +41,7 @@ DataFrames = "1.2"
DifferentialEquations = "7.0"
Distributions = "0.25"
DistributionsAD = "0.6"
FillArrays = "0.13"
Flux = "0.13"
IterTools = "1.3"
MLJBase = "0.20"
Expand Down
1 change: 1 addition & 0 deletions src/ICNF.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ module ICNF
DifferentialEquations,
Distributions,
DistributionsAD,
FillArrays,
Flux,
IterTools,
MLJBase,
Expand Down
2 changes: 1 addition & 1 deletion src/cond_ffjord.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ function CondFFJORD{T}(
nn,
nvars::Integer,
;
basedist::Distribution=MvNormal(zeros(T, nvars), Diagonal(ones(T, nvars))),
basedist::Distribution=MvNormal(Zeros{T}(nvars), one(T)*I),
tspan::Tuple{T, T}=convert(Tuple{T, T}, (0, 1)),

solvealg_test::SciMLBase.AbstractODEAlgorithm=default_solvealg,
Expand Down
2 changes: 1 addition & 1 deletion src/cond_planar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ function CondPlanar{T}(
nn::PlanarNN,
nvars::Integer,
;
basedist::Distribution=MvNormal(zeros(T, nvars), Diagonal(ones(T, nvars))),
basedist::Distribution=MvNormal(Zeros{T}(nvars), one(T)*I),
tspan::Tuple{T, T}=convert(Tuple{T, T}, (0, 1)),

solvealg_test::SciMLBase.AbstractODEAlgorithm=default_solvealg,
Expand Down
2 changes: 1 addition & 1 deletion src/cond_rnode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ function CondRNODE{T}(
nn,
nvars::Integer,
;
basedist::Distribution=MvNormal(zeros(T, nvars), Diagonal(ones(T, nvars))),
basedist::Distribution=MvNormal(Zeros{T}(nvars), one(T)*I),
tspan::Tuple{T, T}=convert(Tuple{T, T}, (0, 1)),

solvealg_test::SciMLBase.AbstractODEAlgorithm=default_solvealg,
Expand Down
2 changes: 1 addition & 1 deletion src/ffjord.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ function FFJORD{T}(
nn,
nvars::Integer,
;
basedist::Distribution=MvNormal(zeros(T, nvars), Diagonal(ones(T, nvars))),
basedist::Distribution=MvNormal(Zeros{T}(nvars), one(T)*I),
tspan::Tuple{T, T}=convert(Tuple{T, T}, (0, 1)),

solvealg_test::SciMLBase.AbstractODEAlgorithm=default_solvealg,
Expand Down
2 changes: 1 addition & 1 deletion src/planar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ function Planar{T}(
nn::PlanarNN,
nvars::Integer,
;
basedist::Distribution=MvNormal(zeros(T, nvars), Diagonal(ones(T, nvars))),
basedist::Distribution=MvNormal(Zeros{T}(nvars), one(T)*I),
tspan::Tuple{T, T}=convert(Tuple{T, T}, (0, 1)),

solvealg_test::SciMLBase.AbstractODEAlgorithm=default_solvealg,
Expand Down
2 changes: 1 addition & 1 deletion src/rnode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ function RNODE{T}(
nn,
nvars::Integer,
;
basedist::Distribution=MvNormal(zeros(T, nvars), Diagonal(ones(T, nvars))),
basedist::Distribution=MvNormal(Zeros{T}(nvars), one(T)*I),
tspan::Tuple{T, T}=convert(Tuple{T, T}, (0, 1)),

solvealg_test::SciMLBase.AbstractODEAlgorithm=default_solvealg,
Expand Down
12 changes: 6 additions & 6 deletions test/smoke_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
nn = PlanarNN(nvars, tanh)
else
nn = Chain(
Dense(nvars, nvars, tanh),
Dense(nvars => nvars, tanh),
)
end
icnf = mt{tp}(nn, nvars; acceleration=cr)
Expand Down Expand Up @@ -128,7 +128,7 @@
nn = PlanarNN(nvars, tanh)
else
nn = Chain(
Dense(nvars, nvars, tanh),
Dense(nvars => nvars, tanh),
)
end
icnf = mt{tp}(nn, nvars; acceleration=cr)
Expand All @@ -147,7 +147,7 @@
nn = PlanarNN(nvars, tanh)
else
nn = Chain(
Dense(nvars, nvars, tanh),
Dense(nvars => nvars, tanh),
)
end
icnf = mt{tp}(nn, nvars; acceleration=cr)
Expand All @@ -174,7 +174,7 @@
nn = PlanarNN(nvars, tanh; cond=true)
else
nn = Chain(
Dense(nvars*2, nvars, tanh),
Dense(2*nvars => nvars, tanh),
)
end
icnf = mt{tp}(nn, nvars; acceleration=cr)
Expand Down Expand Up @@ -262,7 +262,7 @@
nn = PlanarNN(nvars, tanh; cond=true)
else
nn = Chain(
Dense(nvars*2, nvars, tanh),
Dense(2*nvars => nvars, tanh),
)
end
icnf = mt{tp}(nn, nvars; acceleration=cr)
Expand All @@ -281,7 +281,7 @@
nn = PlanarNN(nvars, tanh; cond=true)
else
nn = Chain(
Dense(nvars*2, nvars, tanh),
Dense(2*nvars => nvars, tanh),
)
end
icnf = mt{tp}(nn, nvars; acceleration=cr)
Expand Down

0 comments on commit 55be9f0

Please sign in to comment.