Skip to content

Commit

Permalink
Features/random sampler (#633)
Browse files Browse the repository at this point in the history
Allowing to drop in_df ids

---------

Co-authored-by: Ben Epstein <ben@rungalileo.io>
  • Loading branch information
franz101 and Ben Epstein authored May 31, 2023
1 parent c6c552f commit cc58a18
Show file tree
Hide file tree
Showing 9 changed files with 166 additions and 6 deletions.
2 changes: 1 addition & 1 deletion dataquality/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
dataquality.get_insights()
"""

__version__ = "v0.8.45"
__version__ = "v0.8.46"

import sys
from typing import Any, List, Optional
Expand Down
18 changes: 18 additions & 0 deletions dataquality/integrations/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
from torch import Tensor
from torch.nn import Module
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SequentialSampler
from transformers.modeling_outputs import TokenClassifierOutput

import dataquality as dq
from dataquality.analytics import Analytics
from dataquality.clients.api import ApiClient
from dataquality.core.log import get_data_logger
from dataquality.exceptions import GalileoException
from dataquality.schemas.task_type import TaskType
from dataquality.schemas.torch import DimensionSlice, HelperData, InputDim, Layer
Expand Down Expand Up @@ -193,6 +195,7 @@ def watch(
logits_fn: Optional[Callable] = None,
last_hidden_state_layer: Union[Module, str, None] = None,
unpatch_on_start: bool = False,
dataloader_random_sampling: bool = False,
) -> None:
"""
wraps a PyTorch model and optionally dataloaders to log the
Expand Down Expand Up @@ -229,6 +232,13 @@ def watch(
:param model: Pytorch Model to be wrapped
:param dataloaders: List of dataloaders to be wrapped
:param last_hidden_state_layer: Layer to extract the embeddings from
:param unpatch_on_start: Force unpatching of dataloaders
instead of global patching
:param dataloader_random_sampling: Whether a RandomSampler
or WeightedRandomSampler is being used. If random sampling
is being used, you must set this to True, otherwise logging
will fail at the end of training.
"""
a.log_function("torch/watch")
assert dq.config.task_type, GalileoException(
Expand All @@ -244,6 +254,8 @@ def watch(
)

helper_data = dq.get_model_logger().logger_config.helper_data
logger_config = get_data_logger().logger_config

print("Attaching dataquality to model and dataloaders")
tl = TorchLogger(
model=model,
Expand All @@ -265,6 +277,10 @@ def watch(
)
if len(dataloaders) > 0 and is_single_process_dataloader:
for dataloader in dataloaders:
if not isinstance(getattr(dataloader, "sampler", None), SequentialSampler):
logger_config = get_data_logger().logger_config
logger_config.dataloader_random_sampling = True

assert isinstance(dataloader, DataLoader), GalileoException(
"Invalid dataloader. Must be a pytorch dataloader"
"from torch.utils.data import DataLoader..."
Expand All @@ -283,6 +299,8 @@ def watch(
# Patch the dataloader class globally
# Can be unpatched with unwatch()
patch_dataloaders(tl.helper_data)
if dataloader_random_sampling:
logger_config.dataloader_random_sampling = True


def unwatch(model: Optional[Module] = None, force: bool = True) -> None:
Expand Down
6 changes: 5 additions & 1 deletion dataquality/loggers/data_logger/base_data_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,11 @@ def process_in_out_frames(
:param epoch_or_inf_name: The epoch or inference name we are uploading for
"""
validate_unique_ids(out_frame, epoch_or_inf_name)
in_out = _join_in_out_frames(in_frame, out_frame)
allow_missing_in_df_ids = cls.logger_config.dataloader_random_sampling

in_out = _join_in_out_frames(
in_frame, out_frame, allow_missing_in_df_ids=allow_missing_in_df_ids
)

dataframes = cls.separate_dataframe(in_out, prob_only, split)
# These df vars will be used in upload_in_out_frames
Expand Down
19 changes: 18 additions & 1 deletion dataquality/loggers/data_logger/image_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
import glob
import os
import tempfile
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Set, Union

import numpy as np
import pandas as pd
import vaex
from vaex.dataframe import DataFrame
Expand Down Expand Up @@ -267,9 +268,24 @@ def process_in_out_frames(
data: "id", "pred" + all the other cols not in emb or prob
"""
validate_unique_ids(out_frame, epoch_or_inf_name)
allow_missing_in_df_ids = cls.logger_config.dataloader_random_sampling
filter_ids: Set[int] = set()
if allow_missing_in_df_ids:
observed_ids = image_classification_logger_config.observed_ids
keys = [k for k in observed_ids.keys() if split in k]
if len(keys):
filter_ids = set(observed_ids[keys[0]])
for k in keys:
filter_ids = filter_ids.intersection(observed_ids[k])

emb_cols = ["id"] if prob_only else ["id", "emb"]
emb_df = out_frame[emb_cols]
if allow_missing_in_df_ids:
filter_ids_arr: np.ndarray = np.array(list(filter_ids))
del filter_ids
in_frame = in_frame[in_frame["id"].isin(filter_ids_arr)]
out_frame = out_frame[out_frame["id"].isin(filter_ids_arr)]

# The in_frame has gold, so we join with the out_frame to get the probabilities
prob_df = out_frame.join(in_frame[["id", "gold"]], on="id")[
cls._get_prob_cols()
Expand All @@ -288,6 +304,7 @@ def process_in_out_frames(
# prob_df on the server. This is confusing code
data_cols = in_frame.get_column_names() + ["pred"]
data_cols = ["id"] + [c for c in data_cols if c not in remove_cols]

data_df = in_frame.join(out_frame[["id", "pred"]], on="id")[data_cols]

dataframes = BaseLoggerDataFrames(prob=prob_df, emb=emb_df, data=data_df)
Expand Down
1 change: 1 addition & 0 deletions dataquality/loggers/logger_config/base_logger_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class BaseLoggerConfig(BaseModel):
finish: Callable = lambda: None # Overwritten in Semantic Segmentation
# True when calling `init` with a run that already exists
existing_run: bool = False
dataloader_random_sampling = False

class Config:
validate_assignment = True
Expand Down
1 change: 1 addition & 0 deletions dataquality/loggers/logger_config/image_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ class ImageClassificationLoggerConfig(TextClassificationLoggerConfig):
# Keep track of the ids that have been observed in the current epoch
# the key is the split and epoch like observed_ids["train_0"] = {0, 1, 2, 3}
observed_ids: Dict[str, set] = dict()
all_ids: Dict[str, set] = dict()


image_classification_logger_config = ImageClassificationLoggerConfig()
11 changes: 10 additions & 1 deletion dataquality/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

from typing_extensions import ParamSpec

from dataquality.exceptions import GalileoException

T = TypeVar("T")
P = ParamSpec("P")
GALILEO_DISABLED = "GALILEO_DISABLED"
Expand Down Expand Up @@ -85,7 +87,14 @@ def map_indices_to_ids(id_map: List, indices: List) -> List:
:param indices: The indices to map
:return: The ids
"""
return [id_map[i] for i in indices]
try:
return [id_map[i] for i in indices]
except IndexError:
raise GalileoException(
"The indicies of the model output are not matching the logged data "
"samples. If you are using RandomSampler or WeightedRandomSampler, "
"pass dataloader_random_sampling=True to the watch function"
)


def open_console_url(link: Optional[str] = "") -> None:
Expand Down
14 changes: 12 additions & 2 deletions dataquality/utils/vaex.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
DEFAULT_DATA_EMBS_MODEL = "all-MiniLM-L6-v2"


def _join_in_out_frames(in_df: DataFrame, out_df: DataFrame) -> DataFrame:
def _join_in_out_frames(
in_df: DataFrame, out_df: DataFrame, allow_missing_in_df_ids: bool = False
) -> DataFrame:
"""Helper function to join our input and output frames"""
in_frame = in_df.copy()
# There is an odd vaex bug where sometimes we lose the continuity of this dataframe
Expand All @@ -36,7 +38,7 @@ def _join_in_out_frames(in_df: DataFrame, out_df: DataFrame) -> DataFrame:
in_frame["id"] = in_frame["id"].values
out_frame = out_df.copy()
in_out = out_frame.join(in_frame, on="id", how="inner", lsuffix="_L").copy()
if len(in_out) != len(out_frame):
if len(in_out) != len(out_frame) and not allow_missing_in_df_ids:
num_missing = len(out_frame) - len(in_out)
missing_ids = set(out_frame["id"].unique()) - set(in_out["id_L"].unique())
split = out_frame["split"].unique()[0]
Expand All @@ -45,6 +47,14 @@ def _join_in_out_frames(in_df: DataFrame, out_df: DataFrame) -> DataFrame:
f"for split {split}. {num_missing} corresponding input IDs are missing:\n"
f"{missing_ids}"
)
elif allow_missing_in_df_ids:
# If we're downsampling, we make sure the out and in have an id intersection
# and then we drop the out rows that don't have a corresponding in
in_ids = set(in_frame["id"].unique())
out_ids = set(out_frame["id"].unique())
id_intersection = np.array(list(in_ids.intersection(out_ids)))
in_out = in_out[in_out["id_L"].isin(id_intersection)]

keep_cols = [c for c in in_out.get_column_names() if not c.endswith("_L")]
in_out = in_out[keep_cols]
return in_out
Expand Down
100 changes: 100 additions & 0 deletions tests/integrations/torch/test_random_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from typing import Callable, Generator, Tuple
from unittest.mock import MagicMock, patch

import pandas as pd
from sklearn.preprocessing import LabelEncoder
from torch.utils.data import Dataset

import dataquality as dq
from dataquality.clients.api import ApiClient
from dataquality.schemas.split import Split
from tests.conftest import DEFAULT_PROJECT_ID, DEFAULT_RUN_ID


# Assuming your labels are the target for your model
class TextDataset(Dataset):
def __init__(
self,
dataframe: pd.DataFrame,
id_column: str,
text_column: str,
label_column: str,
) -> None:
self.dataframe = dataframe
self.text = dataframe[text_column]
self.ids = dataframe[id_column]
self.labels = dataframe[label_column]
self.label_encoder = LabelEncoder()
self.labels = self.label_encoder.fit_transform(self.labels)

def __len__(self) -> int:
return len(self.dataframe)

def __getitem__(self, idx: int) -> Tuple:
text = self.text[idx]
label = self.labels[idx]
ids = self.ids[idx]
return ids, text, label


@patch.object(ApiClient, "valid_current_user", return_value=True)
@patch.object(dq.core.finish, "_version_check")
@patch.object(dq.core.finish, "_reset_run")
@patch.object(dq.core.finish, "upload_dq_log_file")
@patch.object(ApiClient, "make_request")
@patch.object(dq.core.finish, "wait_for_run")
@patch.object(ApiClient, "get_project_by_name")
@patch.object(ApiClient, "create_project")
@patch.object(ApiClient, "get_project_run_by_name", return_value={})
@patch.object(ApiClient, "create_run")
@patch("dataquality.core.init._check_dq_version")
@patch.object(
dq.clients.api.ApiClient,
"get_healthcheck_dq",
return_value={
"bucket_names": {
"images": "galileo-images",
"results": "galileo-project-runs-results",
"root": "galileo-project-runs",
},
"minio_fqdn": "127.0.0.1:9000",
},
)
@patch.object(dq.core.init.ApiClient, "valid_current_user", return_value=True)
def test_random(
mock_valid_user: MagicMock,
mock_dq_healthcheck: MagicMock,
mock_check_dq_version: MagicMock,
mock_create_run: MagicMock,
mock_get_project_run_by_name: MagicMock,
mock_create_project: MagicMock,
mock_get_project_by_name: MagicMock,
set_test_config: Callable,
mock_wait_for_run: MagicMock,
mock_make_request: MagicMock,
mock_upload_log_file: MagicMock,
mock_reset_run: MagicMock,
mock_version_check: MagicMock,
cleanup_after_use: Generator,
) -> None:
mock_get_project_by_name.return_value = {"id": DEFAULT_PROJECT_ID}
mock_create_run.return_value = {"id": DEFAULT_RUN_ID}
set_test_config(current_project_id=None, current_run_id=None)
dq.init(task_type="image_classification")
labels = ["a", "b"]
dq.set_labels_for_run(labels)

dq.log_data_samples(
texts=["a", "b", "a"],
labels=["a", "b", "a"],
ids=[0, 1, 2],
split=Split.training,
)
dq.log_model_outputs(
embs=[[0, 0], [1, 1]],
ids=[0, 1],
probs=[[0, 1], [1, 0]],
split=Split.training,
epoch=0,
)
dq.finish()

0 comments on commit cc58a18

Please sign in to comment.