From 31a37c35f29b8f945eae8a33140af98c58fd7bb9 Mon Sep 17 00:00:00 2001 From: Yony Bresler Date: Wed, 10 Apr 2024 20:16:19 +0000 Subject: [PATCH 01/11] First pass at Multi-Target classifier. Core functionality works, but failing other tests --- src/pytorch_tabular/config/config.py | 6 ++++ src/pytorch_tabular/models/base_model.py | 42 +++++++++++++++++++---- src/pytorch_tabular/tabular_datamodule.py | 23 +++++++++---- src/pytorch_tabular/tabular_model.py | 19 +++++----- tests/test_gandalf.py | 4 ++- 5 files changed, 71 insertions(+), 23 deletions(-) diff --git a/src/pytorch_tabular/config/config.py b/src/pytorch_tabular/config/config.py index ef4c9ef5..f6106cdd 100644 --- a/src/pytorch_tabular/config/config.py +++ b/src/pytorch_tabular/config/config.py @@ -195,6 +195,8 @@ class InferredConfig: output_dim (Optional[int]): The number of output targets + output_cardinality (Optional[int]): The number of unique values in classification output + categorical_cardinality (Optional[List[int]]): The number of unique values in categorical features embedding_dims (Optional[List]): The dimensions of the embedding for each categorical column as a @@ -213,6 +215,10 @@ class InferredConfig: default=None, metadata={"help": "The number of output targets"}, ) + output_cardinality: Optional[List[int]] = field( + default=None, + metadata={"help": "The number of unique values in classification output"}, + ) categorical_cardinality: Optional[List[int]] = field( default=None, metadata={"help": "The number of unique values in categorical features"}, diff --git a/src/pytorch_tabular/models/base_model.py b/src/pytorch_tabular/models/base_model.py index b15eed4d..51f1b7e8 100644 --- a/src/pytorch_tabular/models/base_model.py +++ b/src/pytorch_tabular/models/base_model.py @@ -263,7 +263,22 @@ def calculate_loss(self, output: Dict, y: torch.Tensor, tag: str) -> torch.Tenso ) else: # TODO loss fails with batch size of 1? - computed_loss = self.loss(y_hat.squeeze(), y.squeeze()) + reg_loss + computed_loss = reg_loss + start_index = 0 + for i in range(len(self.hparams.output_cardinality)): + end_index = start_index + self.hparams.output_cardinality[i] + _loss = self.loss(y_hat[:, start_index:end_index], y[:, i]) + computed_loss += _loss + if self.hparams.output_dim > 1: + self.log( + f"{tag}_loss_{i}", + _loss, + on_epoch=True, + on_step=False, + logger=True, + prog_bar=False, + ) + start_index = end_index self.log( f"{tag}_loss", computed_loss, @@ -320,11 +335,26 @@ def calculate_metrics(self, y: torch.Tensor, y_hat: torch.Tensor, tag: str) -> L _metrics.append(_metric) avg_metric = torch.stack(_metrics, dim=0).sum() else: - y_hat = nn.Softmax(dim=-1)(y_hat.squeeze()) - if prob_inp: - avg_metric = metric(y_hat, y.squeeze(), **metric_params) - else: - avg_metric = metric(torch.argmax(y_hat, dim=-1), y.squeeze(), **metric_params) + _metrics = [] + start_index = 0 + for i in range(len(self.hparams.output_cardinality)): + end_index = start_index + self.hparams.output_cardinality[i] + y_hat_i = nn.Softmax(dim=-1)(y_hat[:, start_index:end_index].squeeze()) + if prob_inp: + _metric = metric(y_hat_i, y[:,i:i+1].squeeze(), **metric_params) + else: + _metric = metric(torch.argmax(y_hat_i, dim=-1), y[:,i:i+1].squeeze(), **metric_params) + if len(self.hparams.output_cardinality) > 1: + self.log( + f"{tag}_{metric_str}_{i}", + _metric, + on_epoch=True, + on_step=False, + logger=True, + prog_bar=False, + ) + _metrics.append(_metric) + avg_metric = torch.stack(_metrics, dim=0).sum() metrics.append(avg_metric) self.log( f"{tag}_{metric_str}", diff --git a/src/pytorch_tabular/tabular_datamodule.py b/src/pytorch_tabular/tabular_datamodule.py index 5241a395..b19cfd64 100644 --- a/src/pytorch_tabular/tabular_datamodule.py +++ b/src/pytorch_tabular/tabular_datamodule.py @@ -278,13 +278,17 @@ def _update_config(self, config) -> InferredConfig: if config.task == "regression": # self._output_dim_reg = len(config.target) if config.target else None if self.train is not None: output_dim = len(config.target) if config.target else None + output_cardinality = None elif config.task == "classification": # self._output_dim_clf = len(np.unique(self.train_dataset.y)) if config.target else None if self.train is not None: - output_dim = len(np.unique(self.train[config.target[0]])) if config.target else None + output_cardinality = self.train[config.target].fillna("NA").nunique().tolist() if config.target else None + output_dim = sum(output_cardinality) else: - output_dim = len(np.unique(self.train_dataset.y)) if config.target else None + output_cardinality = self.train_dataset.data[config.target].fillna("NA").nunique().tolist() if config.target else None + output_dim = sum(output_cardinality) elif config.task == "ssl": + output_cardinality = None output_dim = None else: raise ValueError(f"{config.task} is an unsupported task.") @@ -304,6 +308,7 @@ def _update_config(self, config) -> InferredConfig: categorical_dim=categorical_dim, continuous_dim=continuous_dim, output_dim=output_dim, + output_cardinality=output_cardinality, categorical_cardinality=categorical_cardinality, embedding_dims=embedding_dims, ) @@ -376,11 +381,14 @@ def _label_encode_target(self, data: DataFrame, stage: str) -> DataFrame: if self.config.task != "classification": return data if stage == "fit" or self.label_encoder is None: - self.label_encoder = LabelEncoder() - data[self.config.target[0]] = self.label_encoder.fit_transform(data[self.config.target[0]]) + self.label_encoder = [None] * len(self.config.target) + for i in range(len(self.config.target)): + self.label_encoder[i] = LabelEncoder() + data[self.config.target[i]] = self.label_encoder[i].fit_transform(data[self.config.target[i]]) else: - if self.config.target[0] in data.columns: - data[self.config.target[0]] = self.label_encoder.transform(data[self.config.target[0]]) + for i in range(len(self.config.target)): + if self.config.target[i] in data.columns: + data[self.config.target[i]] = self.label_encoder[i].transform(data[self.config.target[i]]) return data def _target_transform(self, data: DataFrame, stage: str) -> DataFrame: @@ -803,7 +811,8 @@ def _prepare_inference_data(self, df: DataFrame) -> DataFrame: # TODO Is the target encoding necessary? if len(set(self.target) - set(df.columns)) > 0: if self.config.task == "classification": - df.loc[:, self.target] = np.array([self.label_encoder.classes_[0]] * len(df)).reshape(-1, 1) + for i in range(len(self.target)): + df.loc[:, self.target[i]] = np.array([self.label_encoder[i].classes_[0]] * len(df)).reshape(-1, 1) else: df.loc[:, self.target] = np.zeros((len(df), len(self.target))) df, _ = self.preprocess_data(df, stage="inference") diff --git a/src/pytorch_tabular/tabular_model.py b/src/pytorch_tabular/tabular_model.py index 04a496bd..5adf1a7d 100644 --- a/src/pytorch_tabular/tabular_model.py +++ b/src/pytorch_tabular/tabular_model.py @@ -215,9 +215,6 @@ def num_params(self): def _run_validation(self): """Validates the Config params and throws errors if something is wrong.""" - if self.config.task == "classification": - if len(self.config.target) > 1: - raise NotImplementedError("Multi-Target Classification is not implemented.") if self.config.task == "regression": if self.config.target_range is not None: if ( @@ -1288,12 +1285,16 @@ def _format_predicitons( pred_df[f"{target_col}_q{int(q*100)}"] = quantile_predictions[:, j, i].reshape(-1, 1) elif self.config.task == "classification": - point_predictions = nn.Softmax(dim=-1)(point_predictions).numpy() - for i, class_ in enumerate(self.datamodule.label_encoder.classes_): - pred_df[f"{class_}_probability"] = point_predictions[:, i] - pred_df["prediction"] = self.datamodule.label_encoder.inverse_transform( - np.argmax(point_predictions, axis=1) - ) + start_index = 0 + for i, target_col in enumerate(self.config.target): + end_index = start_index + self.datamodule._inferred_config.output_cardinality[i] + prob_prediction = nn.Softmax(dim=-1)(point_predictions[:, start_index:end_index]).numpy() + start_index = end_index + for j, class_ in enumerate(self.datamodule.label_encoder[i].classes_): + pred_df[f"{target_col}_{class_}_probability"] = prob_prediction[:, j] + pred_df[f"{target_col}_prediction"] = self.datamodule.label_encoder[i].inverse_transform( + np.argmax(prob_prediction, axis=1) + ) warnings.warn( "Classification prediction column will be renamed to" " `{target_col}_prediction` in the next release to maintain" diff --git a/tests/test_gandalf.py b/tests/test_gandalf.py index 912cbd79..d3787ef3 100644 --- a/tests/test_gandalf.py +++ b/tests/test_gandalf.py @@ -78,6 +78,7 @@ def test_regression( assert pred_df.shape[0] == test.shape[0] +@pytest.mark.parametrize("multi_target", [False, True]) @pytest.mark.parametrize( "continuous_cols", [ @@ -89,6 +90,7 @@ def test_regression( @pytest.mark.parametrize("normalize_continuous_features", [True]) def test_classification( classification_data, + multi_target, continuous_cols, categorical_cols, continuous_feature_transform, @@ -96,7 +98,7 @@ def test_classification( ): (train, test, target) = classification_data data_config = DataConfig( - target=target, + target=target + ["feature_53"] if multi_target else target, continuous_cols=continuous_cols, categorical_cols=categorical_cols, continuous_feature_transform=continuous_feature_transform, From 3e5c080c5672264a90210202e5052299f80ec998 Mon Sep 17 00:00:00 2001 From: Yony Bresler Date: Mon, 22 Apr 2024 20:25:28 +0000 Subject: [PATCH 02/11] Updated base model to support custom metrics on multi-target --- src/pytorch_tabular/models/base_model.py | 46 ++++++++++++++---------- 1 file changed, 28 insertions(+), 18 deletions(-) diff --git a/src/pytorch_tabular/models/base_model.py b/src/pytorch_tabular/models/base_model.py index 51f1b7e8..c92a4abd 100644 --- a/src/pytorch_tabular/models/base_model.py +++ b/src/pytorch_tabular/models/base_model.py @@ -13,7 +13,7 @@ import torch import torch.nn as nn import torchmetrics -from omegaconf import DictConfig, OmegaConf +from omegaconf import DictConfig, OmegaConf, ListConfig from pandas import DataFrame from torch import Tensor from torch.optim import Optimizer @@ -121,21 +121,31 @@ def __init__( config.metrics_prob_input = self.custom_metrics_prob_inputs # Updating default metrics in config elif config.task == "classification": + # FIXME need to revise metrics_params already from above if there are custom + # For classification, metrics_params becomes a 2D list + # config.metrics_params[0] = [] # Adding metric_params to config for classification task for i, mp in enumerate(config.metrics_params): - # For classification task, output_dim == number of classses - config.metrics_params[i]["task"] = mp.get("task", "multiclass") - config.metrics_params[i]["num_classes"] = mp.get("num_classes", inferred_config.output_dim) - if config.metrics[i] in ( - "accuracy", - "precision", - "recall", - "precision_recall", - "specificity", - "f1_score", - "fbeta_score", - ): - config.metrics_params[i]["top_k"] = mp.get("top_k", 1) + mp.sub_params_list = [] + for j, num_classes in enumerate(inferred_config.output_cardinality): + # For classification task, output_dim == number of classses + #config.metrics_params[i].append() + #config.metrics_params[i][j]["task"] = mp.get("task", "multiclass") + #config.metrics_params[i][j]["num_classes"] = mp.get("num_classes", num_classes) + + config.metrics_params[i].sub_params_list.append(OmegaConf.create({"task": mp.get("task", "multiclass"), + "num_classes": mp.get("num_classes", num_classes)})) + + if config.metrics[i] in ( + "accuracy", + "precision", + "recall", + "precision_recall", + "specificity", + "f1_score", + "fbeta_score", + ): + config.metrics_params[i].sub_params_list[j]["top_k"] = mp.get("top_k", 1) if self.custom_optimizer is not None: config.optimizer = str(self.custom_optimizer.__class__.__name__) @@ -337,13 +347,13 @@ def calculate_metrics(self, y: torch.Tensor, y_hat: torch.Tensor, tag: str) -> L else: _metrics = [] start_index = 0 - for i in range(len(self.hparams.output_cardinality)): - end_index = start_index + self.hparams.output_cardinality[i] + for i, cardinality in enumerate(self.hparams.output_cardinality): + end_index = start_index + cardinality y_hat_i = nn.Softmax(dim=-1)(y_hat[:, start_index:end_index].squeeze()) if prob_inp: - _metric = metric(y_hat_i, y[:,i:i+1].squeeze(), **metric_params) + _metric = metric(y_hat_i, y[:,i:i+1].squeeze(), **metric_params.sub_params_list[i]) else: - _metric = metric(torch.argmax(y_hat_i, dim=-1), y[:,i:i+1].squeeze(), **metric_params) + _metric = metric(torch.argmax(y_hat_i, dim=-1), y[:,i:i+1].squeeze(), **metric_params.sub_params_list[i]) if len(self.hparams.output_cardinality) > 1: self.log( f"{tag}_{metric_str}_{i}", From 1df34a85c3dc3c8b5ec1ec22746f9fd03542023e Mon Sep 17 00:00:00 2001 From: Yony Bresler Date: Wed, 24 Apr 2024 19:50:25 +0000 Subject: [PATCH 03/11] fix to init metrics param config in multi-target --- src/pytorch_tabular/models/base_model.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/pytorch_tabular/models/base_model.py b/src/pytorch_tabular/models/base_model.py index c92a4abd..a9fae691 100644 --- a/src/pytorch_tabular/models/base_model.py +++ b/src/pytorch_tabular/models/base_model.py @@ -119,17 +119,18 @@ def __init__( config.metrics_params.append(vars(metric)) if config.task == "classification": config.metrics_prob_input = self.custom_metrics_prob_inputs + for i, mp in enumerate(config.metrics_params): + mp.sub_params_list = [] + for j, num_classes in enumerate(inferred_config.output_cardinality): + config.metrics_params[i].sub_params_list.append(OmegaConf.create({"task": mp.get("task", "multiclass"), + "num_classes": mp.get("num_classes", num_classes)})) + # Updating default metrics in config elif config.task == "classification": - # FIXME need to revise metrics_params already from above if there are custom - # For classification, metrics_params becomes a 2D list - # config.metrics_params[0] = [] # Adding metric_params to config for classification task for i, mp in enumerate(config.metrics_params): mp.sub_params_list = [] for j, num_classes in enumerate(inferred_config.output_cardinality): - # For classification task, output_dim == number of classses - #config.metrics_params[i].append() #config.metrics_params[i][j]["task"] = mp.get("task", "multiclass") #config.metrics_params[i][j]["num_classes"] = mp.get("num_classes", num_classes) From 76963a89e442d9e80842b7f82c41c4361b3f3e04 Mon Sep 17 00:00:00 2001 From: Yony Bresler Date: Wed, 24 Apr 2024 19:50:45 +0000 Subject: [PATCH 04/11] updates pytests to include multi-target classification --- tests/test_autoint.py | 4 +++- tests/test_categorical_embedding.py | 4 +++- tests/test_common.py | 3 ++- tests/test_danet.py | 4 +++- tests/test_ft_transformer.py | 4 +++- tests/test_gate.py | 4 +++- tests/test_mdn.py | 4 +++- tests/test_node.py | 4 +++- tests/test_ssl.py | 4 +++- tests/test_tabnet.py | 4 +++- tests/test_tabtransformer.py | 4 +++- 11 files changed, 32 insertions(+), 11 deletions(-) diff --git a/tests/test_autoint.py b/tests/test_autoint.py index 166dd8f7..836586a9 100644 --- a/tests/test_autoint.py +++ b/tests/test_autoint.py @@ -77,6 +77,7 @@ def test_regression( assert pred_df.shape[0] == test.shape[0] +@pytest.mark.parametrize("multi_target", [False, True]) @pytest.mark.parametrize( "continuous_cols", [ @@ -90,6 +91,7 @@ def test_regression( @pytest.mark.parametrize("batch_norm_continuous_input", [True, False]) def test_classification( classification_data, + multi_target, continuous_cols, categorical_cols, continuous_feature_transform, @@ -99,7 +101,7 @@ def test_classification( ): (train, test, target) = classification_data data_config = DataConfig( - target=target, + target=target + ["feature_53"] if multi_target else target, continuous_cols=continuous_cols, categorical_cols=categorical_cols, continuous_feature_transform=continuous_feature_transform, diff --git a/tests/test_categorical_embedding.py b/tests/test_categorical_embedding.py index efce3742..45b051f9 100644 --- a/tests/test_categorical_embedding.py +++ b/tests/test_categorical_embedding.py @@ -123,6 +123,7 @@ def test_regression( assert pred_df.shape[0] == test.shape[0] +@pytest.mark.parametrize("multi_target", [False, True]) @pytest.mark.parametrize( "continuous_cols", [ @@ -135,6 +136,7 @@ def test_regression( @pytest.mark.parametrize("normalize_continuous_features", [True]) def test_classification( classification_data, + multi_target, continuous_cols, categorical_cols, continuous_feature_transform, @@ -145,7 +147,7 @@ def test_classification( return data_config = DataConfig( - target=target, + target=target + ["feature_53"] if multi_target else target, continuous_cols=continuous_cols, categorical_cols=categorical_cols, continuous_feature_transform=continuous_feature_transform, diff --git a/tests/test_common.py b/tests/test_common.py index daf77872..fe82ad99 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -750,7 +750,8 @@ def test_cross_validate_regression( [ "accuracy", None, - lambda y_true, y_pred: accuracy_score(y_true, y_pred["prediction"].values), + #lambda y_true, y_pred: accuracy_score(y_true, y_pred["prediction"].values), + lambda y_true, y_pred: accuracy_score(y_true, y_pred["target_prediction"].values), ], ) @pytest.mark.parametrize("return_oof", [True]) diff --git a/tests/test_danet.py b/tests/test_danet.py index dc01ecd7..c2070a28 100644 --- a/tests/test_danet.py +++ b/tests/test_danet.py @@ -78,6 +78,7 @@ def test_regression( assert pred_df.shape[0] == test.shape[0] +@pytest.mark.parametrize("multi_target", [False, True]) @pytest.mark.parametrize( "continuous_cols", [ @@ -89,6 +90,7 @@ def test_regression( @pytest.mark.parametrize("normalize_continuous_features", [True]) def test_classification( classification_data, + multi_target, continuous_cols, categorical_cols, continuous_feature_transform, @@ -96,7 +98,7 @@ def test_classification( ): (train, test, target) = classification_data data_config = DataConfig( - target=target, + target=target + ["feature_53"] if multi_target else target, continuous_cols=continuous_cols, categorical_cols=categorical_cols, continuous_feature_transform=continuous_feature_transform, diff --git a/tests/test_ft_transformer.py b/tests/test_ft_transformer.py index f83b5b58..6b2635c1 100644 --- a/tests/test_ft_transformer.py +++ b/tests/test_ft_transformer.py @@ -84,6 +84,7 @@ def test_regression( assert pred_df.shape[0] == test.shape[0] +@pytest.mark.parametrize("multi_target", [False, True]) @pytest.mark.parametrize( "continuous_cols", [ @@ -95,6 +96,7 @@ def test_regression( @pytest.mark.parametrize("normalize_continuous_features", [True]) def test_classification( classification_data, + multi_target, continuous_cols, categorical_cols, continuous_feature_transform, @@ -102,7 +104,7 @@ def test_classification( ): (train, test, target) = classification_data data_config = DataConfig( - target=target, + target=target + ["feature_53"] if multi_target else target, continuous_cols=continuous_cols, categorical_cols=categorical_cols, continuous_feature_transform=continuous_feature_transform, diff --git a/tests/test_gate.py b/tests/test_gate.py index aed057f0..0c6158ac 100644 --- a/tests/test_gate.py +++ b/tests/test_gate.py @@ -83,6 +83,7 @@ def test_regression( assert pred_df.shape[0] == test.shape[0] +@pytest.mark.parametrize("multi_target", [False, True]) @pytest.mark.parametrize( "continuous_cols", [ @@ -94,6 +95,7 @@ def test_regression( @pytest.mark.parametrize("normalize_continuous_features", [True]) def test_classification( classification_data, + multi_target, continuous_cols, categorical_cols, continuous_feature_transform, @@ -101,7 +103,7 @@ def test_classification( ): (train, test, target) = classification_data data_config = DataConfig( - target=target, + target=target + ["feature_53"] if multi_target else target, continuous_cols=continuous_cols, categorical_cols=categorical_cols, continuous_feature_transform=continuous_feature_transform, diff --git a/tests/test_mdn.py b/tests/test_mdn.py index bd7fd546..ec7a98a9 100644 --- a/tests/test_mdn.py +++ b/tests/test_mdn.py @@ -75,6 +75,7 @@ def test_regression( assert pred_df.shape[0] == test.shape[0] +@pytest.mark.parametrize("multi_target", [False, True]) @pytest.mark.parametrize( "continuous_cols", [ @@ -88,6 +89,7 @@ def test_regression( @pytest.mark.parametrize("num_gaussian", [1, 2]) def test_classification( classification_data, + multi_target, continuous_cols, categorical_cols, continuous_feature_transform, @@ -96,7 +98,7 @@ def test_classification( ): (train, test, target) = classification_data data_config = DataConfig( - target=target, + target=target + ["feature_53"] if multi_target else target, continuous_cols=continuous_cols, categorical_cols=categorical_cols, continuous_feature_transform=continuous_feature_transform, diff --git a/tests/test_node.py b/tests/test_node.py index 31dcb06a..2af0e1bb 100644 --- a/tests/test_node.py +++ b/tests/test_node.py @@ -81,6 +81,7 @@ def test_regression( assert pred_df.shape[0] == test.shape[0] +@pytest.mark.parametrize("multi_target", [False, True]) @pytest.mark.parametrize( "continuous_cols", [ @@ -92,6 +93,7 @@ def test_regression( @pytest.mark.parametrize("normalize_continuous_features", [True]) def test_classification( classification_data, + multi_target, continuous_cols, categorical_cols, continuous_feature_transform, @@ -99,7 +101,7 @@ def test_classification( ): (train, test, target) = classification_data data_config = DataConfig( - target=target, + target=target + ["feature_53"] if multi_target else target, continuous_cols=continuous_cols, categorical_cols=categorical_cols, continuous_feature_transform=continuous_feature_transform, diff --git a/tests/test_ssl.py b/tests/test_ssl.py index aa92ac07..a55d3087 100644 --- a/tests/test_ssl.py +++ b/tests/test_ssl.py @@ -147,6 +147,7 @@ def test_regression( assert pred_df.shape[0] == test.shape[0] +@pytest.mark.parametrize("multi_target", [False, True]) @pytest.mark.parametrize( "continuous_cols", [ @@ -159,6 +160,7 @@ def test_regression( @pytest.mark.parametrize("freeze_backbone", [False]) def test_classification( classification_data, + multi_target, continuous_cols, categorical_cols, continuous_feature_transform, @@ -170,7 +172,7 @@ def test_classification( ssl_train, ssl_val = train_test_split(ssl, random_state=42) finetune_train, finetune_val = train_test_split(finetune, random_state=42) data_config = DataConfig( - target=target, + target=target + ["feature_53"] if multi_target else target, continuous_cols=continuous_cols, categorical_cols=categorical_cols, continuous_feature_transform=continuous_feature_transform, diff --git a/tests/test_tabnet.py b/tests/test_tabnet.py index 30b6ec99..efe54fb8 100644 --- a/tests/test_tabnet.py +++ b/tests/test_tabnet.py @@ -76,6 +76,7 @@ def test_regression( assert pred_df.shape[0] == test.shape[0] +@pytest.mark.parametrize("multi_target", [False, True]) @pytest.mark.parametrize( "continuous_cols", [[f"feature_{i}" for i in range(54)]], @@ -85,6 +86,7 @@ def test_regression( @pytest.mark.parametrize("normalize_continuous_features", [True]) def test_classification( classification_data, + multi_target, continuous_cols, categorical_cols, continuous_feature_transform, @@ -92,7 +94,7 @@ def test_classification( ): (train, test, target) = classification_data data_config = DataConfig( - target=target, + target=target + ["feature_53"] if multi_target else target, continuous_cols=continuous_cols, categorical_cols=categorical_cols, continuous_feature_transform=continuous_feature_transform, diff --git a/tests/test_tabtransformer.py b/tests/test_tabtransformer.py index 2e0c1478..f2d6f692 100644 --- a/tests/test_tabtransformer.py +++ b/tests/test_tabtransformer.py @@ -82,6 +82,7 @@ def test_regression( assert pred_df.shape[0] == test.shape[0] +@pytest.mark.parametrize("multi_target", [False, True]) @pytest.mark.parametrize( "continuous_cols", [ @@ -93,6 +94,7 @@ def test_regression( @pytest.mark.parametrize("normalize_continuous_features", [True]) def test_classification( classification_data, + multi_target, continuous_cols, categorical_cols, continuous_feature_transform, @@ -100,7 +102,7 @@ def test_classification( ): (train, test, target) = classification_data data_config = DataConfig( - target=target, + target=target + ["feature_53"] if multi_target else target, continuous_cols=continuous_cols, categorical_cols=categorical_cols, continuous_feature_transform=continuous_feature_transform, From db8e18aa1c748a4c6cac28871d56afefad98501c Mon Sep 17 00:00:00 2001 From: Yony Bresler Date: Wed, 24 Apr 2024 20:15:26 +0000 Subject: [PATCH 05/11] preliminary fix for combine_prediction --- src/pytorch_tabular/tabular_model.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/pytorch_tabular/tabular_model.py b/src/pytorch_tabular/tabular_model.py index 5adf1a7d..9072ebff 100644 --- a/src/pytorch_tabular/tabular_model.py +++ b/src/pytorch_tabular/tabular_model.py @@ -2036,14 +2036,15 @@ def _combine_predictions( elif callable(aggregate): bagged_pred = aggregate(pred_prob_l) if self.config.task == "classification": - classes = self.datamodule.label_encoder.classes_ + # FIXME need to iterate .label_encoder[x] + classes = self.datamodule.label_encoder[0].classes_ if aggregate == "hard_voting": pred_df = pd.DataFrame( np.concatenate(pred_prob_l, axis=1), columns=[ f"{c}_probability_fold_{i}" for i in range(len(pred_prob_l)) - for c in self.datamodule.label_encoder.classes_ + for c in classes ], index=pred_idx, ) @@ -2052,7 +2053,8 @@ def _combine_predictions( final_pred = classes[np.argmax(bagged_pred, axis=1)] pred_df = pd.DataFrame( bagged_pred, - columns=[f"{c}_probability" for c in self.datamodule.label_encoder.classes_], + # FIXME + columns=[f"{c}_probability" for c in self.datamodule.label_encoder[0].classes_], index=pred_idx, ) pred_df["prediction"] = final_pred From 3b4f10ae02e22e1a43c0820081bced27211763f2 Mon Sep 17 00:00:00 2001 From: Yony Bresler Date: Tue, 30 Apr 2024 20:41:28 +0000 Subject: [PATCH 06/11] Documentation updates --- README.md | 2 +- docs/gs_usage.md | 2 +- examples/__only_for_dev__/adhoc_scaffold.py | 3 +-- examples/__only_for_dev__/to_test_dae.py | 3 +-- src/pytorch_tabular/models/base_model.py | 5 ++--- 5 files changed, 6 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index eaf62118..3032281a 100644 --- a/README.md +++ b/README.md @@ -100,7 +100,7 @@ from pytorch_tabular.config import ( data_config = DataConfig( target=[ "target" - ], # target should always be a list. Multi-targets are only supported for regression. Multi-Task Classification is not implemented + ], # target should always be a list. continuous_cols=num_col_names, categorical_cols=cat_col_names, ) diff --git a/docs/gs_usage.md b/docs/gs_usage.md index 7285c519..af3c249f 100644 --- a/docs/gs_usage.md +++ b/docs/gs_usage.md @@ -14,7 +14,7 @@ from pytorch_tabular.config import ( data_config = DataConfig( target=[ "target" - ], # target should always be a list. Multi-targets are only supported for regression. Multi-Task Classification is not implemented + ], # target should always be a list. continuous_cols=num_col_names, categorical_cols=cat_col_names, ) diff --git a/examples/__only_for_dev__/adhoc_scaffold.py b/examples/__only_for_dev__/adhoc_scaffold.py index d028827f..efedb4af 100644 --- a/examples/__only_for_dev__/adhoc_scaffold.py +++ b/examples/__only_for_dev__/adhoc_scaffold.py @@ -53,8 +53,7 @@ def print_metrics(y_true, y_pred, tag): from pytorch_tabular.models import GatedAdditiveTreeEnsembleConfig # noqa: E402 data_config = DataConfig( - # target should always be a list. Multi-targets are only supported for regression. - # Multi-Task Classification is not implemented + # target should always be a list. target=["target"], continuous_cols=num_col_names, categorical_cols=cat_col_names, diff --git a/examples/__only_for_dev__/to_test_dae.py b/examples/__only_for_dev__/to_test_dae.py index c00a5c1f..5d91d125 100644 --- a/examples/__only_for_dev__/to_test_dae.py +++ b/examples/__only_for_dev__/to_test_dae.py @@ -145,8 +145,7 @@ def print_metrics(y_true, y_pred, tag): lr = 1e-3 data_config = DataConfig( - # target should always be a list. Multi-targets are only supported for regression. - # Multi-Task Classification is not implemented + # target should always be a list. target=[target_name], continuous_cols=num_col_names, categorical_cols=cat_col_names, diff --git a/src/pytorch_tabular/models/base_model.py b/src/pytorch_tabular/models/base_model.py index a9fae691..ff77f3dc 100644 --- a/src/pytorch_tabular/models/base_model.py +++ b/src/pytorch_tabular/models/base_model.py @@ -118,6 +118,7 @@ def __init__( config.metrics.append(metric.__name__) config.metrics_params.append(vars(metric)) if config.task == "classification": + # Create a parameter set for each metric and target pair config.metrics_prob_input = self.custom_metrics_prob_inputs for i, mp in enumerate(config.metrics_params): mp.sub_params_list = [] @@ -129,11 +130,9 @@ def __init__( elif config.task == "classification": # Adding metric_params to config for classification task for i, mp in enumerate(config.metrics_params): + # Create a parameter set for each metric and target pair mp.sub_params_list = [] for j, num_classes in enumerate(inferred_config.output_cardinality): - #config.metrics_params[i][j]["task"] = mp.get("task", "multiclass") - #config.metrics_params[i][j]["num_classes"] = mp.get("num_classes", num_classes) - config.metrics_params[i].sub_params_list.append(OmegaConf.create({"task": mp.get("task", "multiclass"), "num_classes": mp.get("num_classes", num_classes)})) From 766f01f65ea07e2b622a4240f9672d4939b59e67 Mon Sep 17 00:00:00 2001 From: Yony Bresler Date: Wed, 1 May 2024 15:32:18 +0000 Subject: [PATCH 07/11] linter cleanup --- src/pytorch_tabular/models/base_model.py | 29 ++++++++++++++++------- src/pytorch_tabular/tabular_datamodule.py | 8 +++++-- src/pytorch_tabular/tabular_model.py | 6 +---- tests/test_common.py | 1 - 4 files changed, 27 insertions(+), 17 deletions(-) diff --git a/src/pytorch_tabular/models/base_model.py b/src/pytorch_tabular/models/base_model.py index ff77f3dc..b8bacddb 100644 --- a/src/pytorch_tabular/models/base_model.py +++ b/src/pytorch_tabular/models/base_model.py @@ -13,7 +13,7 @@ import torch import torch.nn as nn import torchmetrics -from omegaconf import DictConfig, OmegaConf, ListConfig +from omegaconf import DictConfig, OmegaConf from pandas import DataFrame from torch import Tensor from torch.optim import Optimizer @@ -123,9 +123,15 @@ def __init__( for i, mp in enumerate(config.metrics_params): mp.sub_params_list = [] for j, num_classes in enumerate(inferred_config.output_cardinality): - config.metrics_params[i].sub_params_list.append(OmegaConf.create({"task": mp.get("task", "multiclass"), - "num_classes": mp.get("num_classes", num_classes)})) - + config.metrics_params[i].sub_params_list.append( + OmegaConf.create( + { + "task": mp.get("task", "multiclass"), + "num_classes": mp.get("num_classes", num_classes), + } + ) + ) + # Updating default metrics in config elif config.task == "classification": # Adding metric_params to config for classification task @@ -133,8 +139,11 @@ def __init__( # Create a parameter set for each metric and target pair mp.sub_params_list = [] for j, num_classes in enumerate(inferred_config.output_cardinality): - config.metrics_params[i].sub_params_list.append(OmegaConf.create({"task": mp.get("task", "multiclass"), - "num_classes": mp.get("num_classes", num_classes)})) + config.metrics_params[i].sub_params_list.append( + OmegaConf.create( + {"task": mp.get("task", "multiclass"), "num_classes": mp.get("num_classes", num_classes)} + ) + ) if config.metrics[i] in ( "accuracy", @@ -288,7 +297,7 @@ def calculate_loss(self, output: Dict, y: torch.Tensor, tag: str) -> torch.Tenso logger=True, prog_bar=False, ) - start_index = end_index + start_index = end_index self.log( f"{tag}_loss", computed_loss, @@ -351,9 +360,11 @@ def calculate_metrics(self, y: torch.Tensor, y_hat: torch.Tensor, tag: str) -> L end_index = start_index + cardinality y_hat_i = nn.Softmax(dim=-1)(y_hat[:, start_index:end_index].squeeze()) if prob_inp: - _metric = metric(y_hat_i, y[:,i:i+1].squeeze(), **metric_params.sub_params_list[i]) + _metric = metric(y_hat_i, y[:, i : i + 1].squeeze(), **metric_params.sub_params_list[i]) else: - _metric = metric(torch.argmax(y_hat_i, dim=-1), y[:,i:i+1].squeeze(), **metric_params.sub_params_list[i]) + _metric = metric( + torch.argmax(y_hat_i, dim=-1), y[:, i : i + 1].squeeze(), **metric_params.sub_params_list[i] + ) if len(self.hparams.output_cardinality) > 1: self.log( f"{tag}_{metric_str}_{i}", diff --git a/src/pytorch_tabular/tabular_datamodule.py b/src/pytorch_tabular/tabular_datamodule.py index b19cfd64..916c58b4 100644 --- a/src/pytorch_tabular/tabular_datamodule.py +++ b/src/pytorch_tabular/tabular_datamodule.py @@ -282,10 +282,14 @@ def _update_config(self, config) -> InferredConfig: elif config.task == "classification": # self._output_dim_clf = len(np.unique(self.train_dataset.y)) if config.target else None if self.train is not None: - output_cardinality = self.train[config.target].fillna("NA").nunique().tolist() if config.target else None + output_cardinality = ( + self.train[config.target].fillna("NA").nunique().tolist() if config.target else None + ) output_dim = sum(output_cardinality) else: - output_cardinality = self.train_dataset.data[config.target].fillna("NA").nunique().tolist() if config.target else None + output_cardinality = ( + self.train_dataset.data[config.target].fillna("NA").nunique().tolist() if config.target else None + ) output_dim = sum(output_cardinality) elif config.task == "ssl": output_cardinality = None diff --git a/src/pytorch_tabular/tabular_model.py b/src/pytorch_tabular/tabular_model.py index 9072ebff..9fc51610 100644 --- a/src/pytorch_tabular/tabular_model.py +++ b/src/pytorch_tabular/tabular_model.py @@ -2041,11 +2041,7 @@ def _combine_predictions( if aggregate == "hard_voting": pred_df = pd.DataFrame( np.concatenate(pred_prob_l, axis=1), - columns=[ - f"{c}_probability_fold_{i}" - for i in range(len(pred_prob_l)) - for c in classes - ], + columns=[f"{c}_probability_fold_{i}" for i in range(len(pred_prob_l)) for c in classes], index=pred_idx, ) pred_df["prediction"] = classes[final_pred] diff --git a/tests/test_common.py b/tests/test_common.py index fe82ad99..cb53dca2 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -750,7 +750,6 @@ def test_cross_validate_regression( [ "accuracy", None, - #lambda y_true, y_pred: accuracy_score(y_true, y_pred["prediction"].values), lambda y_true, y_pred: accuracy_score(y_true, y_pred["target_prediction"].values), ], ) From a2bcf216d77777e6e9a28671f0d299c0ffaee10d Mon Sep 17 00:00:00 2001 From: Yony Bresler Date: Thu, 13 Jun 2024 15:30:11 +0000 Subject: [PATCH 08/11] Bugfix for metrics in multi-target classification --- src/pytorch_tabular/models/base_model.py | 35 +++++++++--------------- 1 file changed, 13 insertions(+), 22 deletions(-) diff --git a/src/pytorch_tabular/models/base_model.py b/src/pytorch_tabular/models/base_model.py index b8bacddb..ae131d03 100644 --- a/src/pytorch_tabular/models/base_model.py +++ b/src/pytorch_tabular/models/base_model.py @@ -13,7 +13,7 @@ import torch import torch.nn as nn import torchmetrics -from omegaconf import DictConfig, OmegaConf +from omegaconf import DictConfig, OmegaConf, ListConfig from pandas import DataFrame from torch import Tensor from torch.optim import Optimizer @@ -118,32 +118,24 @@ def __init__( config.metrics.append(metric.__name__) config.metrics_params.append(vars(metric)) if config.task == "classification": - # Create a parameter set for each metric and target pair config.metrics_prob_input = self.custom_metrics_prob_inputs for i, mp in enumerate(config.metrics_params): mp.sub_params_list = [] for j, num_classes in enumerate(inferred_config.output_cardinality): - config.metrics_params[i].sub_params_list.append( - OmegaConf.create( - { - "task": mp.get("task", "multiclass"), - "num_classes": mp.get("num_classes", num_classes), - } - ) - ) - + config.metrics_params[i].sub_params_list.append(OmegaConf.create({"task": mp.get("task", "multiclass"), + "num_classes": mp.get("num_classes", num_classes)})) + # Updating default metrics in config elif config.task == "classification": # Adding metric_params to config for classification task for i, mp in enumerate(config.metrics_params): - # Create a parameter set for each metric and target pair mp.sub_params_list = [] for j, num_classes in enumerate(inferred_config.output_cardinality): - config.metrics_params[i].sub_params_list.append( - OmegaConf.create( - {"task": mp.get("task", "multiclass"), "num_classes": mp.get("num_classes", num_classes)} - ) - ) + #config.metrics_params[i][j]["task"] = mp.get("task", "multiclass") + #config.metrics_params[i][j]["num_classes"] = mp.get("num_classes", num_classes) + + config.metrics_params[i].sub_params_list.append(OmegaConf.create({"task": mp.get("task", "multiclass"), + "num_classes": mp.get("num_classes", num_classes)})) if config.metrics[i] in ( "accuracy", @@ -297,7 +289,7 @@ def calculate_loss(self, output: Dict, y: torch.Tensor, tag: str) -> torch.Tenso logger=True, prog_bar=False, ) - start_index = end_index + start_index = end_index self.log( f"{tag}_loss", computed_loss, @@ -360,11 +352,9 @@ def calculate_metrics(self, y: torch.Tensor, y_hat: torch.Tensor, tag: str) -> L end_index = start_index + cardinality y_hat_i = nn.Softmax(dim=-1)(y_hat[:, start_index:end_index].squeeze()) if prob_inp: - _metric = metric(y_hat_i, y[:, i : i + 1].squeeze(), **metric_params.sub_params_list[i]) + _metric = metric(y_hat_i, y[:,i:i+1].squeeze(), **metric_params.sub_params_list[i]) else: - _metric = metric( - torch.argmax(y_hat_i, dim=-1), y[:, i : i + 1].squeeze(), **metric_params.sub_params_list[i] - ) + _metric = metric(torch.argmax(y_hat_i, dim=-1), y[:,i:i+1].squeeze(), **metric_params.sub_params_list[i]) if len(self.hparams.output_cardinality) > 1: self.log( f"{tag}_{metric_str}_{i}", @@ -375,6 +365,7 @@ def calculate_metrics(self, y: torch.Tensor, y_hat: torch.Tensor, tag: str) -> L prog_bar=False, ) _metrics.append(_metric) + start_index = end_index avg_metric = torch.stack(_metrics, dim=0).sum() metrics.append(avg_metric) self.log( From 31425df90e72492861b4776d5bfd5cd6f6b2ef3e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 13 Jun 2024 15:32:29 +0000 Subject: [PATCH 09/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/pytorch_tabular/models/base_model.py | 33 ++++++++++++++++-------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/src/pytorch_tabular/models/base_model.py b/src/pytorch_tabular/models/base_model.py index ae131d03..77b36f49 100644 --- a/src/pytorch_tabular/models/base_model.py +++ b/src/pytorch_tabular/models/base_model.py @@ -13,7 +13,7 @@ import torch import torch.nn as nn import torchmetrics -from omegaconf import DictConfig, OmegaConf, ListConfig +from omegaconf import DictConfig, OmegaConf from pandas import DataFrame from torch import Tensor from torch.optim import Optimizer @@ -122,20 +122,29 @@ def __init__( for i, mp in enumerate(config.metrics_params): mp.sub_params_list = [] for j, num_classes in enumerate(inferred_config.output_cardinality): - config.metrics_params[i].sub_params_list.append(OmegaConf.create({"task": mp.get("task", "multiclass"), - "num_classes": mp.get("num_classes", num_classes)})) - + config.metrics_params[i].sub_params_list.append( + OmegaConf.create( + { + "task": mp.get("task", "multiclass"), + "num_classes": mp.get("num_classes", num_classes), + } + ) + ) + # Updating default metrics in config elif config.task == "classification": # Adding metric_params to config for classification task for i, mp in enumerate(config.metrics_params): mp.sub_params_list = [] for j, num_classes in enumerate(inferred_config.output_cardinality): - #config.metrics_params[i][j]["task"] = mp.get("task", "multiclass") - #config.metrics_params[i][j]["num_classes"] = mp.get("num_classes", num_classes) + # config.metrics_params[i][j]["task"] = mp.get("task", "multiclass") + # config.metrics_params[i][j]["num_classes"] = mp.get("num_classes", num_classes) - config.metrics_params[i].sub_params_list.append(OmegaConf.create({"task": mp.get("task", "multiclass"), - "num_classes": mp.get("num_classes", num_classes)})) + config.metrics_params[i].sub_params_list.append( + OmegaConf.create( + {"task": mp.get("task", "multiclass"), "num_classes": mp.get("num_classes", num_classes)} + ) + ) if config.metrics[i] in ( "accuracy", @@ -289,7 +298,7 @@ def calculate_loss(self, output: Dict, y: torch.Tensor, tag: str) -> torch.Tenso logger=True, prog_bar=False, ) - start_index = end_index + start_index = end_index self.log( f"{tag}_loss", computed_loss, @@ -352,9 +361,11 @@ def calculate_metrics(self, y: torch.Tensor, y_hat: torch.Tensor, tag: str) -> L end_index = start_index + cardinality y_hat_i = nn.Softmax(dim=-1)(y_hat[:, start_index:end_index].squeeze()) if prob_inp: - _metric = metric(y_hat_i, y[:,i:i+1].squeeze(), **metric_params.sub_params_list[i]) + _metric = metric(y_hat_i, y[:, i : i + 1].squeeze(), **metric_params.sub_params_list[i]) else: - _metric = metric(torch.argmax(y_hat_i, dim=-1), y[:,i:i+1].squeeze(), **metric_params.sub_params_list[i]) + _metric = metric( + torch.argmax(y_hat_i, dim=-1), y[:, i : i + 1].squeeze(), **metric_params.sub_params_list[i] + ) if len(self.hparams.output_cardinality) > 1: self.log( f"{tag}_{metric_str}_{i}", From a08e542469c0ca524258f405f570ae9d80ff5425 Mon Sep 17 00:00:00 2001 From: Yony Bresler Date: Wed, 10 Jul 2024 20:09:55 +0000 Subject: [PATCH 10/11] Added new tutorial for multi-target classification --- .../15-Multi Target Classification.ipynb | 2008 +++++++++++++++++ 1 file changed, 2008 insertions(+) create mode 100644 docs/tutorials/15-Multi Target Classification.ipynb diff --git a/docs/tutorials/15-Multi Target Classification.ipynb b/docs/tutorials/15-Multi Target Classification.ipynb new file mode 100644 index 00000000..6d882cc2 --- /dev/null +++ b/docs/tutorials/15-Multi Target Classification.ipynb @@ -0,0 +1,2008 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [ + "from sklearn.datasets import make_classification\n", + "from sklearn.model_selection import train_test_split\n", + "from sklearn.metrics import accuracy_score, f1_score\n", + "import numpy as np\n", + "import pandas as pd\n", + "\n", + "from pytorch_tabular.utils import make_mixed_dataset, print_metrics\n", + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [ + "# Load dataset\n", + "data, cat_col_names, num_col_names = make_mixed_dataset(task=\"classification\", n_samples=10000, n_features=8, n_categories=4, weights=[0.8], random_state=42)\n", + "\n", + "# Create a new, second target\n", + "data['second_target'] = 0\n", + "for c in cat_col_names:\n", + " data.second_target += data[c]\n", + "\n", + "# Correlate it to 1st target\n", + "data.second_target += (data.target == 'class_0')\n", + "\n", + "# Create random discrete noise to make task non-trivial\n", + "random_noise = np.random.normal(0, 0.8, data.shape[0]).astype(int)\n", + "data.second_target = (data.second_target + random_noise).mod(3) \n", + "\n", + "# Now that dataset is complete, we can perform train/test split\n", + "train, test = train_test_split(data, random_state=42)\n", + "train, val = train_test_split(train, random_state=42)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "target\n", + "class_0 0.7968\n", + "class_1 0.2032\n", + "Name: proportion, dtype: float64" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data.target.value_counts(normalize=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "second_target\n", + "2.0 0.3364\n", + "1.0 0.3354\n", + "0.0 0.3282\n", + "Name: proportion, dtype: float64" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data.second_target.value_counts(normalize=True)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": { + "Collapsed": "false" + }, + "source": [ + "# Importing the Library" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [ + "from pytorch_tabular import TabularModel\n", + "from pytorch_tabular.models import CategoryEmbeddingModelConfig\n", + "from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig, ExperimentConfig\n", + "from pytorch_tabular.models.common.heads import LinearHeadConfig" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "results = []" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": { + "Collapsed": "false" + }, + "source": [ + "## Define the Configs\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [ + "trainer_config = TrainerConfig(\n", + " auto_lr_find=True, # Runs the LRFinder to automatically derive a learning rate\n", + " batch_size=1024,\n", + " max_epochs=100,\n", + " early_stopping=\"valid_loss\", # Monitor valid_loss for early stopping\n", + " early_stopping_mode = \"min\", # Set the mode as min because for val_loss, lower is better\n", + " early_stopping_patience=5, # No. of epochs of degradation training will wait before terminating\n", + " checkpoints=\"valid_loss\", # Save best checkpoint monitoring val_loss\n", + " load_best=True, # After training, load the best checkpoint\n", + "# accelerator=\"cpu\"\n", + ")\n", + "optimizer_config = OptimizerConfig()\n", + "\n", + "head_config = LinearHeadConfig(\n", + " layers=\"\", # No additional layer in head, just a mapping layer to output_dim\n", + " dropout=0.1,\n", + " initialization=\"kaiming\"\n", + ").__dict__ # Convert to dict to pass to the model config (OmegaConf doesn't accept objects)\n", + "\n", + "model_config = CategoryEmbeddingModelConfig(\n", + " task=\"classification\",\n", + " layers=\"1024-512-512\", # Number of nodes in each layer\n", + " activation=\"LeakyReLU\", # Activation between each layers\n", + " head = \"LinearHead\", #Linear Head\n", + " head_config = head_config, # Linear Head Config\n", + " learning_rate = 1e-3,\n", + " metrics=[\"f1_score\",\"accuracy\",\"auroc\"], \n", + " metrics_prob_input=[True, False, True] # f1_score needs probability scores, while accuracy doesn't\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Data Config For Each Target" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "data_config_first_target = DataConfig(\n", + " target=['target'], #target should always be a list\n", + " continuous_cols=num_col_names,\n", + " categorical_cols=cat_col_names\n", + ")\n", + "\n", + "data_config_second_target = DataConfig(\n", + " target=['second_target'], #target should always be a list\n", + " continuous_cols=num_col_names,\n", + " categorical_cols=cat_col_names\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": { + "Collapsed": "false" + }, + "source": [ + "## Training the Single-Target Model " + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "Collapsed": "false", + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
2024-07-10 11:38:02,110 - {pytorch_tabular.tabular_model:140} - INFO - Experiment Tracking is turned off           \n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:38:02\u001b[0m,\u001b[1;36m110\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m140\u001b[0m\u001b[1m}\u001b[0m - INFO - Experiment Tracking is turned off \n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Seed set to 42\n" + ] + }, + { + "data": { + "text/html": [ + "
2024-07-10 11:38:02,126 - {pytorch_tabular.tabular_model:541} - INFO - Preparing the DataLoaders                   \n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:38:02\u001b[0m,\u001b[1;36m126\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m541\u001b[0m\u001b[1m}\u001b[0m - INFO - Preparing the DataLoaders \n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
2024-07-10 11:38:02,130 - {pytorch_tabular.tabular_datamodule:507} - INFO - Setting up the datamodule for          \n",
+       "classification task                                                                                                \n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:38:02\u001b[0m,\u001b[1;36m130\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_datamodul\u001b[1;92me:507\u001b[0m\u001b[1m}\u001b[0m - INFO - Setting up the datamodule for \n", + "classification task \n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
2024-07-10 11:38:02,150 - {pytorch_tabular.tabular_model:591} - INFO - Preparing the Model: CategoryEmbeddingModel \n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:38:02\u001b[0m,\u001b[1;36m150\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m591\u001b[0m\u001b[1m}\u001b[0m - INFO - Preparing the Model: CategoryEmbeddingModel \n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
2024-07-10 11:38:02,190 - {pytorch_tabular.tabular_model:338} - INFO - Preparing the Trainer                       \n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:38:02\u001b[0m,\u001b[1;36m190\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m338\u001b[0m\u001b[1m}\u001b[0m - INFO - Preparing the Trainer \n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True (mps), used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "IPU available: False, using: 0 IPUs\n", + "HPU available: False, using: 0 HPUs\n" + ] + }, + { + "data": { + "text/html": [ + "
2024-07-10 11:38:02,366 - {pytorch_tabular.tabular_model:647} - INFO - Auto LR Find Started                        \n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:38:02\u001b[0m,\u001b[1;36m366\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m647\u001b[0m\u001b[1m}\u001b[0m - INFO - Auto LR Find Started \n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/yony/Code/pytorch-tabular-public/.venv/lib/python3.8/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:653: Checkpoint directory /Users/yony/Code/pytorch-tabular-public/docs/tutorials/saved_models exists and is not empty.\n", + "/Users/yony/Code/pytorch-tabular-public/.venv/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.\n", + "/Users/yony/Code/pytorch-tabular-public/.venv/lib/python3.8/site-packages/pytorch_lightning/loops/fit_loop.py:298: The number of training batches (6) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.\n", + "/Users/yony/Code/pytorch-tabular-public/.venv/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "d2cae94714354ea09299a9e44f238d89", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Finding best initial lr: 0%| | 0/100 [00:002024-07-10 11:38:10,876 - {pytorch_tabular.tabular_model:660} - INFO - Suggested LR: 0.025118864315095822. For plot\n", + "and detailed analysis, use `find_learning_rate` method. \n", + "\n" + ], + "text/plain": [ + "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:38:10\u001b[0m,\u001b[1;36m876\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m660\u001b[0m\u001b[1m}\u001b[0m - INFO - Suggested LR: \u001b[1;36m0.025118864315095822\u001b[0m. For plot\n", + "and detailed analysis, use `find_learning_rate` method. \n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
2024-07-10 11:38:10,881 - {pytorch_tabular.tabular_model:669} - INFO - Training Started                            \n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:38:10\u001b[0m,\u001b[1;36m881\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m669\u001b[0m\u001b[1m}\u001b[0m - INFO - Training Started \n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
┏━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓\n",
+       "┃    Name              Type                       Params ┃\n",
+       "┡━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩\n",
+       "│ 0 │ _backbone        │ CategoryEmbeddingBackbone │  804 K │\n",
+       "│ 1 │ _embedding_layer │ Embedding1dLayer          │     68 │\n",
+       "│ 2 │ head             │ LinearHead                │  1.0 K │\n",
+       "│ 3 │ loss             │ CrossEntropyLoss          │      0 │\n",
+       "└───┴──────────────────┴───────────────────────────┴────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┏━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓\n", + "┃\u001b[1;35m \u001b[0m\u001b[1;35m \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mName \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mType \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mParams\u001b[0m\u001b[1;35m \u001b[0m┃\n", + "┡━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩\n", + "│\u001b[2m \u001b[0m\u001b[2m0\u001b[0m\u001b[2m \u001b[0m│ _backbone │ CategoryEmbeddingBackbone │ 804 K │\n", + "│\u001b[2m \u001b[0m\u001b[2m1\u001b[0m\u001b[2m \u001b[0m│ _embedding_layer │ Embedding1dLayer │ 68 │\n", + "│\u001b[2m \u001b[0m\u001b[2m2\u001b[0m\u001b[2m \u001b[0m│ head │ LinearHead │ 1.0 K │\n", + "│\u001b[2m \u001b[0m\u001b[2m3\u001b[0m\u001b[2m \u001b[0m│ loss │ CrossEntropyLoss │ 0 │\n", + "└───┴──────────────────┴───────────────────────────┴────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Trainable params: 805 K                                                                                            \n",
+       "Non-trainable params: 0                                                                                            \n",
+       "Total params: 805 K                                                                                                \n",
+       "Total estimated model params size (MB): 3                                                                          \n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1mTrainable params\u001b[0m: 805 K \n", + "\u001b[1mNon-trainable params\u001b[0m: 0 \n", + "\u001b[1mTotal params\u001b[0m: 805 K \n", + "\u001b[1mTotal estimated model params size (MB)\u001b[0m: 3 \n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "d67a105c98184698acff2a3caab52b74", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
2024-07-10 11:38:29,013 - {pytorch_tabular.tabular_model:680} - INFO - Training the model completed                \n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:38:29\u001b[0m,\u001b[1;36m013\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m680\u001b[0m\u001b[1m}\u001b[0m - INFO - Training the model completed \n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
2024-07-10 11:38:29,014 - {pytorch_tabular.tabular_model:1512} - INFO - Loading the best model                     \n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:38:29\u001b[0m,\u001b[1;36m014\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m1512\u001b[0m\u001b[1m}\u001b[0m - INFO - Loading the best model \n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "543ac972faed487c9fa81364d979f3f7", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/yony/Code/pytorch-tabular-public/.venv/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.\n" + ] + }, + { + "data": { + "text/html": [ + "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
+       "┃        Test metric               DataLoader 0        ┃\n",
+       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
+       "│       test_accuracy           0.9484000205993652     │\n",
+       "│        test_auroc             0.9717501997947693     │\n",
+       "│       test_f1_score           0.9484000205993652     │\n",
+       "│         test_loss             0.17070142924785614    │\n",
+       "│        test_loss_0            0.17070142924785614    │\n",
+       "└───────────────────────────┴───────────────────────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│\u001b[36m \u001b[0m\u001b[36m test_accuracy \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.9484000205993652 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_auroc \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.9717501997947693 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_f1_score \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.9484000205993652 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_loss \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.17070142924785614 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_loss_0 \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.17070142924785614 \u001b[0m\u001b[35m \u001b[0m│\n", + "└───────────────────────────┴───────────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "tabular_model = TabularModel(\n", + " data_config=data_config_first_target,\n", + " model_config=model_config,\n", + " optimizer_config=optimizer_config,\n", + " trainer_config=trainer_config,\n", + ")\n", + "\n", + "tabular_model.fit(train=train, validation=val)\n", + "\n", + "result = tabular_model.evaluate(test)\n", + "\n", + "result = {k: float(v) for k,v in result[0].items()}\n", + "result = pd.DataFrame({'f1':result['test_f1_score'],'auroc':result['test_auroc']},\n", + " index=['1st Target (single mode)'])\n", + "results.append(result)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Now train separately on the 2nd target" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
2024-07-10 11:38:30,958 - {pytorch_tabular.tabular_model:140} - INFO - Experiment Tracking is turned off           \n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:38:30\u001b[0m,\u001b[1;36m958\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m140\u001b[0m\u001b[1m}\u001b[0m - INFO - Experiment Tracking is turned off \n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Seed set to 42\n" + ] + }, + { + "data": { + "text/html": [ + "
2024-07-10 11:38:30,972 - {pytorch_tabular.tabular_model:541} - INFO - Preparing the DataLoaders                   \n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:38:30\u001b[0m,\u001b[1;36m972\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m541\u001b[0m\u001b[1m}\u001b[0m - INFO - Preparing the DataLoaders \n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
2024-07-10 11:38:30,976 - {pytorch_tabular.tabular_datamodule:507} - INFO - Setting up the datamodule for          \n",
+       "classification task                                                                                                \n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:38:30\u001b[0m,\u001b[1;36m976\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_datamodul\u001b[1;92me:507\u001b[0m\u001b[1m}\u001b[0m - INFO - Setting up the datamodule for \n", + "classification task \n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
2024-07-10 11:38:30,996 - {pytorch_tabular.tabular_model:591} - INFO - Preparing the Model: CategoryEmbeddingModel \n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:38:30\u001b[0m,\u001b[1;36m996\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m591\u001b[0m\u001b[1m}\u001b[0m - INFO - Preparing the Model: CategoryEmbeddingModel \n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
2024-07-10 11:38:31,035 - {pytorch_tabular.tabular_model:338} - INFO - Preparing the Trainer                       \n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:38:31\u001b[0m,\u001b[1;36m035\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m338\u001b[0m\u001b[1m}\u001b[0m - INFO - Preparing the Trainer \n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True (mps), used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "IPU available: False, using: 0 IPUs\n", + "HPU available: False, using: 0 HPUs\n" + ] + }, + { + "data": { + "text/html": [ + "
2024-07-10 11:38:31,054 - {pytorch_tabular.tabular_model:647} - INFO - Auto LR Find Started                        \n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:38:31\u001b[0m,\u001b[1;36m054\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m647\u001b[0m\u001b[1m}\u001b[0m - INFO - Auto LR Find Started \n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/yony/Code/pytorch-tabular-public/.venv/lib/python3.8/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:653: Checkpoint directory /Users/yony/Code/pytorch-tabular-public/docs/tutorials/saved_models exists and is not empty.\n", + "/Users/yony/Code/pytorch-tabular-public/.venv/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.\n", + "/Users/yony/Code/pytorch-tabular-public/.venv/lib/python3.8/site-packages/pytorch_lightning/loops/fit_loop.py:298: The number of training batches (6) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.\n", + "/Users/yony/Code/pytorch-tabular-public/.venv/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "ace5c6e2844d4ff58f2550a8e9da9ee4", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Finding best initial lr: 0%| | 0/100 [00:002024-07-10 11:38:38,580 - {pytorch_tabular.tabular_model:660} - INFO - Suggested LR: 0.0002511886431509582. For \n", + "plot and detailed analysis, use `find_learning_rate` method. \n", + "\n" + ], + "text/plain": [ + "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:38:38\u001b[0m,\u001b[1;36m580\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m660\u001b[0m\u001b[1m}\u001b[0m - INFO - Suggested LR: \u001b[1;36m0.0002511886431509582\u001b[0m. For \n", + "plot and detailed analysis, use `find_learning_rate` method. \n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
2024-07-10 11:38:38,584 - {pytorch_tabular.tabular_model:669} - INFO - Training Started                            \n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:38:38\u001b[0m,\u001b[1;36m584\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m669\u001b[0m\u001b[1m}\u001b[0m - INFO - Training Started \n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
┏━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓\n",
+       "┃    Name              Type                       Params ┃\n",
+       "┡━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩\n",
+       "│ 0 │ _backbone        │ CategoryEmbeddingBackbone │  804 K │\n",
+       "│ 1 │ _embedding_layer │ Embedding1dLayer          │     68 │\n",
+       "│ 2 │ head             │ LinearHead                │  1.5 K │\n",
+       "│ 3 │ loss             │ CrossEntropyLoss          │      0 │\n",
+       "└───┴──────────────────┴───────────────────────────┴────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┏━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓\n", + "┃\u001b[1;35m \u001b[0m\u001b[1;35m \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mName \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mType \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mParams\u001b[0m\u001b[1;35m \u001b[0m┃\n", + "┡━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩\n", + "│\u001b[2m \u001b[0m\u001b[2m0\u001b[0m\u001b[2m \u001b[0m│ _backbone │ CategoryEmbeddingBackbone │ 804 K │\n", + "│\u001b[2m \u001b[0m\u001b[2m1\u001b[0m\u001b[2m \u001b[0m│ _embedding_layer │ Embedding1dLayer │ 68 │\n", + "│\u001b[2m \u001b[0m\u001b[2m2\u001b[0m\u001b[2m \u001b[0m│ head │ LinearHead │ 1.5 K │\n", + "│\u001b[2m \u001b[0m\u001b[2m3\u001b[0m\u001b[2m \u001b[0m│ loss │ CrossEntropyLoss │ 0 │\n", + "└───┴──────────────────┴───────────────────────────┴────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Trainable params: 806 K                                                                                            \n",
+       "Non-trainable params: 0                                                                                            \n",
+       "Total params: 806 K                                                                                                \n",
+       "Total estimated model params size (MB): 3                                                                          \n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1mTrainable params\u001b[0m: 806 K \n", + "\u001b[1mNon-trainable params\u001b[0m: 0 \n", + "\u001b[1mTotal params\u001b[0m: 806 K \n", + "\u001b[1mTotal estimated model params size (MB)\u001b[0m: 3 \n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "483e9eeb0b754be9b966618d11ab2baa", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
2024-07-10 11:38:54,807 - {pytorch_tabular.tabular_model:680} - INFO - Training the model completed                \n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:38:54\u001b[0m,\u001b[1;36m807\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m680\u001b[0m\u001b[1m}\u001b[0m - INFO - Training the model completed \n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
2024-07-10 11:38:54,808 - {pytorch_tabular.tabular_model:1512} - INFO - Loading the best model                     \n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:38:54\u001b[0m,\u001b[1;36m808\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m1512\u001b[0m\u001b[1m}\u001b[0m - INFO - Loading the best model \n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "0fad2da095be4448ae4507d193f28c4a", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/yony/Code/pytorch-tabular-public/.venv/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.\n" + ] + }, + { + "data": { + "text/html": [ + "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
+       "┃        Test metric               DataLoader 0        ┃\n",
+       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
+       "│       test_accuracy           0.6808000206947327     │\n",
+       "│        test_auroc             0.8134375810623169     │\n",
+       "│       test_f1_score           0.6808000206947327     │\n",
+       "│         test_loss             0.8243984580039978     │\n",
+       "│        test_loss_0            0.8243984580039978     │\n",
+       "└───────────────────────────┴───────────────────────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│\u001b[36m \u001b[0m\u001b[36m test_accuracy \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.6808000206947327 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_auroc \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.8134375810623169 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_f1_score \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.6808000206947327 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_loss \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.8243984580039978 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_loss_0 \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.8243984580039978 \u001b[0m\u001b[35m \u001b[0m│\n", + "└───────────────────────────┴───────────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "tabular_model = TabularModel(\n", + " data_config=data_config_second_target,\n", + " model_config=model_config,\n", + " optimizer_config=optimizer_config,\n", + " trainer_config=trainer_config,\n", + ")\n", + "\n", + "tabular_model.fit(train=train, validation=val)\n", + "\n", + "result = tabular_model.evaluate(test)\n", + "\n", + "result = {k: float(v) for k,v in result[0].items()}\n", + "result = pd.DataFrame({'f1':result['test_f1_score'],'auroc':result['test_auroc']},\n", + " index=['2nd Target (single mode)'])\n", + "\n", + "results.append(result)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Multi-Target Training\n", + "\n", + "Instead of training one model for the first target, and another model for the second target, we can train a single model that will make a prediction for both targets.\n", + "\n", + "This is usually beneficial in reducing training time, but may also lead to better results since the model may have a better representation (embedding) by learning from multiple targets.\n", + "\n", + "To perform multi-target training, we only need to model the 'target' field in the data_config to include a list of multiple targets.\n", + "\n", + "Results are reported on the sum of all metrics (f1, AU-ROC, etc.), as well as a list of results for each target with the suffix '_n' (starting at n=1), for example, the f1 score for the 2nd target is test_f1_score_0." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
2024-07-10 11:38:57,010 - {pytorch_tabular.tabular_model:140} - INFO - Experiment Tracking is turned off           \n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:38:57\u001b[0m,\u001b[1;36m010\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m140\u001b[0m\u001b[1m}\u001b[0m - INFO - Experiment Tracking is turned off \n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Seed set to 42\n" + ] + }, + { + "data": { + "text/html": [ + "
2024-07-10 11:38:57,023 - {pytorch_tabular.tabular_model:541} - INFO - Preparing the DataLoaders                   \n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:38:57\u001b[0m,\u001b[1;36m023\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m541\u001b[0m\u001b[1m}\u001b[0m - INFO - Preparing the DataLoaders \n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
2024-07-10 11:38:57,028 - {pytorch_tabular.tabular_datamodule:507} - INFO - Setting up the datamodule for          \n",
+       "classification task                                                                                                \n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:38:57\u001b[0m,\u001b[1;36m028\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_datamodul\u001b[1;92me:507\u001b[0m\u001b[1m}\u001b[0m - INFO - Setting up the datamodule for \n", + "classification task \n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
2024-07-10 11:38:57,049 - {pytorch_tabular.tabular_model:591} - INFO - Preparing the Model: CategoryEmbeddingModel \n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:38:57\u001b[0m,\u001b[1;36m049\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m591\u001b[0m\u001b[1m}\u001b[0m - INFO - Preparing the Model: CategoryEmbeddingModel \n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
2024-07-10 11:38:57,089 - {pytorch_tabular.tabular_model:338} - INFO - Preparing the Trainer                       \n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:38:57\u001b[0m,\u001b[1;36m089\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m338\u001b[0m\u001b[1m}\u001b[0m - INFO - Preparing the Trainer \n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True (mps), used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "IPU available: False, using: 0 IPUs\n", + "HPU available: False, using: 0 HPUs\n" + ] + }, + { + "data": { + "text/html": [ + "
2024-07-10 11:38:57,115 - {pytorch_tabular.tabular_model:647} - INFO - Auto LR Find Started                        \n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:38:57\u001b[0m,\u001b[1;36m115\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m647\u001b[0m\u001b[1m}\u001b[0m - INFO - Auto LR Find Started \n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/yony/Code/pytorch-tabular-public/.venv/lib/python3.8/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:653: Checkpoint directory /Users/yony/Code/pytorch-tabular-public/docs/tutorials/saved_models exists and is not empty.\n", + "/Users/yony/Code/pytorch-tabular-public/.venv/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.\n", + "/Users/yony/Code/pytorch-tabular-public/.venv/lib/python3.8/site-packages/pytorch_lightning/loops/fit_loop.py:298: The number of training batches (6) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.\n", + "/Users/yony/Code/pytorch-tabular-public/.venv/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "c952afd65ca84628a3d1142ddc2e936e", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Finding best initial lr: 0%| | 0/100 [00:002024-07-10 11:39:06,162 - {pytorch_tabular.tabular_model:660} - INFO - Suggested LR: 4.786300923226385e-05. For \n", + "plot and detailed analysis, use `find_learning_rate` method. \n", + "\n" + ], + "text/plain": [ + "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:39:06\u001b[0m,\u001b[1;36m162\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m660\u001b[0m\u001b[1m}\u001b[0m - INFO - Suggested LR: \u001b[1;36m4.786300923226385e-05\u001b[0m. For \n", + "plot and detailed analysis, use `find_learning_rate` method. \n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
2024-07-10 11:39:06,166 - {pytorch_tabular.tabular_model:669} - INFO - Training Started                            \n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:39:06\u001b[0m,\u001b[1;36m166\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m669\u001b[0m\u001b[1m}\u001b[0m - INFO - Training Started \n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
┏━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓\n",
+       "┃    Name              Type                       Params ┃\n",
+       "┡━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩\n",
+       "│ 0 │ _backbone        │ CategoryEmbeddingBackbone │  804 K │\n",
+       "│ 1 │ _embedding_layer │ Embedding1dLayer          │     68 │\n",
+       "│ 2 │ head             │ LinearHead                │  2.6 K │\n",
+       "│ 3 │ loss             │ CrossEntropyLoss          │      0 │\n",
+       "└───┴──────────────────┴───────────────────────────┴────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┏━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓\n", + "┃\u001b[1;35m \u001b[0m\u001b[1;35m \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mName \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mType \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mParams\u001b[0m\u001b[1;35m \u001b[0m┃\n", + "┡━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩\n", + "│\u001b[2m \u001b[0m\u001b[2m0\u001b[0m\u001b[2m \u001b[0m│ _backbone │ CategoryEmbeddingBackbone │ 804 K │\n", + "│\u001b[2m \u001b[0m\u001b[2m1\u001b[0m\u001b[2m \u001b[0m│ _embedding_layer │ Embedding1dLayer │ 68 │\n", + "│\u001b[2m \u001b[0m\u001b[2m2\u001b[0m\u001b[2m \u001b[0m│ head │ LinearHead │ 2.6 K │\n", + "│\u001b[2m \u001b[0m\u001b[2m3\u001b[0m\u001b[2m \u001b[0m│ loss │ CrossEntropyLoss │ 0 │\n", + "└───┴──────────────────┴───────────────────────────┴────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Trainable params: 807 K                                                                                            \n",
+       "Non-trainable params: 0                                                                                            \n",
+       "Total params: 807 K                                                                                                \n",
+       "Total estimated model params size (MB): 3                                                                          \n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1mTrainable params\u001b[0m: 807 K \n", + "\u001b[1mNon-trainable params\u001b[0m: 0 \n", + "\u001b[1mTotal params\u001b[0m: 807 K \n", + "\u001b[1mTotal estimated model params size (MB)\u001b[0m: 3 \n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "49d66315050e49ed94823d49ea1ff8c0", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "`Trainer.fit` stopped: `max_epochs=100` reached.\n" + ] + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
2024-07-10 11:40:07,357 - {pytorch_tabular.tabular_model:680} - INFO - Training the model completed                \n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:40:07\u001b[0m,\u001b[1;36m357\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m680\u001b[0m\u001b[1m}\u001b[0m - INFO - Training the model completed \n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
2024-07-10 11:40:07,363 - {pytorch_tabular.tabular_model:1512} - INFO - Loading the best model                     \n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:40:07\u001b[0m,\u001b[1;36m363\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m1512\u001b[0m\u001b[1m}\u001b[0m - INFO - Loading the best model \n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "b06cba39db194aaea56bc99fd05c334b", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/yony/Code/pytorch-tabular-public/.venv/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.\n" + ] + }, + { + "data": { + "text/html": [ + "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
+       "┃        Test metric               DataLoader 0        ┃\n",
+       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
+       "│       test_accuracy           1.5648000240325928     │\n",
+       "│      test_accuracy_0           0.946399986743927     │\n",
+       "│      test_accuracy_1           0.618399977684021     │\n",
+       "│        test_auroc             1.7456306219100952     │\n",
+       "│       test_auroc_0            0.9698674082756042     │\n",
+       "│       test_auroc_1            0.7757631540298462     │\n",
+       "│       test_f1_score           1.5648000240325928     │\n",
+       "│      test_f1_score_0           0.946399986743927     │\n",
+       "│      test_f1_score_1           0.618399977684021     │\n",
+       "│         test_loss             1.0632164478302002     │\n",
+       "│        test_loss_0            0.16491228342056274    │\n",
+       "│        test_loss_1            0.8983041048049927     │\n",
+       "└───────────────────────────┴───────────────────────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│\u001b[36m \u001b[0m\u001b[36m test_accuracy \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 1.5648000240325928 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_accuracy_0 \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.946399986743927 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_accuracy_1 \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.618399977684021 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_auroc \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 1.7456306219100952 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_auroc_0 \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.9698674082756042 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_auroc_1 \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.7757631540298462 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_f1_score \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 1.5648000240325928 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_f1_score_0 \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.946399986743927 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_f1_score_1 \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.618399977684021 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_loss \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 1.0632164478302002 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_loss_0 \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.16491228342056274 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_loss_1 \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.8983041048049927 \u001b[0m\u001b[35m \u001b[0m│\n", + "└───────────────────────────┴───────────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "data_config_multi = DataConfig(\n", + " target=['target','second_target'], #target should always be a list\n", + " continuous_cols=num_col_names,\n", + " categorical_cols=cat_col_names\n", + ")\n", + "\n", + "tabular_model = TabularModel(\n", + " data_config=data_config_multi,\n", + " model_config=model_config,\n", + " optimizer_config=optimizer_config,\n", + " trainer_config=trainer_config,\n", + ")\n", + "\n", + "tabular_model.fit(train=train, validation=val)\n", + "\n", + "result = tabular_model.evaluate(test)\n", + "\n", + "result = {k: float(v) for k,v in result[0].items()}\n", + "result1 = pd.DataFrame({'f1':result['test_f1_score_0'],'auroc':result['test_auroc_0']},\n", + " index=['1st Target (multi-target mode)'])\n", + "results.append(result1)\n", + "result2 = pd.DataFrame({'f1':result['test_f1_score_1'],'auroc':result['test_auroc_1']},\n", + " index=['2nd Target (multi-target mode)'])\n", + "results.append(result2)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Results" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
 f1auroc
1st Target (single mode)0.9484000.971750
2nd Target (single mode)0.6808000.813438
1st Target (multi-target mode)0.9464000.969867
2nd Target (multi-target mode)0.6184000.775763
\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "res_df = pd.concat(results)\n", + "res_df.style.highlight_max(color=\"lightgreen\",axis=0)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this run, we see that multi-target model performed on-par with the single-target on the 1st target, but slightly worse on the 2nd target. This may vary for this artificial dataset depending on random number generation, and in general multi-target may not outperform single-target variants. Additional tuning may be needed for multi-target, e.g. larger embedding size. " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Deep Learning Model\n", + "\n", + "We can also test whether a deeper model benefits from the shared embedding. In this case, we test Gandalf with multi-target classification, similar to our experiment above.\n", + "\n", + "Note: Without an accelerator (GPU), this training will take considerable time on a CPU" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "from pytorch_tabular.models import GANDALFConfig" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
2024-07-10 11:40:09,147 - {pytorch_tabular.tabular_model:140} - INFO - Experiment Tracking is turned off           \n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:40:09\u001b[0m,\u001b[1;36m147\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m140\u001b[0m\u001b[1m}\u001b[0m - INFO - Experiment Tracking is turned off \n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Seed set to 42\n" + ] + }, + { + "data": { + "text/html": [ + "
2024-07-10 11:40:09,160 - {pytorch_tabular.tabular_model:541} - INFO - Preparing the DataLoaders                   \n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:40:09\u001b[0m,\u001b[1;36m160\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m541\u001b[0m\u001b[1m}\u001b[0m - INFO - Preparing the DataLoaders \n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
2024-07-10 11:40:09,165 - {pytorch_tabular.tabular_datamodule:507} - INFO - Setting up the datamodule for          \n",
+       "classification task                                                                                                \n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:40:09\u001b[0m,\u001b[1;36m165\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_datamodul\u001b[1;92me:507\u001b[0m\u001b[1m}\u001b[0m - INFO - Setting up the datamodule for \n", + "classification task \n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
2024-07-10 11:40:09,188 - {pytorch_tabular.tabular_model:591} - INFO - Preparing the Model: GANDALFModel           \n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:40:09\u001b[0m,\u001b[1;36m188\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m591\u001b[0m\u001b[1m}\u001b[0m - INFO - Preparing the Model: GANDALFModel \n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
2024-07-10 11:40:09,212 - {pytorch_tabular.tabular_model:338} - INFO - Preparing the Trainer                       \n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:40:09\u001b[0m,\u001b[1;36m212\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m338\u001b[0m\u001b[1m}\u001b[0m - INFO - Preparing the Trainer \n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True (mps), used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "IPU available: False, using: 0 IPUs\n", + "HPU available: False, using: 0 HPUs\n" + ] + }, + { + "data": { + "text/html": [ + "
2024-07-10 11:40:09,240 - {pytorch_tabular.tabular_model:647} - INFO - Auto LR Find Started                        \n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:40:09\u001b[0m,\u001b[1;36m240\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m647\u001b[0m\u001b[1m}\u001b[0m - INFO - Auto LR Find Started \n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/yony/Code/pytorch-tabular-public/.venv/lib/python3.8/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:653: Checkpoint directory /Users/yony/Code/pytorch-tabular-public/docs/tutorials/saved_models exists and is not empty.\n", + "/Users/yony/Code/pytorch-tabular-public/.venv/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.\n", + "/Users/yony/Code/pytorch-tabular-public/.venv/lib/python3.8/site-packages/pytorch_lightning/loops/fit_loop.py:298: The number of training batches (6) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.\n", + "/Users/yony/Code/pytorch-tabular-public/.venv/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "3d222f1be7284f9aae6216f29e17ea4b", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Finding best initial lr: 0%| | 0/100 [00:002024-07-10 11:40:22,112 - {pytorch_tabular.tabular_model:660} - INFO - Suggested LR: 0.10964781961431852. For plot \n", + "and detailed analysis, use `find_learning_rate` method. \n", + "\n" + ], + "text/plain": [ + "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:40:22\u001b[0m,\u001b[1;36m112\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m660\u001b[0m\u001b[1m}\u001b[0m - INFO - Suggested LR: \u001b[1;36m0.10964781961431852\u001b[0m. For plot \n", + "and detailed analysis, use `find_learning_rate` method. \n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
2024-07-10 11:40:22,115 - {pytorch_tabular.tabular_model:669} - INFO - Training Started                            \n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:40:22\u001b[0m,\u001b[1;36m115\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m669\u001b[0m\u001b[1m}\u001b[0m - INFO - Training Started \n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
┏━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓\n",
+       "┃    Name              Type              Params ┃\n",
+       "┡━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩\n",
+       "│ 0 │ _backbone        │ GANDALFBackbone  │  9.6 K │\n",
+       "│ 1 │ _embedding_layer │ Embedding1dLayer │     68 │\n",
+       "│ 2 │ _head            │ Sequential       │     90 │\n",
+       "│ 3 │ loss             │ CrossEntropyLoss │      0 │\n",
+       "└───┴──────────────────┴──────────────────┴────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┏━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓\n", + "┃\u001b[1;35m \u001b[0m\u001b[1;35m \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mName \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mType \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mParams\u001b[0m\u001b[1;35m \u001b[0m┃\n", + "┡━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩\n", + "│\u001b[2m \u001b[0m\u001b[2m0\u001b[0m\u001b[2m \u001b[0m│ _backbone │ GANDALFBackbone │ 9.6 K │\n", + "│\u001b[2m \u001b[0m\u001b[2m1\u001b[0m\u001b[2m \u001b[0m│ _embedding_layer │ Embedding1dLayer │ 68 │\n", + "│\u001b[2m \u001b[0m\u001b[2m2\u001b[0m\u001b[2m \u001b[0m│ _head │ Sequential │ 90 │\n", + "│\u001b[2m \u001b[0m\u001b[2m3\u001b[0m\u001b[2m \u001b[0m│ loss │ CrossEntropyLoss │ 0 │\n", + "└───┴──────────────────┴──────────────────┴────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Trainable params: 9.8 K                                                                                            \n",
+       "Non-trainable params: 0                                                                                            \n",
+       "Total params: 9.8 K                                                                                                \n",
+       "Total estimated model params size (MB): 0                                                                          \n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1mTrainable params\u001b[0m: 9.8 K \n", + "\u001b[1mNon-trainable params\u001b[0m: 0 \n", + "\u001b[1mTotal params\u001b[0m: 9.8 K \n", + "\u001b[1mTotal estimated model params size (MB)\u001b[0m: 0 \n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "833c09b04f5a4841a69215c4161a94ec", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
2024-07-10 11:40:28,324 - {pytorch_tabular.tabular_model:680} - INFO - Training the model completed                \n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:40:28\u001b[0m,\u001b[1;36m324\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m680\u001b[0m\u001b[1m}\u001b[0m - INFO - Training the model completed \n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
2024-07-10 11:40:28,326 - {pytorch_tabular.tabular_model:1512} - INFO - Loading the best model                     \n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:40:28\u001b[0m,\u001b[1;36m326\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m1512\u001b[0m\u001b[1m}\u001b[0m - INFO - Loading the best model \n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "8db6da2e21c24377a31514cc5a245919", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/yony/Code/pytorch-tabular-public/.venv/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.\n" + ] + }, + { + "data": { + "text/html": [ + "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
+       "┃        Test metric               DataLoader 0        ┃\n",
+       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
+       "│       test_accuracy            1.210800051689148     │\n",
+       "│      test_accuracy_0          0.9120000004768372     │\n",
+       "│      test_accuracy_1           0.298799991607666     │\n",
+       "│        test_auroc             1.4514379501342773     │\n",
+       "│       test_auroc_0            0.9510558843612671     │\n",
+       "│       test_auroc_1             0.500382125377655     │\n",
+       "│       test_f1_score            1.210800051689148     │\n",
+       "│      test_f1_score_0          0.9120000004768372     │\n",
+       "│      test_f1_score_1           0.298799991607666     │\n",
+       "│         test_loss             1.3304287195205688     │\n",
+       "│        test_loss_0            0.22366906702518463    │\n",
+       "│        test_loss_1             1.106759786605835     │\n",
+       "└───────────────────────────┴───────────────────────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│\u001b[36m \u001b[0m\u001b[36m test_accuracy \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 1.210800051689148 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_accuracy_0 \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.9120000004768372 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_accuracy_1 \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.298799991607666 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_auroc \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 1.4514379501342773 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_auroc_0 \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.9510558843612671 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_auroc_1 \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.500382125377655 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_f1_score \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 1.210800051689148 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_f1_score_0 \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.9120000004768372 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_f1_score_1 \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.298799991607666 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_loss \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 1.3304287195205688 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_loss_0 \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.22366906702518463 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_loss_1 \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 1.106759786605835 \u001b[0m\u001b[35m \u001b[0m│\n", + "└───────────────────────────┴───────────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
f1auroc
1st Target (single mode)0.94840.971750
2nd Target (single mode)0.68080.813438
1st Target (multi-target mode)0.94640.969867
2nd Target (multi-target mode)0.61840.775763
1st Target (Gandalf, multi-target mode)0.91200.951056
2nd Target (Gandalf, multi-target mode)0.29880.500382
\n", + "
" + ], + "text/plain": [ + " f1 auroc\n", + "1st Target (single mode) 0.9484 0.971750\n", + "2nd Target (single mode) 0.6808 0.813438\n", + "1st Target (multi-target mode) 0.9464 0.969867\n", + "2nd Target (multi-target mode) 0.6184 0.775763\n", + "1st Target (Gandalf, multi-target mode) 0.9120 0.951056\n", + "2nd Target (Gandalf, multi-target mode) 0.2988 0.500382" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data_config_2nd = DataConfig(\n", + " target=['target', 'second_target'],\n", + " continuous_cols=num_col_names,\n", + " categorical_cols=cat_col_names,\n", + ")\n", + "\n", + "trainer_gl_config = TrainerConfig(\n", + " auto_lr_find=True, # Runs the LRFinder to automatically derive a learning rate\n", + " batch_size=1024,\n", + " max_epochs=50,\n", + " early_stopping=\"valid_loss\", # Monitor valid_loss for early stopping\n", + " early_stopping_mode = \"min\", # Set the mode as min because for val_loss, lower is better\n", + " early_stopping_patience=5, # No. of epochs of degradation training will wait before terminating\n", + " checkpoints=\"valid_loss\", # Save best checkpoint monitoring val_loss\n", + " load_best=True, # After training, load the best checkpoint\n", + " # accelerator=\"cpu\"\n", + ")\n", + "\n", + "model_config_gandalf = GANDALFConfig(task='classification',\n", + " metrics=[\"f1_score\",\"accuracy\",\"auroc\"], \n", + " metrics_prob_input=[True, False, True], # f1_score needs probability scores, while accuracy doesn't,\n", + " gflu_stages=6, gflu_dropout=0.2\n", + ")\n", + "\n", + "\n", + "tabular_model = TabularModel(\n", + " data_config=data_config_2nd,\n", + " model_config=model_config_gandalf,\n", + " optimizer_config=optimizer_config,\n", + " trainer_config=trainer_gl_config,\n", + ")\n", + "\n", + "tabular_model.fit(train=train, validation=val)\n", + "\n", + "result = tabular_model.evaluate(test)\n", + "\n", + "result = {k: float(v) for k,v in result[0].items()}\n", + "result1 = pd.DataFrame({'f1':result['test_f1_score_0'],'auroc':result['test_auroc_0']},\n", + " index=['1st Target (Gandalf, multi-target mode)'])\n", + "results.append(result1)\n", + "result2 = pd.DataFrame({'f1':result['test_f1_score_1'],'auroc':result['test_auroc_1']},\n", + " index=['2nd Target (Gandalf, multi-target mode)'])\n", + "results.append(result2)\n", + "\n", + "pd.concat(results)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Results Summary\n", + "\n", + "Here we see that the deeper Gandalf model with multi-target performed worse than either variant of the Category Embedding Model. As before, additional hyperparameter tuning may be required for a fair comparison." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.18" + }, + "varInspector": { + "cols": { + "lenName": 16, + "lenType": 16, + "lenVar": 40 + }, + "kernels_config": { + "python": { + "delete_cmd_postfix": "", + "delete_cmd_prefix": "del ", + "library": "var_list.py", + "varRefreshCmd": "print(var_dic_list())" + }, + "r": { + "delete_cmd_postfix": ") ", + "delete_cmd_prefix": "rm(", + "library": "var_list.r", + "varRefreshCmd": "cat(var_dic_list()) " + } + }, + "types_to_exclude": [ + "module", + "function", + "builtin_function_or_method", + "instance", + "_Feature" + ], + "window_display": false + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} From 660dc1c199096e7e14ca13becb0786c3df5fd4df Mon Sep 17 00:00:00 2001 From: Yony Bresler Date: Wed, 10 Jul 2024 20:11:33 +0000 Subject: [PATCH 11/11] Minor update to documentation for multi-target classification --- ...1-Approaching Any Tabular Problem with PyTorch Tabular.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/tutorials/01-Approaching Any Tabular Problem with PyTorch Tabular.ipynb b/docs/tutorials/01-Approaching Any Tabular Problem with PyTorch Tabular.ipynb index af4ddbf2..ab803092 100644 --- a/docs/tutorials/01-Approaching Any Tabular Problem with PyTorch Tabular.ipynb +++ b/docs/tutorials/01-Approaching Any Tabular Problem with PyTorch Tabular.ipynb @@ -532,7 +532,7 @@ "data_config = DataConfig(\n", " target=[\n", " target_col\n", - " ], # target should always be a list. Multi-targets are only supported for regression. Multi-Task Classification is not implemented\n", + " ], # target should always be a list\n", " continuous_cols=num_col_names,\n", " categorical_cols=cat_col_names,\n", ")\n",