From 5539563d13b3ff1dacbf2c2b8be7fabbe4163cc3 Mon Sep 17 00:00:00 2001 From: Dave Shoup Date: Wed, 7 Sep 2022 13:53:48 -0400 Subject: [PATCH] fix column cleaning and general cleanup (#39) * remove callout placeholders * remove np.nan fill; remove duplicate cleaning functions * pass remaining cleaning function to tests * remove column debug logging * remove callout prep during sampling * ensure dtypes are consistent before/after sampling * handle fillna in payload generation * add tests to make sure null values don't cause problems internally * update test docstring * debug logging for dtype conversions * remove is_equal and use df.equals(other_df) * skip cleaning for float/int/bool columns/datetime columns * debug logging for column conversions --- dx/filtering.py | 1 - dx/formatters/dataresource.py | 20 ++++++------ dx/formatters/dx.py | 23 +++++++------- dx/sampling.py | 58 ++++++++++------------------------- dx/tests/test_datatypes.py | 5 +++ dx/tests/test_formatting.py | 36 ++++++++++++++++++++++ dx/utils/formatting.py | 46 ++++++--------------------- dx/utils/tracking.py | 45 +++++++++------------------ 8 files changed, 103 insertions(+), 131 deletions(-) diff --git a/dx/filtering.py b/dx/filtering.py index d8fcd660..12f81921 100644 --- a/dx/filtering.py +++ b/dx/filtering.py @@ -80,7 +80,6 @@ def update_display_id( query_string = sql_filter.format(table_name=table_name) logger.debug(f"sql query string: {query_string}") new_df = pd.read_sql(query_string, sql_engine) - logger.debug(f"{new_df.columns=}") with sql_engine.connect() as conn: orig_df_count = conn.execute(f"SELECT COUNT (*) FROM {table_name}").scalar() diff --git a/dx/formatters/dataresource.py b/dx/formatters/dataresource.py index 44a4b91a..3b0b7162 100644 --- a/dx/formatters/dataresource.py +++ b/dx/formatters/dataresource.py @@ -2,12 +2,13 @@ from functools import lru_cache from typing import Optional +import numpy as np import pandas as pd import structlog from IPython import get_ipython from IPython.core.formatters import DisplayFormatter from IPython.core.interactiveshell import InteractiveShell -from IPython.display import HTML, display +from IPython.display import display from pandas.io.json import build_table_schema from pydantic import BaseSettings, Field @@ -64,6 +65,7 @@ def handle_dataresource_format( logger.debug(f"*** handling dataresource format for {type(obj)=} ***") if not isinstance(obj, pd.DataFrame): obj = to_dataframe(obj) + logger.debug(f"{obj.shape=}") default_index_used = is_default_index(obj.index) @@ -138,9 +140,15 @@ def generate_dataresource_body( Transforms the dataframe to a payload dictionary containing the table schema and column values as arrays. """ + schema = build_table_schema(df) + logger.debug(f"{schema=}") + + # fillna(np.nan) to handle pd.NA values + data = df.fillna(np.nan).reset_index().to_dict("records") + payload = { - "schema": build_table_schema(df), - "data": df.reset_index().to_dict("records"), + "schema": schema, + "data": data, "datalink": {"display_id": display_id}, } return payload @@ -186,12 +194,6 @@ def format_dataresource( display_id=display_id, update=update, ) - # temporary placeholder for copy/paste user messaging - display( - HTML("
"), - display_id=display_id + "-primary", - update=update, - ) return (payload, metadata) diff --git a/dx/formatters/dx.py b/dx/formatters/dx.py index b859531c..710d5d61 100644 --- a/dx/formatters/dx.py +++ b/dx/formatters/dx.py @@ -2,12 +2,13 @@ from functools import lru_cache from typing import Optional +import numpy as np import pandas as pd import structlog from IPython import get_ipython from IPython.core.formatters import DisplayFormatter from IPython.core.interactiveshell import InteractiveShell -from IPython.display import HTML, display +from IPython.display import display from pandas.io.json import build_table_schema from pydantic import BaseSettings, Field @@ -60,8 +61,10 @@ def handle_dx_format( ): ipython = ipython_shell or get_ipython() + logger.debug(f"*** handling DEX format for {type(obj)=} ***") if not isinstance(obj, pd.DataFrame): obj = to_dataframe(obj) + logger.debug(f"{obj.shape=}") default_index_used = is_default_index(obj.index) @@ -136,10 +139,16 @@ def generate_dx_body( Transforms the dataframe to a payload dictionary containing the table schema and column values as arrays. """ + schema = build_table_schema(df) + logger.debug(f"{schema=}") + + # fillna(np.nan) to handle pd.NA values + data = df.fillna(np.nan).reset_index().transpose().values.tolist() + # this will include the `df.index` by default (e.g. slicing/sampling) payload = { - "schema": build_table_schema(df), - "data": df.reset_index().transpose().values.tolist(), + "schema": schema, + "data": data, "datalink": {"display_id": display_id}, } return payload @@ -184,14 +193,6 @@ def format_dx( update=update, ) - # temporary placeholder for copy/paste user messaging - if settings.ENABLE_DATALINK: - display( - HTML("
"), - display_id=display_id + "-primary", - update=update, - ) - return (payload, metadata) diff --git a/dx/sampling.py b/dx/sampling.py index 7a6c0a96..b0912ae3 100644 --- a/dx/sampling.py +++ b/dx/sampling.py @@ -6,7 +6,6 @@ from dx.settings import settings from dx.types import DXSamplingMethod -from dx.utils.formatting import human_readable_size logger = structlog.get_logger(__name__) @@ -17,66 +16,41 @@ def sample_if_too_big(df: pd.DataFrame, display_id: Optional[str] = None) -> pd. to help reduce the amount of data being sent to the frontend for non-default media types. """ - - warnings = [] + orig_dtypes = set(df.dtypes.to_dict().items()) # check number of columns first, then trim rows if needed max_columns = settings.DISPLAY_MAX_COLUMNS df_too_wide = len(df.columns) > max_columns if df_too_wide: - num_orig_columns = len(df.columns) df = sample_columns(df, num_cols=max_columns) - col_warning = f"""Dataframe has {num_orig_columns:,} column(s), - which is more than {settings.DISPLAY_MAX_COLUMNS=}""" - warnings.append(col_warning) # check number of rows next, then start reducing even more max_rows = settings.DISPLAY_MAX_ROWS df_too_long = len(df) > max_rows if df_too_long: - num_orig_rows = len(df) df = sample_rows(df, num_rows=max_rows, display_id=display_id) - row_warning = f"""Dataframe has {num_orig_rows:,} row(s), - which is more than {settings.DISPLAY_MAX_ROWS=}""" - warnings.append(row_warning) # in the event that there are nested/large values bloating the dataframe, # easiest to reduce rows even further here max_size_bytes = settings.MAX_RENDER_SIZE_BYTES df_too_big = sys.getsizeof(df) > max_size_bytes if df_too_big: - orig_size = sys.getsizeof(df) df = reduce_df(df) - size_str = human_readable_size(orig_size) - max_size_str = human_readable_size(max_size_bytes) - settings_size_str = f"{settings.MAX_RENDER_SIZE_BYTES=} ({max_size_str})" - size_warning = f"""Dataframe is {size_str}, which is more than {settings_size_str}""" - warnings.append(size_warning) - - # TODO: remove this altogether once frontend uses new metadata to create warning - if warnings: - warning_html = "
".join(warnings) - new_size_html = f"""A truncated version with {len(df):,} row(s) and - {len(df.reset_index().columns):,} column(s) will be viewable in DEX.""" - warning_html = f"{warning_html}
{new_size_html}" - - # give users more information on how to change settings - override_snippet = ( - """dx.set_option({setting name}, {new value})""" - ) - sample_override = """dx.set_option("DISPLAY_MAX_ROWS", 250_000)""" - override_warning = "*Be careful; increasing these limits may negatively impact performance." - user_feedback = f"""
-

To adjust the settings*, execute {override_snippet} in a new cell. -
For example, to change the maximum number of rows to display to 250,000, - you could execute the following: {sample_override}

- {override_warning}
""" - user_feedback_collapsed_section = ( - f"""
More Information{user_feedback}
""" - ) - - warning_html = f"{warning_html} {user_feedback_collapsed_section}" - # display_callout(warning_html, level="warning") + + # sampling may convert columns to `object` dtype, so we need to make sure + # the original dtypes persist before generating the body for the frontend + current_dtypes = set(df.dtypes.to_dict()) + dtype_conversions = orig_dtypes - current_dtypes + if dtype_conversions: + for column, dtype in dtype_conversions: + if column not in df.columns: + # this is a column that was dropped during sampling + logger.debug(f"`{column}` no longer in df, skipping dtype conversion") + continue + if str(df[column].dtype) == str(dtype): + continue + logger.debug(f"converting `{column}` from `{df[column].dtype!r}` to `{dtype!r}`") + df[column] = df[column].astype(dtype) return df diff --git a/dx/tests/test_datatypes.py b/dx/tests/test_datatypes.py index 1ed09339..c9e48652 100644 --- a/dx/tests/test_datatypes.py +++ b/dx/tests/test_datatypes.py @@ -17,6 +17,7 @@ quick_random_dataframe, random_dataframe, ) +from dx.utils.formatting import clean_column_values from dx.utils.tracking import generate_df_hash, sql_engine, store_in_sqlite @@ -120,6 +121,8 @@ def test_generate_df_hash(dtype: str): params = {dt: False for dt in SORTED_DX_DATATYPES} params[dtype] = True df = random_dataframe(**params) + for col in df.columns: + df[col] = clean_column_values(df[col]) try: hash_str = generate_df_hash(df) except Exception as e: @@ -153,6 +156,8 @@ def test_store_in_sqlite(dtype: str): params = {dt: False for dt in SORTED_DX_DATATYPES} params[dtype] = True df = random_dataframe(**params) + for col in df.columns: + df[col] = clean_column_values(df[col]) try: num_rows = store_in_sqlite(f"{dtype}_test", df) except Exception as e: diff --git a/dx/tests/test_formatting.py b/dx/tests/test_formatting.py index 4a6e5d68..7c84244e 100644 --- a/dx/tests/test_formatting.py +++ b/dx/tests/test_formatting.py @@ -1,4 +1,6 @@ +import numpy as np import pandas as pd +import pytest from IPython.terminal.interactiveshell import TerminalInteractiveShell from dx.formatters.dataresource import get_dataresource_settings, handle_dataresource_format @@ -62,3 +64,37 @@ def test_dx_nonunique_index_succeeds( handle_dx_format(double_df, ipython_shell=get_ipython) except Exception as e: assert False, f"{e}" + + +@pytest.mark.parametrize("null_value", [None, np.nan, pd.NA]) +def test_dataresource_succeeds_with_missing_data( + sample_dataframe: pd.DataFrame, + get_ipython: TerminalInteractiveShell, + null_value, +): + """ + Test dataresource formatting doesn't fail while formatting + a dataframe with null values. + """ + sample_dataframe["missing_data"] = null_value + try: + handle_dataresource_format(sample_dataframe, ipython_shell=get_ipython) + except Exception as e: + assert False, f"{e}" + + +@pytest.mark.parametrize("null_value", [None, np.nan, pd.NA]) +def test_dx_succeeds_with_missing_data( + sample_dataframe: pd.DataFrame, + get_ipython: TerminalInteractiveShell, + null_value, +): + """ + Test dx formatting doesn't fail while formatting + a dataframe with null values. + """ + sample_dataframe["missing_data"] = null_value + try: + handle_dx_format(sample_dataframe, ipython_shell=get_ipython) + except Exception as e: + assert False, f"{e}" diff --git a/dx/utils/formatting.py b/dx/utils/formatting.py index b6dba46f..34ee970e 100644 --- a/dx/utils/formatting.py +++ b/dx/utils/formatting.py @@ -1,4 +1,3 @@ -import numpy as np import pandas as pd import structlog @@ -54,9 +53,6 @@ def normalize_index_and_columns(df: pd.DataFrame) -> pd.DataFrame: display_df = normalize_index(display_df) display_df = normalize_columns(display_df) - # build_table_schema() doesn't like pd.NAs - display_df.fillna(np.nan, inplace=True) - return display_df @@ -107,10 +103,14 @@ def normalize_columns(df: pd.DataFrame) -> pd.DataFrame: if settings.STRINGIFY_COLUMN_VALUES: df.columns = pd.Index(stringify_index(df.columns)) - logger.debug("-- cleaning before display --") + logger.debug("-- cleaning columns before display --") for column in df.columns: - df[column] = clean_column_values_for_display(df[column]) - + standard_dtypes = ["float", "int", "bool"] + if df[column].dtype in standard_dtypes or str(df[column].dtype).startswith("datetime"): + logger.debug(f"skipping `{column=}` since it has dtype `{df[column].dtype}`") + continue + logger.debug(f"--> cleaning `{column=}` with dtype `{df[column].dtype}`") + df[column] = clean_column_values(df[column]) return df @@ -125,7 +125,7 @@ def stringify_index(index: pd.Index): return tuple(map(str, index)) -def clean_column_values_for_display(s: pd.Series) -> pd.Series: +def clean_column_values(s: pd.Series) -> pd.Series: """ Cleaning/conversion for values in a series to prevent build_table_schema() or frontend rendering errors. @@ -138,39 +138,11 @@ def clean_column_values_for_display(s: pd.Series) -> pd.Series: s = datatypes.handle_ip_address_series(s) s = datatypes.handle_complex_number_series(s) - s = geometry.handle_geometry_series(s) - s = datatypes.handle_unk_type_series(s) - return s - - -def clean_column_values_for_hash(s: pd.Series) -> pd.Series: - """ - Cleaning/conversion for values in a series to prevent - hash_pandas_object() errors. - """ - s = geometry.handle_geometry_series(s) - - s = datatypes.handle_dict_series(s) - s = datatypes.handle_sequence_series(s) - return s - - -def clean_column_values_for_sqlite(s: pd.Series) -> pd.Series: - """ - Cleaning/conversion for values in a series to prevent - errors writing to sqlite. - """ - s = datatypes.handle_dtype_series(s) - s = datatypes.handle_interval_series(s) - s = datatypes.handle_complex_number_series(s) - s = datatypes.handle_ip_address_series(s) - - s = date_time.handle_time_period_series(s) - s = geometry.handle_geometry_series(s) s = datatypes.handle_dict_series(s) s = datatypes.handle_sequence_series(s) + s = datatypes.handle_unk_type_series(s) return s diff --git a/dx/utils/tracking.py b/dx/utils/tracking.py index b9e0aa9a..8c9e87ce 100644 --- a/dx/utils/tracking.py +++ b/dx/utils/tracking.py @@ -11,11 +11,6 @@ from dx.utils.datatypes import has_numeric_strings, is_sequence_series from dx.utils.date_time import is_datetime_series -from dx.utils.formatting import ( - clean_column_values_for_hash, - clean_column_values_for_sqlite, - normalize_index_and_columns, -) logger = structlog.get_logger(__name__) sql_engine = create_engine("sqlite://", echo=False) @@ -84,9 +79,6 @@ def generate_df_hash(df: pd.DataFrame) -> str: """ hash_df = df.copy() - for col in hash_df.columns: - hash_df[col] = clean_column_values_for_hash(hash_df[col]) - # this will be a series of hash values the length of df df_hash_series = hash_pandas_object(hash_df) # then string-concatenate all the hashed values, which could be very large @@ -96,22 +88,6 @@ def generate_df_hash(df: pd.DataFrame) -> str: return hash_str -def is_equal(df: pd.DataFrame, other_df: pd.DataFrame, df_hash: str): - if df.shape != other_df.shape: - return False - if sorted(list(df.columns)) != sorted(list(other_df.columns)): - return False - - # this could be expensive, so we only want to do it if we're - # pretty sure two dataframes could be equal - logger.debug("-- cleaning before hashing --") - other_hash = generate_df_hash(normalize_index_and_columns(other_df)) - if df_hash != other_hash: - return False - - return True - - def get_df_variable_name( df: pd.DataFrame, ipython_shell: Optional[InteractiveShell] = None, @@ -120,15 +96,28 @@ def get_df_variable_name( """ Returns the variable name of the DataFrame object. """ + logger.debug("looking for matching variables for dataframe") + ipython = ipython_shell or get_ipython() df_vars = {k: v for k, v in ipython.user_ns.items() if isinstance(v, pd.DataFrame)} logger.debug(f"dataframe variables present: {list(df_vars.keys())}") df_hash = df_hash or generate_df_hash(df) - matching_df_vars = [k for k, v in df_vars.items() if is_equal(df, v, df_hash)] + matching_df_vars = [] + for k, v in df_vars.items(): + logger.debug(f"checking if `{k}` is equal to this dataframe") + # we previously checked columns, dtypes, shape, etc between both dataframes, + # to avoid having to normalize and hash the other dataframe (v here), + # but that was too slow, and ultimately we shouldn't be checking raw data vs cleaned data + # so .equals() should be the most performant + if df.equals(v): + logger.debug(f"`{k}` matches this dataframe") + matching_df_vars.append(k) + logger.debug(f"dataframe variables with same hash: {matching_df_vars}") # we might get a mix of references here like ['_', '__', 'df'] named_df_vars_with_same_hash = [name for name in matching_df_vars if not name.startswith("_")] + logger.debug(f"named dataframe variables with same hash: {named_df_vars_with_same_hash}") if named_df_vars_with_same_hash: logger.debug(f"{named_df_vars_with_same_hash=}") return named_df_vars_with_same_hash[0] @@ -184,10 +173,6 @@ def store_in_sqlite(table_name: str, df: pd.DataFrame): logger.debug(f"{df.columns=}") tracking_df = df.copy() - logger.debug("-- cleaning before sqlite --") - for col in tracking_df.columns: - tracking_df[col] = clean_column_values_for_sqlite(tracking_df[col]) - logger.debug(f"writing to `{table_name}` table in sqlite") with sql_engine.begin() as conn: num_written_rows = tracking_df.to_sql( @@ -213,8 +198,6 @@ def track_column_conversions( # to the cleaned version of the dataframe, pull the index values # of the resulting row(s), then swap out the results with the # index positions of the original data - logger.debug(f"{orig_df.columns=}") - logger.debug(f"{df.columns=}") DISPLAY_ID_TO_INDEX[display_id] = df.index.name DISPLAY_ID_TO_DATETIME_COLUMNS[display_id] = [