Skip to content

Commit

Permalink
bugfix: big loss (#118)
Browse files Browse the repository at this point in the history
  • Loading branch information
prbzrg authored Jan 16, 2023
1 parent f2c5776 commit 13149f8
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions src/core_cond_icnf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ function MLJModelInterface.fit(model::CondICNFModel, verbosity, XY)
y = convert(model.array_type, y)
data = DataLoader((x, y); batchsize = model.batch_size, shuffle = true, partial = true)
ncdata = ncycle(data, model.n_epochs)
initial_loss_value = model.loss(model.m, x, y)
initial_loss_value = model.loss(model.m, first(data)...)

if model.opt_app isa FluxOptApp
model.optimizer isa Optimisers.AbstractRule ||
Expand Down Expand Up @@ -308,7 +308,7 @@ function MLJModelInterface.fit(model::CondICNFModel, verbosity, XY)
tst = @timed res = solve(optprob, model.optimizer, ncdata; callback = _callback)
model.m.p .= res.u
end
final_loss_value = model.loss(model.m, x, y)
final_loss_value = model.loss(model.m, first(data)...)
@info(
"Fitting",
"elapsed time (seconds)" = tst.time,
Expand Down
4 changes: 2 additions & 2 deletions src/core_icnf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ function MLJModelInterface.fit(model::ICNFModel, verbosity, X)
x = convert(model.array_type, x)
data = DataLoader((x,); batchsize = model.batch_size, shuffle = true, partial = true)
ncdata = ncycle(data, model.n_epochs)
initial_loss_value = model.loss(model.m, x)
initial_loss_value = model.loss(model.m, first(data)...)

if model.opt_app isa FluxOptApp
model.optimizer isa Optimisers.AbstractRule ||
Expand Down Expand Up @@ -291,7 +291,7 @@ function MLJModelInterface.fit(model::ICNFModel, verbosity, X)
tst = @timed res = solve(optprob, model.optimizer, ncdata; callback = _callback)
model.m.p .= res.u
end
final_loss_value = model.loss(model.m, x)
final_loss_value = model.loss(model.m, first(data)...)
@info(
"Fitting",
"elapsed time (seconds)" = tst.time,
Expand Down

0 comments on commit 13149f8

Please sign in to comment.