-
-
Notifications
You must be signed in to change notification settings - Fork 50
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add tabular model #124
Add tabular model #124
Conversation
The embedding layer used here is based on FluxML/Flux.jl#1516. |
src/models/tabularmodel.jl
Outdated
x = tm.emb_drop(x) | ||
end | ||
if tm.n_cont != 0 | ||
if (tm.bn_cont != false) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Equivalent to if tm.bn_cont == true
which is equivalent to if tm.bn_cont
.
if (tm.bn_cont != false) | |
if tm.bn_cont |
src/models/tabularmodel.jl
Outdated
actns = append!([], [act_cls for i in 1:(length(sizes)-1)], [nothing]) | ||
_layers = [linbndrop(Int64(sizes[i]), Int64(sizes[i+1]), use_bn=(use_bn && ((i!=(length(actns)-1)) || bn_final)), p=p, act=a, lin_first=lin_first) for (i, (p, a)) in enumerate(zip(push!(ps, 0.), actns))] | ||
if !isnothing(y_range) | ||
push!(_layers, Chain(@. x->Flux.sigmoid(x) * (y_range[2] - y_range[1]) + y_range[1])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't need to be wrapped in a Chain. Also, is the sigmoid range meant to replace the activation of the last linear layer or to run after it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the case of y_range == nothing
, should it be sigmoid
? Right no there is no activation at all.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it should be sigmoid
, then I think it makes more sense to force y_range::Tuple
and make the default as (0, 1)
. Just avoid nothing
entirely.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, I don't think there is an activation if y_range == nothing
according to the python implementation.
src/models/tabularmodel.jl
Outdated
x = [e(x_cat[i, :]) for (i, e) in enumerate(tm.embeds)] | ||
x = vcat(x...) | ||
x = tm.emb_drop(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if there's a way to use Parallel
for tm.embeds
instead of writing all this out manually?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could do Parallel(vcat, [Embedding(ni, nf) for (ni, nf) in emb_szs])
. In the Chain
, before the embedding layers, you would need x -> ntuple(i -> x[i, :], length(emb_szs))
to split a single x
into multiple arguments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think splatting eachslice(x, dims=1)
would work too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I considered using Parallel
first but just wasn't sure how to handle the input for that. This should simplify things a lot.
src/models/tabularmodel.jl
Outdated
n_emb = sum(size(embedlayer.weight)[1] for embedlayer in embedslist) | ||
sizes = append!(zeros(0), [n_emb+n_cont], layers, [out_sz]) | ||
actns = append!([], [act_cls for i in 1:(length(sizes)-1)], [nothing]) | ||
_layers = [linbndrop(Int64(sizes[i]), Int64(sizes[i+1]), use_bn=(use_bn && ((i!=(length(actns)-1)) || bn_final)), p=p, act=a, lin_first=lin_first) for (i, (p, a)) in enumerate(zip(push!(ps, 0.), actns))] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this could benefit from splitting into a few lines. Also, do you recall why the Int64
cast was added?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would echo this suggestion. In particular, instead of having complicated use_bn
logic, just use two lines. First, iterate over the 1:(n - 1)
elements, then just do push!(_layers, # last element version)
.
Also, I think this will be cleaner as a standard for
-loop instead of a generator. Instead of pre-computing sizes
, could we use Flux.outputsize
inside the loop?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the Int64
cast was added because BatchNorm
and Dense
don't seem to work with floats. I could maybe restrict the types to Int64
only though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I can create the last linbndrop
outside the loop. Although I'm not sure how the sizes for Flux.outputsize
would work with the Parallel
layer before it. Would it be better than using sizes
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a good start, but I think we can improve the readability on top of the original Python code a lot here. In particular, I think there are two places where Parallel
will intuitively express the operation and eliminate a lot of boilerplate both in the constructor and the forward pass.
src/models/tabularmodel.jl
Outdated
emb_szs, | ||
n_cont, | ||
out_sz, | ||
ps::Union{Tuple, Vector, Number, Nothing}=nothing, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would recommend pushing the isnothing
etc. logic to linbndrop
(it already handles the 0 case). It can be extended to handle nothing
and false
too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay. I'm even considering removing nothing
as it won't work directly with Iterators.cycle
if we do use that for ps
.
src/models/tabularmodel.jl
Outdated
if isnothing(ps) | ||
ps = zeros(length(layers)) | ||
end | ||
if ps isa Number | ||
ps = fill(ps, length(layers)) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In addition to the comment about isnothing(p)
being part of linbndrop
, this can also be deleted by doing Iterators.cycle(ps)
when building _layers
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I'll try doing this.
src/models/tabularmodel.jl
Outdated
end | ||
embedslist = [Embedding(ni, nf) for (ni, nf) in emb_szs] | ||
emb_drop = Dropout(embed_p) | ||
bn_cont = bn_cont ? BatchNorm(n_cont) : false |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could make this bn_cont = bn_cont ? BatchNorm(n_cont) : identity
then avoid the if/else logic in the forward pass.
src/models/tabularmodel.jl
Outdated
actns = append!([], [act_cls for i in 1:(length(sizes)-1)], [nothing]) | ||
_layers = [linbndrop(Int64(sizes[i]), Int64(sizes[i+1]), use_bn=(use_bn && ((i!=(length(actns)-1)) || bn_final)), p=p, act=a, lin_first=lin_first) for (i, (p, a)) in enumerate(zip(push!(ps, 0.), actns))] | ||
if !isnothing(y_range) | ||
push!(_layers, Chain(@. x->Flux.sigmoid(x) * (y_range[2] - y_range[1]) + y_range[1])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the case of y_range == nothing
, should it be sigmoid
? Right no there is no activation at all.
src/models/tabularmodel.jl
Outdated
bn_cont = bn_cont ? BatchNorm(n_cont) : false | ||
n_emb = sum(size(embedlayer.weight)[1] for embedlayer in embedslist) | ||
sizes = append!(zeros(0), [n_emb+n_cont], layers, [out_sz]) | ||
actns = append!([], [act_cls for i in 1:(length(sizes)-1)], [nothing]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be avoided if you split the final layer into a separate line.
src/models/tabularmodel.jl
Outdated
if tm.n_emb != 0 | ||
x = [e(x_cat[i, :]) for (i, e) in enumerate(tm.embeds)] | ||
x = vcat(x...) | ||
x = tm.emb_drop(x) | ||
end | ||
if tm.n_cont != 0 | ||
if (tm.bn_cont != false) | ||
x_cont = tm.bn_cont(x_cont) | ||
end | ||
x = tm.n_emb!=0 ? vcat(x, x_cont) : x_cont | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would suggest making the whole categorical branch/continuous branch portion of the model into a Parallel
. Something like Parallel(vcat, embeds, bn_cont)
. It should gracefully handle the cases when either the categorical or continuous vectors are empty.
90e5bc9
to
fa5c563
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from call
src/models/tabularmodel.jl
Outdated
_layers = [] | ||
for (i, (p, a)) in enumerate(zip(Iterators.cycle(ps), actns)) | ||
layer = linbndrop(Int64(sizes[i]), Int64(sizes[i+1]), use_bn=use_bn, p=p, act=a, lin_first=lin_first) | ||
push!(_layers, layer) | ||
end | ||
push!(_layers, linbndrop(Int64(last(sizes)), Int64(out_sz), use_bn=bn_final, lin_first=lin_first)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_layers = [] | |
for (i, (p, a)) in enumerate(zip(Iterators.cycle(ps), actns)) | |
layer = linbndrop(Int64(sizes[i]), Int64(sizes[i+1]), use_bn=use_bn, p=p, act=a, lin_first=lin_first) | |
push!(_layers, layer) | |
end | |
push!(_layers, linbndrop(Int64(last(sizes)), Int64(out_sz), use_bn=bn_final, lin_first=lin_first)) | |
n_emb = first(Flux.outputsize(embeds, (length(emb_szs), 1))) | |
classifiers = [ | |
linbndrop(n_cat, first(layers); use_bn=use_bn, p=first(ps), act=act_cls, lin_first=lin_first), | |
[linbndrop(isize, osize; use_bn=use_bn, p=p, act=act_cls, lin_first=lin_first) | |
for (isize, osize, p) in zip(layers[1:(end - 2)], layers[2:(end - 1)], Iterators.cycle(Base.tail(ps)))]..., | |
linbndrop(last(layers)), out_sz; use_bn=bn_final, lin_first=lin_first) | |
] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would go so far as to convert the comprehension into a plain for loop with push!
. Not any less efficient than the splat and would help with line lengths.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah maybe that's better here since we are using so many iterables in the comprehension
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seeing that Flux.outputsize
isn't working with Embedding
layers, I was thinking of keeping the current way of calculating n_emb
for now.
src/models/tabularmodel.jl
Outdated
@@ -0,0 +1,52 @@ | |||
function emb_sz_rule(n_cat) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add a comment with a link to where this is taken from?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure I can add this link. I believe they got this formula experimentally.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking really good. Left a few comments.
src/models/tabularmodel.jl
Outdated
|
||
bn_cont = bn_cont ? BatchNorm(n_cont) : identity | ||
|
||
n_emb = sum(size(embedlayer.weight)[1] for embedlayer in embedslist) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would be clearer with Flux.outputsize
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@manikyabard can you paste the error you got with Embedding
and outputsize
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, so if I do something like this
Flux.outputsize(Embedding(2, 3), (2,1))
the error I get is
ArgumentError: unable to check bounds for indices of type Flux.NilNumber.Nil
Stacktrace:
[1] checkindex(#unused#::Type{Bool}, inds::Base.OneTo{Int64}, i::Flux.NilNumber.Nil)
@ Base ./abstractarray.jl:671
[2] checkindex
@ ./abstractarray.jl:686 [inlined]
[3] checkbounds_indices (repeats 2 times)
@ ./abstractarray.jl:642 [inlined]
[4] checkbounds
@ ./abstractarray.jl:595 [inlined]
[5] checkbounds
@ ./abstractarray.jl:616 [inlined]
[6] _getindex
@ ./multidimensional.jl:831 [inlined]
[7] getindex
@ ./abstractarray.jl:1170 [inlined]
[8] Embedding
@ ~/.julia/packages/Flux/wii6E/src/layers/basic.jl:421 [inlined]
[9] (::Embedding{Matrix{Float32}})(x::Matrix{Flux.NilNumber.Nil})
@ Flux ~/.julia/packages/Flux/wii6E/src/layers/basic.jl:422
[10] #outputsize#279
@ ~/.julia/packages/Flux/wii6E/src/outputsize.jl:93 [inlined]
[11] outputsize(m::Embedding{Matrix{Float32}}, inputsizes::Tuple{Int64, Int64})
@ Flux ~/.julia/packages/Flux/wii6E/src/outputsize.jl:92
[12] top-level scope
@ In[213]:1
[13] eval
@ ./boot.jl:360 [inlined]
[14] include_string(mapexpr::typeof(REPL.softscope), mod::Module, code::String, filename::String)
@ Base ./loading.jl:1094
src/models/tabularmodel.jl
Outdated
end | ||
push!(classifiers, linbndrop(last(layers), out_sz; use_bn=bn_final, lin_first=lin_first)) | ||
|
||
layers = isnothing(y_range) ? Chain(Parallel(vcat, embeds, bn_cont), classifiers...) : Chain(Parallel(vcat, embeds, bn_cont), classifiers..., @. x->Flux.sigmoid(x) * (y_range[2] - y_range[1]) + y_range[1]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's switch this to something like final_activation = identity
as a kwarg.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, then I can make a sigmoidrange
function and remove the y_range
part. If users want to use this then something like x -> sigmoidrange(x, 4, 6)
(for getting values between 4 and 6) should work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup, that's what I was thinking too
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just some comments from the call. Also we want to refactor as TabularModel(EmbeddingBackbone, ContinuousBackbone, layers)
like we discussed.
src/models/tabularmodel.jl
Outdated
x -> collect(eachrow(x)), | ||
x -> ntuple(i -> x[i], length(x)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
x -> collect(eachrow(x)), | |
x -> ntuple(i -> x[i], length(x)), | |
x -> tuple(eachrow(x)...), |
src/models/tabularmodel.jl
Outdated
end | ||
push!(classifiers, linbndrop(last(layers), out_sz; use_bn=bn_final, lin_first=lin_first)) | ||
layers = Chain( | ||
x -> tuple(x...), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
x -> tuple(x...), |
src/models/tabularmodel.jl
Outdated
push!(classifiers, linbndrop(last(layers), out_sz; use_bn=bn_final, lin_first=lin_first)) | ||
layers = Chain( | ||
x -> tuple(x...), | ||
Parallel(vcat, embeds, Chain(x -> ndims(x)==1 ? Flux.unsqueeze(x, 2) : x, bn_cont)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move the unsqueeze
to the encode method
7fff9d9
to
eb2cfc8
Compare
Great to have tests! I think there're just some imports missing which is why they're failing |
Yeah I think it's probably because the |
Tests are still failing because |
Co-authored-by: Kyle Daruwalla <daruwalla.k.public@icloud.com>
eb2cfc8
to
04d27d4
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's maybe knock out the final restructuring of this on the next call. I think it's possible to eliminate a lot of these keywords while maintaining flexibility (or at least restructure them so that it is clearer what each keyword does).
src/models/tabularmodel.jl
Outdated
function classifierbackbone( | ||
layers; | ||
ps=0, | ||
use_bn=true, | ||
bn_final=false, | ||
act_cls=Flux.relu, | ||
lin_first=true) | ||
ps = Iterators.cycle(ps) | ||
classifiers = [] | ||
|
||
for (isize, osize, p) in zip(layers[1:(end-1)], layers[2:end], ps) | ||
layer = linbndrop(isize, osize; use_bn=use_bn, p=p, act=act_cls, lin_first=lin_first) | ||
push!(classifiers, layer) | ||
end | ||
Chain(classifiers...) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wasn't thinking that the linear chain should be customizable. I think this loop should get pushed into TabularModel
. My suggestion was that bn_final
and final_activation
be lumped into a single classifier
argument to TabularModel
which is positional and defaults to Dense(in, out)
.
src/models/tabularmodel.jl
Outdated
function continuousbackbone(n_cont) | ||
n_cont > 0 ? BatchNorm(n_cont) : identity | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know how useful it is to have this function.
src/models/tabularmodel.jl
Outdated
n_cont::Number, | ||
out_sz::Number, | ||
layers=[200, 100]; | ||
catdict, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have some questions about understanding catdict
that we can discuss during the call.
src/models/tabularmodel.jl
Outdated
@@ -0,0 +1,52 @@ | |||
function emb_sz_rule(n_cat) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly minor name changes to make things more readable. Only major change is to collapse catcols
and cardinalitydict
into a single argument. We want to optimize the API to be simple for the common case.
src/models/tabularmodel.jl
Outdated
- `use_bn`: Boolean variable which controls whether to use batch normalization in the classifier. | ||
- `act_cls`: The activation function to use in the classifier layers. | ||
- `lin_first`: Controls if the linear layer comes before or after BatchNorm and Dropout. | ||
- `cardinalitydict`: An indexable collection which maps to the cardinality for each column present |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Couldn't this be a vector of the cardinalities directly? Instead of having catcols
be column names and cardinalitydict
be a map from name => cardinality, just make catcols
a vector of cardinalities.
I also think that cardinality can be confusing lingo for non-ML oriented users. Something like "vector of sizes (number of labels) for each categorical column."
src/models/tabularmodel.jl
Outdated
catcols, | ||
n_cont::Number, | ||
outsz::Number, | ||
layers=[200, 100]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use a tuple
Co-authored-by: Kyle Daruwalla <daruwalla.k.public@icloud.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Small nitpicks mostly. Let's also split some of these docstrings so it is less confusing trying to explain multiple sets of arguments in a single paragraph.
src/models/tabularmodel.jl
Outdated
""" | ||
get_emb_sz(cardinalities, [size_overrides]) | ||
get_emb_sz(cardinalities; catcols, [size_overrides]) | ||
|
||
Returns a collection of tuples containing embedding dimensions corresponding to | ||
number of classes in categorical columns present in `cardinalities` and adjusting for nans. | ||
|
||
## Keyword arguments | ||
|
||
- `size_overrides`: Depending on the method used, this could either be a collection of | ||
Integers and `nothing` or an indexable collection with column name as key and size | ||
to override it with as the value. In the first case, the integer present at any index | ||
will be used to override the rule of thumb for getting embedding sizes. | ||
- `categorical_cols`: A collection of categorical column names. | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This docstring is better off split into two separate docstrings for each method. Both will show up automatically in the docs, but you will be able to tailor the explanation of size_overrides
to each method (instead of explaining both in one paragraph).
src/models/tabularmodel.jl
Outdated
function tabular_embedding_backbone(embedding_sizes, dropoutprob=0.) | ||
embedslist = [Flux.Embedding(ni, nf) for (ni, nf) in embedding_sizes] | ||
emb_drop = dropoutprob==0. ? identity : Dropout(dropoutprob) | ||
Chain( | ||
x -> tuple(eachrow(x)...), | ||
Parallel(vcat, embedslist), | ||
emb_drop | ||
) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
function tabular_embedding_backbone(embedding_sizes, dropoutprob=0.) | |
embedslist = [Flux.Embedding(ni, nf) for (ni, nf) in embedding_sizes] | |
emb_drop = dropoutprob==0. ? identity : Dropout(dropoutprob) | |
Chain( | |
x -> tuple(eachrow(x)...), | |
Parallel(vcat, embedslist), | |
emb_drop | |
) | |
end | |
function tabular_embedding_backbone(embedding_sizes, dropout_rates=0.) | |
embedslist = [Flux.Embedding(ni, nf) for (ni, nf) in embedding_sizes] | |
emb_drop = iszero(dropout_rates) ? identity : Dropout(dropout_rates) | |
Chain( | |
x -> tuple(eachrow(x)...), | |
Parallel(vcat, embedslist), | |
emb_drop | |
) | |
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
emb_drop
could even be further inlined if desired. If preserving model structure is more important than a bit of lost performance, then passing active=iszero(dropout_rates)
instead of using a ternary is also a valid option.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is active
an argument to something? How would that work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
active
is the third arg to BatchNorm
, yes. I wouldn't worry about that now though, it's a minor tweak we can always revisit later.
Co-authored-by: Kyle Daruwalla <daruwalla.k.public@icloud.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly docstring changes.
Co-authored-by: Kyle Daruwalla <daruwalla.k.public@icloud.com>
@lorenzoh this is good to go. I'll let you merge to have the final skim. |
Adds a tabular model based on the python fastai implementation.