-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Gunnar Raetsch benchmark datasets added
- Loading branch information
1 parent
4a6c2d5
commit 9a5d633
Showing
9 changed files
with
230 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
""" | ||
Gunnar Raetsch benchmark datasets | ||
(https://github.com/tdiethe/gunnar_raetsch_benchmark_datasets). | ||
@author: David Diaz Vico | ||
@license: MIT | ||
""" | ||
|
||
from functools import partial | ||
|
||
from .base import load_dataset | ||
|
||
|
||
load_banana = partial(load_dataset, name='banana') | ||
load_breast_cancer = partial(load_dataset, name='breast_cancer') | ||
load_diabetis = partial(load_dataset, name='diabetis') | ||
load_flare_solar = partial(load_dataset, name='flare_solar') | ||
load_german = partial(load_dataset, name='german') | ||
load_heart = partial(load_dataset, name='heart') | ||
load_image = partial(load_dataset, name='image') | ||
load_ringnorm = partial(load_dataset, name='ringnorm') | ||
load_splice = partial(load_dataset, name='splice') | ||
load_thyroid = partial(load_dataset, name='thyroid') | ||
load_titanic = partial(load_dataset, name='titanic') | ||
load_twonorm = partial(load_dataset, name='twonorm') | ||
load_waveform = partial(load_dataset, name='waveform') | ||
|
||
|
||
load = {'banana': load_banana, 'breast_cancer': load_breast_cancer, | ||
'diabetis': load_diabetis, 'flare_solar': load_flare_solar, | ||
'german': load_german, 'heart': load_heart, 'image': load_image, | ||
'ringnorm': load_ringnorm, 'splice': load_splice, | ||
'thyroid': load_thyroid, 'titanic': load_titanic, | ||
'twonorm': load_twonorm, 'waveform': load_waveform} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
""" | ||
Gunnar Raetsch benchmark datasets | ||
(https://github.com/tdiethe/gunnar_raetsch_benchmark_datasets). | ||
@author: David Diaz Vico | ||
@license: MIT | ||
""" | ||
|
||
from scipy.io import loadmat | ||
from sklearn.model_selection import BaseCrossValidator | ||
|
||
from ..base import Bunch | ||
|
||
|
||
class GunnarRaetschDatasetSplit(BaseCrossValidator): | ||
"""Predefined split cross-validator for Gunnar Raetsch datasets. | ||
Provides train/test indices to split data into train/test sets using a | ||
predefined scheme. | ||
Read more in the :ref:`User Guide <cross_validation>`. | ||
Parameters | ||
---------- | ||
train_splits: array-like, shape (n_samples,) | ||
List of indices for each training split. | ||
test_splits: array-like, shape (n_samples,) | ||
List of indices for each test split. | ||
""" | ||
|
||
def __init__(self, train_splits, test_splits): | ||
self.train_splits = train_splits - 1 | ||
self.test_splits = test_splits - 1 | ||
|
||
def split(self, X=None, y=None, groups=None): | ||
"""Generate indices to split data into training and test set. | ||
Parameters | ||
---------- | ||
X: object | ||
Always ignored, exists for compatibility. | ||
y: object | ||
Always ignored, exists for compatibility. | ||
groups: object | ||
Always ignored, exists for compatibility. | ||
Returns | ||
------- | ||
train: ndarray | ||
The training set indices for that split. | ||
test: ndarray | ||
The testing set indices for that split. | ||
""" | ||
for train_indices, test_indices in zip(self.train_splits, self.test_splits): | ||
yield (train_indices, test_indices) | ||
|
||
def get_n_splits(self, X=None, y=None, groups=None): | ||
"""Returns the number of splitting iterations in the cross-validator | ||
Parameters | ||
---------- | ||
X: object | ||
Always ignored, exists for compatibility. | ||
y: object | ||
Always ignored, exists for compatibility. | ||
groups: object | ||
Always ignored, exists for compatibility. | ||
Returns | ||
------- | ||
n_splits: int | ||
Returns the number of splitting iterations in the | ||
cross-validator. | ||
""" | ||
return len(self.train_splits) | ||
|
||
|
||
def load_dataset(name, return_X_y=False): | ||
"""Load dataset. | ||
Load a dataset. | ||
Parameters | ||
---------- | ||
name: string | ||
Dataset name. | ||
return_X_y: bool, default=False | ||
If True, returns (data, target) instead of a Bunch object.. | ||
Returns | ||
------- | ||
data: Bunch | ||
Dictionary-like object with all the data and metadata. | ||
X, y: arrays | ||
If return_X_y is True | ||
""" | ||
features, target, train_splits, test_splits = loadmat('skdatasets/gunnar_raetsch/benchmarks')[name][0][0] | ||
|
||
if return_X_y: | ||
return features, target | ||
|
||
return Bunch(features=features, target=target, | ||
splits=GunnarRaetschDatasetSplit(train_splits, test_splits)) |
Binary file not shown.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
""" | ||
Tests. | ||
@author: David Diaz Vico | ||
@license: MIT | ||
""" | ||
|
||
from sklearn.linear_model import LogisticRegression | ||
from sklearn.model_selection import cross_val_score | ||
|
||
from ..base import check_load_dataset | ||
|
||
from skdatasets.gunnar_raetsch import (load_banana, load_breast_cancer, | ||
load_diabetis, load_flare_solar, | ||
load_german, load_heart, load_image, | ||
load_ringnorm, load_splice, load_thyroid, | ||
load_titanic, load_twonorm, | ||
load_waveform) | ||
|
||
|
||
def test_split(): | ||
"""Tests Gunnar Raetsch dataset splits.""" | ||
data = load_banana() | ||
X = data.features | ||
y = data.target | ||
splits = data.splits | ||
cross_val_score(LogisticRegression(), X, y=y, cv=splits) | ||
|
||
|
||
def test_load(): | ||
"""Tests Gunnar Raetsch benchmark datasets.""" | ||
datasets = {'banana': {'loader': load_banana, 'n_patterns': (5300, ), | ||
'n_variables': 2}, | ||
'breast_cancer': {'loader': load_breast_cancer, | ||
'n_patterns': (263, ), 'n_variables': 9}, | ||
'diabetis': {'loader': load_diabetis, 'n_patterns': (768, ), | ||
'n_variables': 8}, | ||
'flare_solar': {'loader': load_flare_solar, | ||
'n_patterns': (144, ), 'n_variables': 9}, | ||
'german': {'loader': load_german, 'n_patterns': (1000, ), | ||
'n_variables': 20}, | ||
'heart': {'loader': load_heart, 'n_patterns': (270, ), | ||
'n_variables': 13}, | ||
'image': {'loader': load_image, 'n_patterns': (2086, ), | ||
'n_variables': 18}, | ||
'ringnorm': {'loader': load_ringnorm, 'n_patterns': (7400, ), | ||
'n_variables': 20}, | ||
'splice': {'loader': load_splice, 'n_patterns': (2991, ), | ||
'n_variables': 60}, | ||
'thyroid': {'loader': load_thyroid, 'n_patterns': (215, ), | ||
'n_variables': 5}, | ||
'titanic': {'loader': load_titanic, 'n_patterns': (24, ), | ||
'n_variables': 3}, | ||
'twonorm': {'loader': load_twonorm, 'n_patterns': (7400, ), | ||
'n_variables': 20}, | ||
'waveform': {'loader': load_waveform, 'n_patterns': (5000, ), | ||
'n_variables': 21}} | ||
for dataset in datasets.values(): | ||
check_load_dataset(dataset['loader'], dataset['n_patterns'], | ||
dataset['n_variables'], (('features', 'target'), ), | ||
n_targets=1, n_folds=None) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters