Skip to content

Commit

Permalink
Use pyarrow fallback to detect dataframe schema
Browse files Browse the repository at this point in the history
  • Loading branch information
plamut committed Sep 26, 2019
1 parent 7c9c0cb commit 8ab716f
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 9 deletions.
55 changes: 51 additions & 4 deletions bigquery/google/cloud/bigquery/_pandas_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,13 @@ def pyarrow_timestamp():
"TIME": pyarrow_time,
"TIMESTAMP": pyarrow_timestamp,
}
ARROW_SCALARS_TO_BQ = {
arrow_type(): bq_type # TODO: explain wht calling arrow_type()
for bq_type, arrow_type in BQ_TO_ARROW_SCALARS.items()
}
else: # pragma: NO COVER
BQ_TO_ARROW_SCALARS = {} # pragma: NO COVER
ARROW_SCALARS_TO_BQ = {} # pragma: NO_COVER


def bq_to_arrow_struct_data_type(field):
Expand Down Expand Up @@ -140,10 +145,11 @@ def bq_to_arrow_data_type(field):
return pyarrow.list_(inner_type)
return None

if field.field_type.upper() in schema._STRUCT_TYPES:
field_type_upper = field.field_type.upper() if field.field_type else ""
if field_type_upper in schema._STRUCT_TYPES:
return bq_to_arrow_struct_data_type(field)

data_type_constructor = BQ_TO_ARROW_SCALARS.get(field.field_type.upper())
data_type_constructor = BQ_TO_ARROW_SCALARS.get(field_type_upper)
if data_type_constructor is None:
return None
return data_type_constructor()
Expand Down Expand Up @@ -180,9 +186,12 @@ def bq_to_arrow_schema(bq_schema):

def bq_to_arrow_array(series, bq_field):
arrow_type = bq_to_arrow_data_type(bq_field)

field_type_upper = bq_field.field_type.upper() if bq_field.field_type else ""

if bq_field.mode.upper() == "REPEATED":
return pyarrow.ListArray.from_pandas(series, type=arrow_type)
if bq_field.field_type.upper() in schema._STRUCT_TYPES:
if field_type_upper in schema._STRUCT_TYPES:
return pyarrow.StructArray.from_pandas(series, type=arrow_type)
return pyarrow.array(series, type=arrow_type)

Expand Down Expand Up @@ -273,7 +282,7 @@ def dataframe_to_bq_schema(dataframe, bq_schema):
bq_type = _PANDAS_DTYPE_TO_BQ.get(dtype.name)
if not bq_type:
warnings.warn(u"Unable to determine type of column '{}'.".format(column))
return None

bq_field = schema.SchemaField(column, bq_type)
bq_schema_out.append(bq_field)

Expand All @@ -285,6 +294,44 @@ def dataframe_to_bq_schema(dataframe, bq_schema):
bq_schema_unused
)
)

# If schema detection was not successful for all columns, also try with
# pyarrow, if available.
if any(field.field_type is None for field in bq_schema_out):
if not pyarrow:
return None # We cannot detect the schema in full.

arrow_table = dataframe_to_arrow(dataframe, bq_schema_out)
arrow_schema_index = {field.name: field.type for field in arrow_table}

currated_schema = []
for schema_field in bq_schema_out:
if schema_field.field_type is not None:
currated_schema.append(schema_field)
continue

detected_type = ARROW_SCALARS_TO_BQ.get(
arrow_schema_index.get(schema_field.name)
)
if detected_type is None:
warnings.warn(
u"Pyarrow could not determine the type of column '{}'.".format(
schema_field.name
)
)
return None

new_field = schema.SchemaField(
name=schema_field.name,
field_type=detected_type,
mode=schema_field.mode,
description=schema_field.description,
fields=schema_field.fields,
)
currated_schema.append(new_field)

bq_schema_out = currated_schema

return tuple(bq_schema_out)


Expand Down
66 changes: 66 additions & 0 deletions bigquery/tests/unit/test__pandas_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -905,3 +905,69 @@ def test_dataframe_to_parquet_compression_method(module_under_test):
call_args = fake_write_table.call_args
assert call_args is not None
assert call_args.kwargs.get("compression") == "ZSTD"


@pytest.mark.skipif(pandas is None, reason="Requires `pandas`")
def test_dataframe_to_bq_schema_fallback_needed_wo_pyarrow(module_under_test):
dataframe = pandas.DataFrame(
data=[
{"id": 10, "status": "FOO", "execution_date": datetime.date(2019, 5, 10)},
{"id": 20, "status": "BAR", "created_at": datetime.date(2018, 9, 12)},
]
)

no_pyarrow_patch = mock.patch(module_under_test.__name__ + ".pyarrow", None)

with no_pyarrow_patch:
detected_schema = module_under_test.dataframe_to_bq_schema(
dataframe, bq_schema=[]
)

assert detected_schema is None


@pytest.mark.skipif(pandas is None, reason="Requires `pandas`")
@pytest.mark.skipif(pyarrow is None, reason="Requires `pyarrow`")
def test_dataframe_to_bq_schema_fallback_needed_w_pyarrow(module_under_test):
dataframe = pandas.DataFrame(
data=[
{"id": 10, "status": "FOO", "created_at": datetime.date(2019, 5, 10)},
{"id": 20, "status": "BAR", "created_at": datetime.date(2018, 9, 12)},
]
)

detected_schema = module_under_test.dataframe_to_bq_schema(dataframe, bq_schema=[])
expected_schema = (
schema.SchemaField("id", "INTEGER", mode="NULLABLE"),
schema.SchemaField("status", "STRING", mode="NULLABLE"),
schema.SchemaField("created_at", "DATE", mode="NULLABLE"),
)
assert detected_schema == expected_schema


@pytest.mark.skipif(pandas is None, reason="Requires `pandas`")
@pytest.mark.skipif(pyarrow is None, reason="Requires `pyarrow`")
def test_dataframe_to_bq_schema_pyarrow_fallback_fails(module_under_test):
dataframe = pandas.DataFrame(
data=[
{"id": 10, "status": "FOO", "all_items": [10.1, 10.2]},
{"id": 20, "status": "BAR", "all_items": [20.1, 20.2]},
]
)

with warnings.catch_warnings(record=True) as warned:
detected_schema = module_under_test.dataframe_to_bq_schema(
dataframe, bq_schema=[]
)

assert detected_schema is None

expected_warnings = []
for warning in warned:
if "Pyarrow could not" in str(warning):
expected_warnings.append(warning)

assert len(expected_warnings) == 1
warning_msg = str(expected_warnings[0])
assert "all_items" in warning_msg
assert "could not determine the type" in warning_msg
14 changes: 9 additions & 5 deletions bigquery/tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5704,8 +5704,7 @@ def test_load_table_from_dataframe_unknown_table(self):
)

@unittest.skipIf(pandas is None, "Requires `pandas`")
@unittest.skipIf(pyarrow is None, "Requires `pyarrow`")
def test_load_table_from_dataframe_no_schema_warning(self):
def test_load_table_from_dataframe_no_schema_warning_wo_pyarrow(self):
client = self._make_client()

# Pick at least one column type that translates to Pandas dtype
Expand All @@ -5722,9 +5721,12 @@ def test_load_table_from_dataframe_no_schema_warning(self):
"google.cloud.bigquery.client.Client.load_table_from_file", autospec=True
)
pyarrow_patch = mock.patch("google.cloud.bigquery.client.pyarrow", None)
pyarrow_patch_helpers = mock.patch(
"google.cloud.bigquery._pandas_helpers.pyarrow", None
)
catch_warnings = warnings.catch_warnings(record=True)

with get_table_patch, load_patch, pyarrow_patch, catch_warnings as warned:
with get_table_patch, load_patch, pyarrow_patch, pyarrow_patch_helpers, catch_warnings as warned:
client.load_table_from_dataframe(
dataframe, self.TABLE_REF, location=self.LOCATION
)
Expand Down Expand Up @@ -5892,7 +5894,6 @@ def test_load_table_from_dataframe_w_partial_schema_extra_types(self):
assert "unknown_col" in message

@unittest.skipIf(pandas is None, "Requires `pandas`")
@unittest.skipIf(pyarrow is None, "Requires `pyarrow`")
def test_load_table_from_dataframe_w_partial_schema_missing_types(self):
from google.cloud.bigquery.client import _DEFAULT_NUM_RETRIES
from google.cloud.bigquery import job
Expand All @@ -5909,10 +5910,13 @@ def test_load_table_from_dataframe_w_partial_schema_missing_types(self):
load_patch = mock.patch(
"google.cloud.bigquery.client.Client.load_table_from_file", autospec=True
)
pyarrow_patch = mock.patch(
"google.cloud.bigquery._pandas_helpers.pyarrow", None
)

schema = (SchemaField("string_col", "STRING"),)
job_config = job.LoadJobConfig(schema=schema)
with load_patch as load_table_from_file, warnings.catch_warnings(
with pyarrow_patch, load_patch as load_table_from_file, warnings.catch_warnings(
record=True
) as warned:
client.load_table_from_dataframe(
Expand Down

0 comments on commit 8ab716f

Please sign in to comment.