Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Leave one out cross validation #101

Merged
merged 12 commits into from
Feb 8, 2017
Prev Previous commit
Next Next commit
Adding lovo, percent_explained and mse tests
  • Loading branch information
mortonjt committed Jan 26, 2017
commit bbb6a02fd939619f6d111c4ff74ee4acdb8d4d9a
41 changes: 37 additions & 4 deletions gneiss/regression/tests/test_ols.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,17 +303,50 @@ def test_loo(self):
table=self.y, metadata=self.x, tree=self.t2)
res.fit(regularized=True)
res_loo = res.loo(fit_regularized=True).astype(np.float)

pdt.assert_frame_equal(exp_loo, res_loo, check_less_precise=True)

def test_lovo(self):
self.assertTrue(False)
res = ols(formula="x1 + x2 + x3 + x4",
table=self.y, metadata=self.x, tree=self.t2)
res.fit()
exp_lovo = pd.DataFrame([[0.799364, 0.978214],
[0.799363, 0.097355],
[0.799368, 0.0973498],
[0.799364, 0.097354],
[0.799361, 0.0973575]],
columns=['mse', 'Rsquared'],
index=['Intercept', 'x1', 'x2', 'x3', 'x4'])
res_lovo = res.lovo().astype(np.float)
pdt.assert_frame_equal(exp_lovo, res_lovo)

# Make sure that the regularization works
exp_lovo = pd.DataFrame([[0.799364, 0.978214],
[0.799363, 0.097355],
[0.799368, 0.0973498],
[0.799364, 0.097354],
[0.799361, 0.0973575]],
columns=['mse', 'Rsquared'],
index=['Intercept', 'x1', 'x2', 'x3', 'x4'])
res = ols(formula="x1 + x2 + x3 + x4",
table=self.y, metadata=self.x, tree=self.t2)
res.fit(regularized=True)
res_lovo = res.lovo(fit_regularized=True).astype(np.float)
pdt.assert_frame_equal(exp_lovo, res_lovo, check_less_precise=True)

def test_percent_explained(self):
self.assertTrue(False)
res = ols(formula="x1 + x2 + x3 + x4",
table=self.y, metadata=self.x, tree=self.t2)
res.fit()
res_perc = res.percent_explained()
exp_perc = pd.Series({'y0': 0.009901,
'y1': 0.990099})
pdt.assert_series_equal(res_perc, exp_perc)

def test_mse(self):
self.assertTrue(False)
res = ols(formula="x1 + x2 + x3 + x4",
table=self.y, metadata=self.x, tree=self.t2)
res.fit()
self.assertEquals(res.mse, 0.87936961790361856)


if __name__ == "__main__":
Expand Down