Skip to content

Commit

Permalink
[SPARK-21275][ML] Update GLM test to use supportedFamilyNames
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?
Update GLM test to use supportedFamilyNames as suggested here:
#16699 (diff)

Author: actuaryzhang <actuaryzhang10@gmail.com>

Closes #18495 from actuaryzhang/mlGlmTest2.
  • Loading branch information
actuaryzhang authored and yanboliang committed Jul 1, 2017
1 parent b1d719e commit 37ef32e
Showing 1 changed file with 16 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -749,15 +749,15 @@ class GeneralizedLinearRegressionSuite
library(statmod)
y <- c(1.0, 0.5, 0.7, 0.3)
w <- c(1, 2, 3, 4)
for (fam in list(gaussian(), poisson(), binomial(), Gamma(), tweedie(1.6))) {
for (fam in list(binomial(), Gamma(), gaussian(), poisson(), tweedie(1.6))) {
model1 <- glm(y ~ 1, family = fam)
model2 <- glm(y ~ 1, family = fam, weights = w)
print(as.vector(c(coef(model1), coef(model2))))
}
[1] 0.625 0.530
[1] -0.4700036 -0.6348783
[1] 0.5108256 0.1201443
[1] 1.600000 1.886792
[1] 0.625 0.530
[1] -0.4700036 -0.6348783
[1] 1.325782 1.463641
*/

Expand All @@ -768,13 +768,13 @@ class GeneralizedLinearRegressionSuite
Instance(0.3, 4.0, Vectors.zeros(0))
).toDF()

val expected = Seq(0.625, 0.530, -0.4700036, -0.6348783, 0.5108256, 0.1201443,
1.600000, 1.886792, 1.325782, 1.463641)
val expected = Seq(0.5108256, 0.1201443, 1.600000, 1.886792, 0.625, 0.530,
-0.4700036, -0.6348783, 1.325782, 1.463641)

import GeneralizedLinearRegression._

var idx = 0
for (family <- Seq("gaussian", "poisson", "binomial", "gamma", "tweedie")) {
for (family <- GeneralizedLinearRegression.supportedFamilyNames.sortWith(_ < _)) {
for (useWeight <- Seq(false, true)) {
val trainer = new GeneralizedLinearRegression().setFamily(family)
if (useWeight) trainer.setWeightCol("weight")
Expand Down Expand Up @@ -807,7 +807,7 @@ class GeneralizedLinearRegressionSuite
0.5, 2.1, 0.5, 1.0, 2.0,
0.9, 0.4, 1.0, 2.0, 1.0,
0.7, 0.7, 0.0, 3.0, 3.0), 4, 5, byrow = TRUE))
families <- list(gaussian, binomial, poisson, Gamma, tweedie(1.5))
families <- list(binomial, Gamma, gaussian, poisson, tweedie(1.5))
f1 <- V1 ~ -1 + V4 + V5
f2 <- V1 ~ V4 + V5
for (f in c(f1, f2)) {
Expand All @@ -816,15 +816,15 @@ class GeneralizedLinearRegressionSuite
print(as.vector(coef(model)))
}
}
[1] 0.5169222 -0.3344444
[1] 0.9419107 -0.6864404
[1] 0.1812436 -0.6568422
[1] -0.2869094 0.7857710
[1] 0.5169222 -0.3344444
[1] 0.1812436 -0.6568422
[1] 0.1055254 0.2979113
[1] -0.05990345 0.53188982 -0.32118415
[1] -0.2147117 0.9911750 -0.6356096
[1] -1.5616130 0.6646470 -0.3192581
[1] 0.3390397 -0.3406099 0.6870259
[1] -0.05990345 0.53188982 -0.32118415
[1] -1.5616130 0.6646470 -0.3192581
[1] 0.3665034 0.1039416 0.1484616
*/
val dataset = Seq(
Expand All @@ -835,23 +835,22 @@ class GeneralizedLinearRegressionSuite
).toDF()

val expected = Seq(
Vectors.dense(0, 0.5169222, -0.3344444),
Vectors.dense(0, 0.9419107, -0.6864404),
Vectors.dense(0, 0.1812436, -0.6568422),
Vectors.dense(0, -0.2869094, 0.785771),
Vectors.dense(0, 0.5169222, -0.3344444),
Vectors.dense(0, 0.1812436, -0.6568422),
Vectors.dense(0, 0.1055254, 0.2979113),
Vectors.dense(-0.05990345, 0.53188982, -0.32118415),
Vectors.dense(-0.2147117, 0.991175, -0.6356096),
Vectors.dense(-1.561613, 0.664647, -0.3192581),
Vectors.dense(0.3390397, -0.3406099, 0.6870259),
Vectors.dense(-0.05990345, 0.53188982, -0.32118415),
Vectors.dense(-1.561613, 0.664647, -0.3192581),
Vectors.dense(0.3665034, 0.1039416, 0.1484616))

import GeneralizedLinearRegression._

var idx = 0

for (fitIntercept <- Seq(false, true)) {
for (family <- Seq("gaussian", "binomial", "poisson", "gamma", "tweedie")) {
for (family <- GeneralizedLinearRegression.supportedFamilyNames.sortWith(_ < _)) {
val trainer = new GeneralizedLinearRegression().setFamily(family)
.setFitIntercept(fitIntercept).setOffsetCol("offset")
.setWeightCol("weight").setLinkPredictionCol("linkPrediction")
Expand Down

0 comments on commit 37ef32e

Please sign in to comment.