Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add multi target classification #441

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ from pytorch_tabular.config import (
data_config = DataConfig(
target=[
"target"
], # target should always be a list. Multi-targets are only supported for regression. Multi-Task Classification is not implemented
], # target should always be a list.
continuous_cols=num_col_names,
categorical_cols=cat_col_names,
)
Expand Down
2 changes: 1 addition & 1 deletion docs/gs_usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ from pytorch_tabular.config import (
data_config = DataConfig(
target=[
"target"
], # target should always be a list. Multi-targets are only supported for regression. Multi-Task Classification is not implemented
], # target should always be a list.
continuous_cols=num_col_names,
categorical_cols=cat_col_names,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@
"data_config = DataConfig(\n",
" target=[\n",
" target_col\n",
" ], # target should always be a list. Multi-targets are only supported for regression. Multi-Task Classification is not implemented\n",
" ], # target should always be a list\n",
" continuous_cols=num_col_names,\n",
" categorical_cols=cat_col_names,\n",
")\n",
Expand Down
2,008 changes: 2,008 additions & 0 deletions docs/tutorials/15-Multi Target Classification.ipynb

Large diffs are not rendered by default.

3 changes: 1 addition & 2 deletions examples/__only_for_dev__/adhoc_scaffold.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,7 @@ def print_metrics(y_true, y_pred, tag):
from pytorch_tabular.models import GatedAdditiveTreeEnsembleConfig # noqa: E402

data_config = DataConfig(
# target should always be a list. Multi-targets are only supported for regression.
# Multi-Task Classification is not implemented
# target should always be a list.
target=["target"],
continuous_cols=num_col_names,
categorical_cols=cat_col_names,
Expand Down
3 changes: 1 addition & 2 deletions examples/__only_for_dev__/to_test_dae.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,7 @@ def print_metrics(y_true, y_pred, tag):
lr = 1e-3

data_config = DataConfig(
# target should always be a list. Multi-targets are only supported for regression.
# Multi-Task Classification is not implemented
# target should always be a list.
target=[target_name],
continuous_cols=num_col_names,
categorical_cols=cat_col_names,
Expand Down
6 changes: 6 additions & 0 deletions src/pytorch_tabular/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,8 @@ class InferredConfig:

output_dim (Optional[int]): The number of output targets

output_cardinality (Optional[int]): The number of unique values in classification output

categorical_cardinality (Optional[List[int]]): The number of unique values in categorical features

embedding_dims (Optional[List]): The dimensions of the embedding for each categorical column as a
Expand All @@ -216,6 +218,10 @@ class InferredConfig:
default=None,
metadata={"help": "The number of output targets"},
)
output_cardinality: Optional[List[int]] = field(
default=None,
metadata={"help": "The number of unique values in classification output"},
)
categorical_cardinality: Optional[List[int]] = field(
default=None,
metadata={"help": "The number of unique values in categorical features"},
Expand Down
91 changes: 72 additions & 19 deletions src/pytorch_tabular/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,23 +122,43 @@ def __init__(
config.metrics_params.append(vars(metric))
if config.task == "classification":
config.metrics_prob_input = self.custom_metrics_prob_inputs
for i, mp in enumerate(config.metrics_params):
mp.sub_params_list = []
for j, num_classes in enumerate(inferred_config.output_cardinality):
config.metrics_params[i].sub_params_list.append(
OmegaConf.create(
{
"task": mp.get("task", "multiclass"),
"num_classes": mp.get("num_classes", num_classes),
}
)
)

# Updating default metrics in config
elif config.task == "classification":
# Adding metric_params to config for classification task
for i, mp in enumerate(config.metrics_params):
# For classification task, output_dim == number of classses
config.metrics_params[i]["task"] = mp.get("task", "multiclass")
config.metrics_params[i]["num_classes"] = mp.get("num_classes", inferred_config.output_dim)
if config.metrics[i] in (
"accuracy",
"precision",
"recall",
"precision_recall",
"specificity",
"f1_score",
"fbeta_score",
):
config.metrics_params[i]["top_k"] = mp.get("top_k", 1)
mp.sub_params_list = []
for j, num_classes in enumerate(inferred_config.output_cardinality):
# config.metrics_params[i][j]["task"] = mp.get("task", "multiclass")
# config.metrics_params[i][j]["num_classes"] = mp.get("num_classes", num_classes)

config.metrics_params[i].sub_params_list.append(
OmegaConf.create(
{"task": mp.get("task", "multiclass"), "num_classes": mp.get("num_classes", num_classes)}
)
)

if config.metrics[i] in (
"accuracy",
"precision",
"recall",
"precision_recall",
"specificity",
"f1_score",
"fbeta_score",
):
config.metrics_params[i].sub_params_list[j]["top_k"] = mp.get("top_k", 1)

if self.custom_optimizer is not None:
config.optimizer = str(self.custom_optimizer.__class__.__name__)
Expand Down Expand Up @@ -267,7 +287,22 @@ def calculate_loss(self, output: Dict, y: torch.Tensor, tag: str) -> torch.Tenso
)
else:
# TODO loss fails with batch size of 1?
computed_loss = self.loss(y_hat.squeeze(), y.squeeze()) + reg_loss
computed_loss = reg_loss
start_index = 0
for i in range(len(self.hparams.output_cardinality)):
end_index = start_index + self.hparams.output_cardinality[i]
_loss = self.loss(y_hat[:, start_index:end_index], y[:, i])
computed_loss += _loss
if self.hparams.output_dim > 1:
self.log(
f"{tag}_loss_{i}",
_loss,
on_epoch=True,
on_step=False,
logger=True,
prog_bar=False,
)
start_index = end_index
self.log(
f"{tag}_loss",
computed_loss,
Expand Down Expand Up @@ -325,11 +360,29 @@ def calculate_metrics(self, y: torch.Tensor, y_hat: torch.Tensor, tag: str) -> L
_metrics.append(_metric)
avg_metric = torch.stack(_metrics, dim=0).sum()
else:
y_hat = nn.Softmax(dim=-1)(y_hat.squeeze())
if prob_inp:
avg_metric = metric(y_hat, y.squeeze(), **metric_params)
else:
avg_metric = metric(torch.argmax(y_hat, dim=-1), y.squeeze(), **metric_params)
_metrics = []
start_index = 0
for i, cardinality in enumerate(self.hparams.output_cardinality):
end_index = start_index + cardinality
y_hat_i = nn.Softmax(dim=-1)(y_hat[:, start_index:end_index].squeeze())
if prob_inp:
_metric = metric(y_hat_i, y[:, i : i + 1].squeeze(), **metric_params.sub_params_list[i])
else:
_metric = metric(
torch.argmax(y_hat_i, dim=-1), y[:, i : i + 1].squeeze(), **metric_params.sub_params_list[i]
)
if len(self.hparams.output_cardinality) > 1:
self.log(
f"{tag}_{metric_str}_{i}",
_metric,
on_epoch=True,
on_step=False,
logger=True,
prog_bar=False,
)
_metrics.append(_metric)
start_index = end_index
avg_metric = torch.stack(_metrics, dim=0).sum()
metrics.append(avg_metric)
self.log(
f"{tag}_{metric_str}",
Expand Down
27 changes: 20 additions & 7 deletions src/pytorch_tabular/tabular_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,13 +282,21 @@ def _update_config(self, config) -> InferredConfig:
if config.task == "regression":
# self._output_dim_reg = len(config.target) if config.target else None if self.train is not None:
output_dim = len(config.target) if config.target else None
output_cardinality = None
elif config.task == "classification":
# self._output_dim_clf = len(np.unique(self.train_dataset.y)) if config.target else None
if self.train is not None:
output_dim = len(np.unique(self.train[config.target[0]])) if config.target else None
output_cardinality = (
self.train[config.target].fillna("NA").nunique().tolist() if config.target else None
)
output_dim = sum(output_cardinality)
else:
output_dim = len(np.unique(self.train_dataset.y)) if config.target else None
output_cardinality = (
self.train_dataset.data[config.target].fillna("NA").nunique().tolist() if config.target else None
)
output_dim = sum(output_cardinality)
elif config.task == "ssl":
output_cardinality = None
output_dim = None
else:
raise ValueError(f"{config.task} is an unsupported task.")
Expand All @@ -308,6 +316,7 @@ def _update_config(self, config) -> InferredConfig:
categorical_dim=categorical_dim,
continuous_dim=continuous_dim,
output_dim=output_dim,
output_cardinality=output_cardinality,
categorical_cardinality=categorical_cardinality,
embedding_dims=embedding_dims,
)
Expand Down Expand Up @@ -381,11 +390,14 @@ def _label_encode_target(self, data: DataFrame, stage: str) -> DataFrame:
if self.config.task != "classification":
return data
if stage == "fit" or self.label_encoder is None:
self.label_encoder = LabelEncoder()
data[self.config.target[0]] = self.label_encoder.fit_transform(data[self.config.target[0]])
self.label_encoder = [None] * len(self.config.target)
for i in range(len(self.config.target)):
self.label_encoder[i] = LabelEncoder()
data[self.config.target[i]] = self.label_encoder[i].fit_transform(data[self.config.target[i]])
else:
if self.config.target[0] in data.columns:
data[self.config.target[0]] = self.label_encoder.transform(data[self.config.target[0]])
for i in range(len(self.config.target)):
if self.config.target[i] in data.columns:
data[self.config.target[i]] = self.label_encoder[i].transform(data[self.config.target[i]])
return data

def _target_transform(self, data: DataFrame, stage: str) -> DataFrame:
Expand Down Expand Up @@ -818,7 +830,8 @@ def _prepare_inference_data(self, df: DataFrame) -> DataFrame:
# TODO Is the target encoding necessary?
if len(set(self.target) - set(df.columns)) > 0:
if self.config.task == "classification":
df.loc[:, self.target] = np.array([self.label_encoder.classes_[0]] * len(df)).reshape(-1, 1)
for i in range(len(self.target)):
df.loc[:, self.target[i]] = np.array([self.label_encoder[i].classes_[0]] * len(df)).reshape(-1, 1)
else:
df.loc[:, self.target] = np.zeros((len(df), len(self.target)))
df, _ = self.preprocess_data(df, stage="inference")
Expand Down
31 changes: 15 additions & 16 deletions src/pytorch_tabular/tabular_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,9 +211,6 @@ def num_params(self):

def _run_validation(self):
"""Validates the Config params and throws errors if something is wrong."""
if self.config.task == "classification":
if len(self.config.target) > 1:
raise NotImplementedError("Multi-Target Classification is not implemented.")
if self.config.task == "regression":
if self.config.target_range is not None:
if (
Expand Down Expand Up @@ -1291,12 +1288,16 @@ def _format_predicitons(
pred_df[f"{target_col}_q{int(q*100)}"] = quantile_predictions[:, j, i].reshape(-1, 1)

elif self.config.task == "classification":
point_predictions = nn.Softmax(dim=-1)(point_predictions).numpy()
for i, class_ in enumerate(self.datamodule.label_encoder.classes_):
pred_df[f"{class_}_probability"] = point_predictions[:, i]
pred_df["prediction"] = self.datamodule.label_encoder.inverse_transform(
np.argmax(point_predictions, axis=1)
)
start_index = 0
for i, target_col in enumerate(self.config.target):
end_index = start_index + self.datamodule._inferred_config.output_cardinality[i]
prob_prediction = nn.Softmax(dim=-1)(point_predictions[:, start_index:end_index]).numpy()
start_index = end_index
for j, class_ in enumerate(self.datamodule.label_encoder[i].classes_):
pred_df[f"{target_col}_{class_}_probability"] = prob_prediction[:, j]
pred_df[f"{target_col}_prediction"] = self.datamodule.label_encoder[i].inverse_transform(
np.argmax(prob_prediction, axis=1)
)
warnings.warn(
"Classification prediction column will be renamed to"
" `{target_col}_prediction` in the next release to maintain"
Expand Down Expand Up @@ -2046,23 +2047,21 @@ def _combine_predictions(
elif callable(aggregate):
bagged_pred = aggregate(pred_prob_l)
if self.config.task == "classification":
classes = self.datamodule.label_encoder.classes_
# FIXME need to iterate .label_encoder[x]
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should give an error message if somebody attempts bagging predict with multi label classification?

classes = self.datamodule.label_encoder[0].classes_
if aggregate == "hard_voting":
pred_df = pd.DataFrame(
np.concatenate(pred_prob_l, axis=1),
columns=[
f"{c}_probability_fold_{i}"
for i in range(len(pred_prob_l))
for c in self.datamodule.label_encoder.classes_
],
columns=[f"{c}_probability_fold_{i}" for i in range(len(pred_prob_l)) for c in classes],
index=pred_idx,
)
pred_df["prediction"] = classes[final_pred]
else:
final_pred = classes[np.argmax(bagged_pred, axis=1)]
pred_df = pd.DataFrame(
bagged_pred,
columns=[f"{c}_probability" for c in self.datamodule.label_encoder.classes_],
# FIXME
columns=[f"{c}_probability" for c in self.datamodule.label_encoder[0].classes_],
index=pred_idx,
)
pred_df["prediction"] = final_pred
Expand Down
4 changes: 3 additions & 1 deletion tests/test_autoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def test_regression(
assert pred_df.shape[0] == test.shape[0]


@pytest.mark.parametrize("multi_target", [False, True])
@pytest.mark.parametrize(
"continuous_cols",
[
Expand All @@ -91,6 +92,7 @@ def test_regression(
@pytest.mark.parametrize("batch_norm_continuous_input", [True, False])
def test_classification(
classification_data,
multi_target,
continuous_cols,
categorical_cols,
continuous_feature_transform,
Expand All @@ -100,7 +102,7 @@ def test_classification(
):
(train, test, target) = classification_data
data_config = DataConfig(
target=target,
target=target + ["feature_53"] if multi_target else target,
continuous_cols=continuous_cols,
categorical_cols=categorical_cols,
continuous_feature_transform=continuous_feature_transform,
Expand Down
4 changes: 3 additions & 1 deletion tests/test_categorical_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def test_regression(
assert pred_df.shape[0] == test.shape[0]


@pytest.mark.parametrize("multi_target", [False, True])
@pytest.mark.parametrize(
"continuous_cols",
[
Expand All @@ -136,6 +137,7 @@ def test_regression(
@pytest.mark.parametrize("normalize_continuous_features", [True])
def test_classification(
classification_data,
multi_target,
continuous_cols,
categorical_cols,
continuous_feature_transform,
Expand All @@ -146,7 +148,7 @@ def test_classification(
return

data_config = DataConfig(
target=target,
target=target + ["feature_53"] if multi_target else target,
continuous_cols=continuous_cols,
categorical_cols=categorical_cols,
continuous_feature_transform=continuous_feature_transform,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,7 +753,7 @@ def test_cross_validate_regression(
[
"accuracy",
None,
lambda y_true, y_pred: accuracy_score(y_true, y_pred["prediction"].values),
lambda y_true, y_pred: accuracy_score(y_true, y_pred["target_prediction"].values),
],
)
@pytest.mark.parametrize("return_oof", [True])
Expand Down
4 changes: 3 additions & 1 deletion tests/test_danet.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def test_regression(
assert pred_df.shape[0] == test.shape[0]


@pytest.mark.parametrize("multi_target", [False, True])
@pytest.mark.parametrize(
"continuous_cols",
[
Expand All @@ -91,14 +92,15 @@ def test_regression(
@pytest.mark.parametrize("normalize_continuous_features", [True])
def test_classification(
classification_data,
multi_target,
continuous_cols,
categorical_cols,
continuous_feature_transform,
normalize_continuous_features,
):
(train, test, target) = classification_data
data_config = DataConfig(
target=target,
target=target + ["feature_53"] if multi_target else target,
continuous_cols=continuous_cols,
categorical_cols=categorical_cols,
continuous_feature_transform=continuous_feature_transform,
Expand Down
Loading
Loading