diff --git a/gosdt/model/threshold_guess.py b/gosdt/model/threshold_guess.py index a7d721c..362015c 100644 --- a/gosdt/model/threshold_guess.py +++ b/gosdt/model/threshold_guess.py @@ -4,7 +4,7 @@ import time import random import sys -import os +import os from queue import Queue import pathlib @@ -16,7 +16,7 @@ # fit the tree using gradient boosted classifier def fit_boosted_tree(X, y, n_est=10, lr=0.1, d=1): - clf = GradientBoostingClassifier(loss='deviance', learning_rate=lr, n_estimators=n_est, max_depth=d, + clf = GradientBoostingClassifier(loss='log_loss', learning_rate=lr, n_estimators=n_est, max_depth=d, random_state=42) clf.fit(X, y) out = clf.score(X, y)