Skip to content

Commit

Permalink
Merge pull request #111 from PumasAI/trainbatcheddocstring
Browse files Browse the repository at this point in the history
move `train_batched!` docstring to correct function
  • Loading branch information
chriselrod authored Sep 27, 2022
2 parents 9db72ca + 791e92b commit f890bd5
Showing 1 changed file with 14 additions and 14 deletions.
28 changes: 14 additions & 14 deletions src/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -496,20 +496,6 @@ end
@inline view_slice_last(X::AbstractArray{<:Any,3}, r) = view(X, :, :, r)
@inline view_slice_last(X::AbstractArray{<:Any,4}, r) = view(X, :, :, :, r)
@inline view_slice_last(X::AbstractArray{<:Any,5}, r) = view(X, :, :, :, :, r)
"""
train_batched!(g::AbstractVecOrMat, p, chn, X, opt, iters; batchsize = nothing)
Train while batching arguments.
Arguments:
- `g` pre-allocated gradient buffer. Can be allocated with `similar(p)` (if you want to run single threaded), or `alloc_threaded_grad(chn, size(X))` (`size(X)` argument is only necessary if the input dimension was not specified when constructing the chain). If a matrix, the number of columns gives how many threads to use. Do not use more threads than batch size would allow.
- `p` is the parameter vector. It is updated inplace. It should be pre-initialized, e.g. with `init_params`/`init_params!`. This is to allow calling `train_unbatched!` several times to train in increments.
- `chn` is the `SimpleChain`. It must include a loss (see `SimpleChains.add_loss`) containing the target information (dependent variables) you're trying to fit.
- `X` the training data input argument (independent variables).
- `opt` is the optimizer. Currently, only `SimpleChains.ADAM` is supported.
- `iters`, how many iterations to train for.
- `batchsize` keyword argument: the size of the batches to use. If `batchsize = nothing`, it'll try to do a half-decent job of picking the batch size for you. However, this is not well optimized at the moment.
"""
function train_batched_core!(
_chn::Chain,
pu::Ptr{UInt8},
Expand Down Expand Up @@ -593,6 +579,20 @@ function train_batched_core!(
offset = static_sizeof(T) * aligned_glen * numthreads
train_batched_core!(c, pu + offset, g, p, pX, opt, iters, leaveofflast, mpt, N_bs)
end
"""
train_batched!(g::AbstractVecOrMat, p, chn, X, opt, iters; batchsize = nothing)
Train while batching arguments.
Arguments:
- `g` pre-allocated gradient buffer. Can be allocated with `similar(p)` (if you want to run single threaded), or `alloc_threaded_grad(chn, size(X))` (`size(X)` argument is only necessary if the input dimension was not specified when constructing the chain). If a matrix, the number of columns gives how many threads to use. Do not use more threads than batch size would allow.
- `p` is the parameter vector. It is updated inplace. It should be pre-initialized, e.g. with `init_params`/`init_params!`. This is to allow calling `train_unbatched!` several times to train in increments.
- `chn` is the `SimpleChain`. It must include a loss (see `SimpleChains.add_loss`) containing the target information (dependent variables) you're trying to fit.
- `X` the training data input argument (independent variables).
- `opt` is the optimizer. Currently, only `SimpleChains.ADAM` is supported.
- `iters`, how many iterations to train for.
- `batchsize` keyword argument: the size of the batches to use. If `batchsize = nothing`, it'll try to do a half-decent job of picking the batch size for you. However, this is not well optimized at the moment.
"""
function train_batched!(
g::Union{Nothing,AbstractVector,AbstractMatrix},
p::AbstractVector,
Expand Down

0 comments on commit f890bd5

Please sign in to comment.