From 8467880aeb442ea32f18fcdde93369bbfddd3fba Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Mon, 1 Jun 2020 04:32:24 +0800 Subject: [PATCH] Fix loading old model. (#5724) (#5737) * Add test. --- src/learner.cc | 20 ++++++++++++++------ tests/pytest.ini | 3 ++- tests/python/test_model_compatibility.py | 4 ++++ 3 files changed, 20 insertions(+), 7 deletions(-) diff --git a/src/learner.cc b/src/learner.cc index e4c925ebf5b6..096ccd0c396b 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -689,15 +689,23 @@ class LearnerIO : public LearnerConfiguration { warn_old_model = false; } - if (mparam_.major_version >= 1) { - learner_model_param_ = LearnerModelParam(mparam_, - obj_->ProbToMargin(mparam_.base_score)); - } else { + if (mparam_.major_version < 1) { // Before 1.0.0, base_score is saved as a transformed value, and there's no version - // attribute in the saved model. - learner_model_param_ = LearnerModelParam(mparam_, mparam_.base_score); + // attribute (saved a 0) in the saved model. + std::string multi{"multi:"}; + if (!std::equal(multi.cbegin(), multi.cend(), tparam_.objective.cbegin())) { + HostDeviceVector t; + t.HostVector().resize(1); + t.HostVector().at(0) = mparam_.base_score; + this->obj_->PredTransform(&t); + auto base_score = t.HostVector().at(0); + mparam_.base_score = base_score; + } warn_old_model = true; } + + learner_model_param_ = + LearnerModelParam(mparam_, obj_->ProbToMargin(mparam_.base_score)); if (attributes_.find("objective") != attributes_.cend()) { auto obj_str = attributes_.at("objective"); auto j_obj = Json::Load({obj_str.c_str(), obj_str.size()}); diff --git a/tests/pytest.ini b/tests/pytest.ini index 80c6579a80ad..aa0c89344ca5 100644 --- a/tests/pytest.ini +++ b/tests/pytest.ini @@ -1,3 +1,4 @@ [pytest] markers = - mgpu: Mark a test that requires multiple GPUs to run. \ No newline at end of file + mgpu: Mark a test that requires multiple GPUs to run. + ci: Mark a test that runs only on CI. \ No newline at end of file diff --git a/tests/python/test_model_compatibility.py b/tests/python/test_model_compatibility.py index 3ab85c74be8a..a37d5ecb24b0 100644 --- a/tests/python/test_model_compatibility.py +++ b/tests/python/test_model_compatibility.py @@ -4,6 +4,7 @@ import json import zipfile import pytest +import copy def run_model_param_check(config): @@ -124,6 +125,9 @@ def test_model_compatibility(): if name.startswith('xgboost-'): booster = xgboost.Booster(model_file=path) run_booster_check(booster, name) + # Do full serialization. + booster = copy.copy(booster) + run_booster_check(booster, name) elif name.startswith('xgboost_scikit'): run_scikit_model_check(name, path) else: