Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rest and intercept column to glexobj$m #25

Open
wants to merge 19 commits into
base: master
Choose a base branch
from
Open

Add rest and intercept column to glexobj$m #25

wants to merge 19 commits into from

Conversation

jyliuu
Copy link
Collaborator

@jyliuu jyliuu commented Nov 18, 2024

To be merged after #24.
This PR adds an intercept column to glex_obj$m for all supported tree-methods. Moreover, if the task is regression, a rest column will also be added which is equal to predictions - rowSums(glex_obj$m)

@jyliuu jyliuu marked this pull request as ready for review November 18, 2024 15:09
@jemus42
Copy link
Collaborator

jemus42 commented Nov 19, 2024

I think we need to be careful here, as changing the structure of the glex object has ripple effects across every plotting and reshaping function and needs to be consistent across the 2 (or 3, counting rpf) supported learners.

Some thoughts:

  1. $intercept was stored as a scalar specifically to avoid having to store a constant column in $m. I understand that it's more convenient to have it all stored in $m such that things like rowSums "just work", but it also feels kind of inefficient (probably wouldn't matter too much, I admit). We can change that structure if it ends up being easier down the road, but there's a few considerations to make, even minor things like "what happens if the dataset happens to contain a feature named intercept for some reason?"

  2. For the randomPlantedForest output, there already is a $remainder term, see example below. I don't care whether we call it $rest or $remainder, but naturally we should be consistent 😅 I also see that it doesn't help that rpf basically has its own machinery in a different package, which just adds to the complexity. That's one of the things that makes the remainder thing complicated, with the other part being that the remainder needs to make sense for regression, binary classification, and multiclass classification, which requires extra handling. See e.g. here in rpf https://github.com/PlantedML/randomPlantedForest/blob/41fe7eef99cfc60fbc04f6a08d30a932ffc097e0/R/predict_components.R#L124-L143 and see Add remainder vector to glex output #11 as well. Happy to make progress here though, it's been on my list for some time, but I wanted to avoid a "this just works for regression now sry" type situation, which I always find frustrating as a user.

  3. On that note, we might also address the part where $shap is stored even if max_interaction is set in glex which then makes $shap effectively meaningless (see also Calculate shap values only for the selected features #18)

Example for remainder term

library(glex)
library(xgboost)
library(randomPlantedForest)
set.seed(234)
options(max.print = 10)

# this is completely arbitrary nonsense
xdat <- data.frame(
  x1 = rnorm(100),
  x2 = rpois(100, 2),
  x3 = runif(100)
)
xdat <- within(xdat, y <- 3 * x1 + 0.5 * (x2 + x3) + 3 * abs(x1 * x3))

# rpf has remainder term
rpf_fit <- rpf(y ~ ., data = xdat, num.trees = 50, max_interaction = 3)
rpf_glex <- glex(rpf_fit, xdat, max_interaction = 2)
rpf_glex$remainder
#>  [1] -0.0153642314 -0.0166246943 -0.0396781809  0.0453551935 -0.0003334053
#>  [6]  0.0038383014  0.0034622827  0.0306458754 -0.0007525907 -0.0162234933
#>  [ reached getOption("max.print") -- omitted 90 entries ]

# also, intercept is stored as scalar
rpf_glex$intercept
#> [1] 2.053044


# xgb not yet
xgb_fit <- xgboost(data = as.matrix(xdat[, 1:3]), label = xdat$y, max_depth = 3, 
                   early_stopping_rounds = 50, nrounds = 1000, verbose = FALSE)
xgb_glex <- glex(xgb_fit, as.matrix(xdat[, 1:3]), max_interaction = 2)
xgb_glex$remainder
#> NULL

xgb_glex$intercept
#> [1] 2.141527

# Also, shap is stored but known to be wrong due to max_interaction limit
xgb_glex$shap
#>               x1           x2          x3
#>            <num>        <num>       <num>
#>   1:  1.98061587  0.056079863  0.12250051
#>   2: -3.80587144 -0.281592150 -0.33564143
#>  [ reached getOption("max.print") -- omitted 99 rows ]

Created on 2024-11-19 with reprex v2.1.1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants