Skip to content

Commit

Permalink
fix: include metadata cols in get dataframe (#787)
Browse files Browse the repository at this point in the history
  • Loading branch information
Elliott authored Nov 1, 2023
1 parent 5d9194f commit 1828e8a
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 4 deletions.
2 changes: 1 addition & 1 deletion dataquality/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
"""


__version__ = "1.1.10"
__version__ = "1.1.11"

import sys
from typing import Any, List, Optional
Expand Down
5 changes: 2 additions & 3 deletions dataquality/clients/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,15 +433,14 @@ def export_run(
:param slice_name: The optional slice name to export. If selected, this data
from this slice will be exported only.
:param include_cols: List of columns to include in the export. If not set,
all columns will be exported.
all columns will be exported. If "*" is included, return all metadata columns
:param col_mapping: Dictionary of renamed column names for export.
:param hf_format: (NER only)
Whether to export the dataframe in a HuggingFace compatible format
:param tagging_schema: (NER only)
If hf_format is True, you must pass a tagging schema
:param filter_params: Filters to apply to the dataframe before exporting. Only
rows with matching filters will be included in the exported data. If a slice
rows with matching filters will be included in the exported data
"""
project, run = self._get_project_run_id(project_name, run_name)
ext = os.path.splitext(file_name)[-1].lstrip(".")
Expand Down
6 changes: 6 additions & 0 deletions dataquality/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ def _download_df(
hf_format: bool,
tagging_schema: Optional[TaggingSchema],
filter_params: FilterParams,
meta_cols: List[str],
) -> DataFrame:
"""Helper function to download the dataframe to take advantage of caching
Expand All @@ -249,6 +250,7 @@ def _download_df(
split,
inference_name=inference_name,
file_name=file_name,
include_cols=meta_cols,
filter_params=filter_params.dict(),
hf_format=hf_format,
tagging_schema=tagging_schema,
Expand All @@ -271,6 +273,7 @@ def get_dataframe(
filter: Optional[Union[FilterParams, Dict]] = None,
as_pandas: bool = True,
include_data_embs: bool = False,
meta_cols: Optional[List[str]] = None,
) -> Union[pd.DataFrame, DataFrame]:
"""Gets the dataframe for a run/split
Expand Down Expand Up @@ -305,6 +308,8 @@ def get_dataframe(
(embeddings, probabilities etc), vaex will always be returned, because pandas
cannot support multi-dimensional columns. Default True
:param include_data_embs: Whether to include the off the shelf data embeddings
:param meta_cols: List of metadata columns to return in the dataframe. If "*"
is included, return all metadata columns
"""
split = conform_split(split)
project_id, run_id = api_client._get_project_run_id(project_name, run_name)
Expand All @@ -320,6 +325,7 @@ def get_dataframe(
hf_format,
tagging_schema,
filter_params,
meta_cols or [],
)
return _process_exported_dataframe(
data_df,
Expand Down

0 comments on commit 1828e8a

Please sign in to comment.