diff --git a/src/optimize.jl b/src/optimize.jl index 5bcd0c3..984e151 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -496,20 +496,6 @@ end @inline view_slice_last(X::AbstractArray{<:Any,3}, r) = view(X, :, :, r) @inline view_slice_last(X::AbstractArray{<:Any,4}, r) = view(X, :, :, :, r) @inline view_slice_last(X::AbstractArray{<:Any,5}, r) = view(X, :, :, :, :, r) -""" - train_batched!(g::AbstractVecOrMat, p, chn, X, opt, iters; batchsize = nothing) - -Train while batching arguments. - -Arguments: -- `g` pre-allocated gradient buffer. Can be allocated with `similar(p)` (if you want to run single threaded), or `alloc_threaded_grad(chn, size(X))` (`size(X)` argument is only necessary if the input dimension was not specified when constructing the chain). If a matrix, the number of columns gives how many threads to use. Do not use more threads than batch size would allow. -- `p` is the parameter vector. It is updated inplace. It should be pre-initialized, e.g. with `init_params`/`init_params!`. This is to allow calling `train_unbatched!` several times to train in increments. -- `chn` is the `SimpleChain`. It must include a loss (see `SimpleChains.add_loss`) containing the target information (dependent variables) you're trying to fit. -- `X` the training data input argument (independent variables). -- `opt` is the optimizer. Currently, only `SimpleChains.ADAM` is supported. -- `iters`, how many iterations to train for. -- `batchsize` keyword argument: the size of the batches to use. If `batchsize = nothing`, it'll try to do a half-decent job of picking the batch size for you. However, this is not well optimized at the moment. -""" function train_batched_core!( _chn::Chain, pu::Ptr{UInt8}, @@ -593,6 +579,20 @@ function train_batched_core!( offset = static_sizeof(T) * aligned_glen * numthreads train_batched_core!(c, pu + offset, g, p, pX, opt, iters, leaveofflast, mpt, N_bs) end +""" + train_batched!(g::AbstractVecOrMat, p, chn, X, opt, iters; batchsize = nothing) + +Train while batching arguments. + +Arguments: +- `g` pre-allocated gradient buffer. Can be allocated with `similar(p)` (if you want to run single threaded), or `alloc_threaded_grad(chn, size(X))` (`size(X)` argument is only necessary if the input dimension was not specified when constructing the chain). If a matrix, the number of columns gives how many threads to use. Do not use more threads than batch size would allow. +- `p` is the parameter vector. It is updated inplace. It should be pre-initialized, e.g. with `init_params`/`init_params!`. This is to allow calling `train_unbatched!` several times to train in increments. +- `chn` is the `SimpleChain`. It must include a loss (see `SimpleChains.add_loss`) containing the target information (dependent variables) you're trying to fit. +- `X` the training data input argument (independent variables). +- `opt` is the optimizer. Currently, only `SimpleChains.ADAM` is supported. +- `iters`, how many iterations to train for. +- `batchsize` keyword argument: the size of the batches to use. If `batchsize = nothing`, it'll try to do a half-decent job of picking the batch size for you. However, this is not well optimized at the moment. +""" function train_batched!( g::Union{Nothing,AbstractVector,AbstractMatrix}, p::AbstractVector,