Skip to content

Commit

Permalink
fix column cleaning and general cleanup (#39)
Browse files Browse the repository at this point in the history
* remove callout placeholders

* remove np.nan fill; remove duplicate cleaning functions

* pass remaining cleaning function to tests

* remove column debug logging

* remove callout prep during sampling

* ensure dtypes are consistent before/after sampling

* handle fillna in payload generation

* add tests to make sure null values don't cause problems internally

* update test docstring

* debug logging for dtype conversions

* remove is_equal and use df.equals(other_df)

* skip cleaning for float/int/bool columns/datetime columns

* debug logging for column conversions
  • Loading branch information
shouples authored Sep 7, 2022
1 parent 409488e commit 5539563
Show file tree
Hide file tree
Showing 8 changed files with 103 additions and 131 deletions.
1 change: 0 additions & 1 deletion dx/filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ def update_display_id(
query_string = sql_filter.format(table_name=table_name)
logger.debug(f"sql query string: {query_string}")
new_df = pd.read_sql(query_string, sql_engine)
logger.debug(f"{new_df.columns=}")

with sql_engine.connect() as conn:
orig_df_count = conn.execute(f"SELECT COUNT (*) FROM {table_name}").scalar()
Expand Down
20 changes: 11 additions & 9 deletions dx/formatters/dataresource.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
from functools import lru_cache
from typing import Optional

import numpy as np
import pandas as pd
import structlog
from IPython import get_ipython
from IPython.core.formatters import DisplayFormatter
from IPython.core.interactiveshell import InteractiveShell
from IPython.display import HTML, display
from IPython.display import display
from pandas.io.json import build_table_schema
from pydantic import BaseSettings, Field

Expand Down Expand Up @@ -64,6 +65,7 @@ def handle_dataresource_format(
logger.debug(f"*** handling dataresource format for {type(obj)=} ***")
if not isinstance(obj, pd.DataFrame):
obj = to_dataframe(obj)
logger.debug(f"{obj.shape=}")

default_index_used = is_default_index(obj.index)

Expand Down Expand Up @@ -138,9 +140,15 @@ def generate_dataresource_body(
Transforms the dataframe to a payload dictionary containing the
table schema and column values as arrays.
"""
schema = build_table_schema(df)
logger.debug(f"{schema=}")

# fillna(np.nan) to handle pd.NA values
data = df.fillna(np.nan).reset_index().to_dict("records")

payload = {
"schema": build_table_schema(df),
"data": df.reset_index().to_dict("records"),
"schema": schema,
"data": data,
"datalink": {"display_id": display_id},
}
return payload
Expand Down Expand Up @@ -186,12 +194,6 @@ def format_dataresource(
display_id=display_id,
update=update,
)
# temporary placeholder for copy/paste user messaging
display(
HTML("<div></div>"),
display_id=display_id + "-primary",
update=update,
)

return (payload, metadata)

Expand Down
23 changes: 12 additions & 11 deletions dx/formatters/dx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
from functools import lru_cache
from typing import Optional

import numpy as np
import pandas as pd
import structlog
from IPython import get_ipython
from IPython.core.formatters import DisplayFormatter
from IPython.core.interactiveshell import InteractiveShell
from IPython.display import HTML, display
from IPython.display import display
from pandas.io.json import build_table_schema
from pydantic import BaseSettings, Field

Expand Down Expand Up @@ -60,8 +61,10 @@ def handle_dx_format(
):
ipython = ipython_shell or get_ipython()

logger.debug(f"*** handling DEX format for {type(obj)=} ***")
if not isinstance(obj, pd.DataFrame):
obj = to_dataframe(obj)
logger.debug(f"{obj.shape=}")

default_index_used = is_default_index(obj.index)

Expand Down Expand Up @@ -136,10 +139,16 @@ def generate_dx_body(
Transforms the dataframe to a payload dictionary containing the
table schema and column values as arrays.
"""
schema = build_table_schema(df)
logger.debug(f"{schema=}")

# fillna(np.nan) to handle pd.NA values
data = df.fillna(np.nan).reset_index().transpose().values.tolist()

# this will include the `df.index` by default (e.g. slicing/sampling)
payload = {
"schema": build_table_schema(df),
"data": df.reset_index().transpose().values.tolist(),
"schema": schema,
"data": data,
"datalink": {"display_id": display_id},
}
return payload
Expand Down Expand Up @@ -184,14 +193,6 @@ def format_dx(
update=update,
)

# temporary placeholder for copy/paste user messaging
if settings.ENABLE_DATALINK:
display(
HTML("<div></div>"),
display_id=display_id + "-primary",
update=update,
)

return (payload, metadata)


Expand Down
58 changes: 16 additions & 42 deletions dx/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

from dx.settings import settings
from dx.types import DXSamplingMethod
from dx.utils.formatting import human_readable_size

logger = structlog.get_logger(__name__)

Expand All @@ -17,66 +16,41 @@ def sample_if_too_big(df: pd.DataFrame, display_id: Optional[str] = None) -> pd.
to help reduce the amount of data being sent to the
frontend for non-default media types.
"""

warnings = []
orig_dtypes = set(df.dtypes.to_dict().items())

# check number of columns first, then trim rows if needed
max_columns = settings.DISPLAY_MAX_COLUMNS
df_too_wide = len(df.columns) > max_columns
if df_too_wide:
num_orig_columns = len(df.columns)
df = sample_columns(df, num_cols=max_columns)
col_warning = f"""Dataframe has {num_orig_columns:,} column(s),
which is more than <code>{settings.DISPLAY_MAX_COLUMNS=}</code>"""
warnings.append(col_warning)

# check number of rows next, then start reducing even more
max_rows = settings.DISPLAY_MAX_ROWS
df_too_long = len(df) > max_rows
if df_too_long:
num_orig_rows = len(df)
df = sample_rows(df, num_rows=max_rows, display_id=display_id)
row_warning = f"""Dataframe has {num_orig_rows:,} row(s),
which is more than <code>{settings.DISPLAY_MAX_ROWS=}</code>"""
warnings.append(row_warning)

# in the event that there are nested/large values bloating the dataframe,
# easiest to reduce rows even further here
max_size_bytes = settings.MAX_RENDER_SIZE_BYTES
df_too_big = sys.getsizeof(df) > max_size_bytes
if df_too_big:
orig_size = sys.getsizeof(df)
df = reduce_df(df)
size_str = human_readable_size(orig_size)
max_size_str = human_readable_size(max_size_bytes)
settings_size_str = f"<code>{settings.MAX_RENDER_SIZE_BYTES=}</code> ({max_size_str})"
size_warning = f"""Dataframe is {size_str}, which is more than {settings_size_str}"""
warnings.append(size_warning)

# TODO: remove this altogether once frontend uses new metadata to create warning
if warnings:
warning_html = "<br/>".join(warnings)
new_size_html = f"""A truncated version with <strong>{len(df):,}</code> row(s) and
{len(df.reset_index().columns):,} column(s)</strong> will be viewable in DEX."""
warning_html = f"{warning_html}<br/>{new_size_html}"

# give users more information on how to change settings
override_snippet = (
"""<mark><code>dx.set_option({setting name}, {new value})</code></mark>"""
)
sample_override = """<code>dx.set_option("DISPLAY_MAX_ROWS", 250_000)</code>"""
override_warning = "<small><i><sup>*</sup>Be careful; increasing these limits may negatively impact performance.</i></small>"
user_feedback = f"""<div style="padding:0.25rem 1rem;">
<p>To adjust the settings*, execute {override_snippet} in a new cell.
<br/>For example, to change the maximum number of rows to display to 250,000,
you could execute the following: {sample_override}</p>
{override_warning}</div>"""
user_feedback_collapsed_section = (
f"""<details><summary>More Information</summary>{user_feedback}</details>"""
)

warning_html = f"{warning_html} {user_feedback_collapsed_section}"
# display_callout(warning_html, level="warning")

# sampling may convert columns to `object` dtype, so we need to make sure
# the original dtypes persist before generating the body for the frontend
current_dtypes = set(df.dtypes.to_dict())
dtype_conversions = orig_dtypes - current_dtypes
if dtype_conversions:
for column, dtype in dtype_conversions:
if column not in df.columns:
# this is a column that was dropped during sampling
logger.debug(f"`{column}` no longer in df, skipping dtype conversion")
continue
if str(df[column].dtype) == str(dtype):
continue
logger.debug(f"converting `{column}` from `{df[column].dtype!r}` to `{dtype!r}`")
df[column] = df[column].astype(dtype)

return df

Expand Down
5 changes: 5 additions & 0 deletions dx/tests/test_datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
quick_random_dataframe,
random_dataframe,
)
from dx.utils.formatting import clean_column_values
from dx.utils.tracking import generate_df_hash, sql_engine, store_in_sqlite


Expand Down Expand Up @@ -120,6 +121,8 @@ def test_generate_df_hash(dtype: str):
params = {dt: False for dt in SORTED_DX_DATATYPES}
params[dtype] = True
df = random_dataframe(**params)
for col in df.columns:
df[col] = clean_column_values(df[col])
try:
hash_str = generate_df_hash(df)
except Exception as e:
Expand Down Expand Up @@ -153,6 +156,8 @@ def test_store_in_sqlite(dtype: str):
params = {dt: False for dt in SORTED_DX_DATATYPES}
params[dtype] = True
df = random_dataframe(**params)
for col in df.columns:
df[col] = clean_column_values(df[col])
try:
num_rows = store_in_sqlite(f"{dtype}_test", df)
except Exception as e:
Expand Down
36 changes: 36 additions & 0 deletions dx/tests/test_formatting.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import numpy as np
import pandas as pd
import pytest
from IPython.terminal.interactiveshell import TerminalInteractiveShell

from dx.formatters.dataresource import get_dataresource_settings, handle_dataresource_format
Expand Down Expand Up @@ -62,3 +64,37 @@ def test_dx_nonunique_index_succeeds(
handle_dx_format(double_df, ipython_shell=get_ipython)
except Exception as e:
assert False, f"{e}"


@pytest.mark.parametrize("null_value", [None, np.nan, pd.NA])
def test_dataresource_succeeds_with_missing_data(
sample_dataframe: pd.DataFrame,
get_ipython: TerminalInteractiveShell,
null_value,
):
"""
Test dataresource formatting doesn't fail while formatting
a dataframe with null values.
"""
sample_dataframe["missing_data"] = null_value
try:
handle_dataresource_format(sample_dataframe, ipython_shell=get_ipython)
except Exception as e:
assert False, f"{e}"


@pytest.mark.parametrize("null_value", [None, np.nan, pd.NA])
def test_dx_succeeds_with_missing_data(
sample_dataframe: pd.DataFrame,
get_ipython: TerminalInteractiveShell,
null_value,
):
"""
Test dx formatting doesn't fail while formatting
a dataframe with null values.
"""
sample_dataframe["missing_data"] = null_value
try:
handle_dx_format(sample_dataframe, ipython_shell=get_ipython)
except Exception as e:
assert False, f"{e}"
46 changes: 9 additions & 37 deletions dx/utils/formatting.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import numpy as np
import pandas as pd
import structlog

Expand Down Expand Up @@ -54,9 +53,6 @@ def normalize_index_and_columns(df: pd.DataFrame) -> pd.DataFrame:
display_df = normalize_index(display_df)
display_df = normalize_columns(display_df)

# build_table_schema() doesn't like pd.NAs
display_df.fillna(np.nan, inplace=True)

return display_df


Expand Down Expand Up @@ -107,10 +103,14 @@ def normalize_columns(df: pd.DataFrame) -> pd.DataFrame:
if settings.STRINGIFY_COLUMN_VALUES:
df.columns = pd.Index(stringify_index(df.columns))

logger.debug("-- cleaning before display --")
logger.debug("-- cleaning columns before display --")
for column in df.columns:
df[column] = clean_column_values_for_display(df[column])

standard_dtypes = ["float", "int", "bool"]
if df[column].dtype in standard_dtypes or str(df[column].dtype).startswith("datetime"):
logger.debug(f"skipping `{column=}` since it has dtype `{df[column].dtype}`")
continue
logger.debug(f"--> cleaning `{column=}` with dtype `{df[column].dtype}`")
df[column] = clean_column_values(df[column])
return df


Expand All @@ -125,7 +125,7 @@ def stringify_index(index: pd.Index):
return tuple(map(str, index))


def clean_column_values_for_display(s: pd.Series) -> pd.Series:
def clean_column_values(s: pd.Series) -> pd.Series:
"""
Cleaning/conversion for values in a series to prevent
build_table_schema() or frontend rendering errors.
Expand All @@ -138,39 +138,11 @@ def clean_column_values_for_display(s: pd.Series) -> pd.Series:
s = datatypes.handle_ip_address_series(s)
s = datatypes.handle_complex_number_series(s)

s = geometry.handle_geometry_series(s)
s = datatypes.handle_unk_type_series(s)
return s


def clean_column_values_for_hash(s: pd.Series) -> pd.Series:
"""
Cleaning/conversion for values in a series to prevent
hash_pandas_object() errors.
"""
s = geometry.handle_geometry_series(s)

s = datatypes.handle_dict_series(s)
s = datatypes.handle_sequence_series(s)
return s


def clean_column_values_for_sqlite(s: pd.Series) -> pd.Series:
"""
Cleaning/conversion for values in a series to prevent
errors writing to sqlite.
"""
s = datatypes.handle_dtype_series(s)
s = datatypes.handle_interval_series(s)
s = datatypes.handle_complex_number_series(s)
s = datatypes.handle_ip_address_series(s)

s = date_time.handle_time_period_series(s)

s = geometry.handle_geometry_series(s)

s = datatypes.handle_dict_series(s)
s = datatypes.handle_sequence_series(s)
s = datatypes.handle_unk_type_series(s)
return s


Expand Down
Loading

0 comments on commit 5539563

Please sign in to comment.