Skip to content

Commit

Permalink
MAINT: add threads/jobs primitives (#232)
Browse files Browse the repository at this point in the history
  • Loading branch information
colinvwood authored Feb 13, 2024
1 parent 7aef7e7 commit 8d7ea52
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 2 deletions.
7 changes: 7 additions & 0 deletions q2_sample_classifier/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from sklearn.pipeline import Pipeline

import qiime2
from qiime2.plugin import get_available_cores
import pandas as pd
import biom
import skbio
Expand Down Expand Up @@ -107,6 +108,9 @@ def _fit_predict_knn_cv(
x: pd.DataFrame, y: pd.Series, k: int, cv: int,
random_state: int, n_jobs: int
) -> (pd.Series, pd.Series):
if n_jobs == 0:
n_jobs = get_available_cores()

kf = KFold(n_splits=cv, shuffle=True, random_state=random_state)

# train and test with CV
Expand Down Expand Up @@ -291,6 +295,9 @@ def fit_regressor(table: biom.Table,


def predict_base(table, sample_estimator, n_jobs):
if n_jobs == 0:
n_jobs = get_available_cores()

# extract feature data from biom
feature_data = _extract_features(table)
index = table.ids()
Expand Down
4 changes: 2 additions & 2 deletions q2_sample_classifier/plugin_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from qiime2.plugin import (
Int, Str, Float, Range, Bool, Plugin, Metadata, Choices, MetadataColumn,
Numeric, Categorical, Citations, Visualization, TypeMatch)
Numeric, Categorical, Citations, Visualization, TypeMatch, Threads)
from q2_types.feature_table import (
FeatureTable, Frequency, RelativeFrequency, PresenceAbsence, Balance,
PercentileNormalized, Design, Composition)
Expand Down Expand Up @@ -100,7 +100,7 @@
parameters = {
'base': {
'random_state': Int,
'n_jobs': Int,
'n_jobs': Threads,
'n_estimators': Int % Range(1, None),
'missing_samples': Str % Choices(['error', 'ignore'])},
'splitter': {
Expand Down
7 changes: 7 additions & 0 deletions q2_sample_classifier/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
)
from sklearn.pipeline import Pipeline

from qiime2.plugin import get_available_cores
import q2templates
import pandas as pd
import numpy as np
Expand Down Expand Up @@ -264,6 +265,9 @@ def nested_cross_validation(table, metadata, cv, random_state, n_jobs,
n_estimators, estimator, stratify,
parameter_tuning, classification, scoring,
missing_samples='error'):
if n_jobs == 0:
n_jobs = get_available_cores()

# extract column name from NumericMetadataColumn
column = metadata.name

Expand Down Expand Up @@ -301,6 +305,9 @@ def _fit_estimator(features, targets, estimator, n_estimators=100, step=0.05,
cv=5, random_state=None, n_jobs=1,
optimize_feature_selection=False, parameter_tuning=False,
missing_samples='error', classification=True):
if n_jobs == 0:
n_jobs = get_available_cores()

# extract column name from CategoricalMetadataColumn
column = targets.to_series().name

Expand Down

0 comments on commit 8d7ea52

Please sign in to comment.