From f33e3b579d685a570cd7edf1b9a10944c379c491 Mon Sep 17 00:00:00 2001 From: fis Date: Thu, 28 May 2020 18:19:13 +0800 Subject: [PATCH] Fix loading old model. --- src/learner.cc | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/src/learner.cc b/src/learner.cc index de9620c9a8b2..1b92d1db77b5 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -694,15 +694,24 @@ class LearnerIO : public LearnerConfiguration { warn_old_model = false; } - if (mparam_.major_version >= 1) { - learner_model_param_ = LearnerModelParam(mparam_, - obj_->ProbToMargin(mparam_.base_score)); - } else { + if (mparam_.major_version < 1) { // Before 1.0.0, base_score is saved as a transformed value, and there's no version // attribute in the saved model. - learner_model_param_ = LearnerModelParam(mparam_, mparam_.base_score); + std::string multi{"multi:"}; + if (!std::equal(tparam_.objective.cbegin(), tparam_.objective.cend(), + multi.begin())) { + HostDeviceVector t; + t.HostVector().resize(1); + t.HostVector().at(0) = mparam_.base_score; + this->obj_->PredTransform(&t); + auto base_score = t.HostVector().at(0); + mparam_.base_score = base_score; + } warn_old_model = true; } + + learner_model_param_ = + LearnerModelParam(mparam_, obj_->ProbToMargin(mparam_.base_score)); if (attributes_.find("objective") != attributes_.cend()) { auto obj_str = attributes_.at("objective"); auto j_obj = Json::Load({obj_str.c_str(), obj_str.size()});