diff --git a/bigquery/google/cloud/bigquery/schema.py b/bigquery/google/cloud/bigquery/schema.py index 99fc65fc0c46..61bc0bcedfd6 100644 --- a/bigquery/google/cloud/bigquery/schema.py +++ b/bigquery/google/cloud/bigquery/schema.py @@ -14,6 +14,33 @@ """Schemas for BigQuery tables / queries.""" +from google.cloud.bigquery_v2 import types + + +# SQL types reference: +# https://cloud.google.com/bigquery/data-types#legacy_sql_data_types +# https://cloud.google.com/bigquery/docs/reference/standard-sql/data-types +LEGACY_TO_STANDARD_TYPES = { + "STRING": types.StandardSqlDataType.STRING, + "BYTES": types.StandardSqlDataType.BYTES, + "INTEGER": types.StandardSqlDataType.INT64, + "INT64": types.StandardSqlDataType.INT64, + "FLOAT": types.StandardSqlDataType.FLOAT64, + "FLOAT64": types.StandardSqlDataType.FLOAT64, + "NUMERIC": types.StandardSqlDataType.NUMERIC, + "BOOLEAN": types.StandardSqlDataType.BOOL, + "BOOL": types.StandardSqlDataType.BOOL, + "GEOGRAPHY": types.StandardSqlDataType.GEOGRAPHY, + "RECORD": types.StandardSqlDataType.STRUCT, + "STRUCT": types.StandardSqlDataType.STRUCT, + "TIMESTAMP": types.StandardSqlDataType.TIMESTAMP, + "DATE": types.StandardSqlDataType.DATE, + "TIME": types.StandardSqlDataType.TIME, + "DATETIME": types.StandardSqlDataType.DATETIME, + # no direct conversion from ARRAY, the latter is represented by mode="REPEATED" +} +"""String names of the legacy SQL types to integer codes of Standard SQL types.""" + class SchemaField(object): """Describe a single field within a table schema. @@ -146,6 +173,41 @@ def _key(self): self._fields, ) + def to_standard_sql(self): + """Return the field as the standard SQL field representation object. + + Returns: + An instance of :class:`~google.cloud.bigquery_v2.types.StandardSqlField`. + """ + sql_type = types.StandardSqlDataType() + + if self.mode == "REPEATED": + sql_type.type_kind = types.StandardSqlDataType.ARRAY + else: + sql_type.type_kind = LEGACY_TO_STANDARD_TYPES.get( + self.field_type, types.StandardSqlDataType.TYPE_KIND_UNSPECIFIED + ) + + if sql_type.type_kind == types.StandardSqlDataType.ARRAY: # noqa: E721 + array_element_type = LEGACY_TO_STANDARD_TYPES.get( + self.field_type, types.StandardSqlDataType.TYPE_KIND_UNSPECIFIED + ) + sql_type.array_element_type.type_kind = array_element_type + + # ARRAY cannot directly contain other arrays, only scalar types and STRUCTs + # https://cloud.google.com/bigquery/docs/reference/standard-sql/data-types#array-type + if array_element_type == types.StandardSqlDataType.STRUCT: # noqa: E721 + sql_type.array_element_type.struct_type.fields.extend( + field.to_standard_sql() for field in self.fields + ) + + elif sql_type.type_kind == types.StandardSqlDataType.STRUCT: # noqa: E721 + sql_type.struct_type.fields.extend( + field.to_standard_sql() for field in self.fields + ) + + return types.StandardSqlField(name=self.name, type=sql_type) + def __eq__(self, other): if not isinstance(other, SchemaField): return NotImplemented diff --git a/bigquery/tests/unit/test_schema.py b/bigquery/tests/unit/test_schema.py index 4694aaf63cd8..682e45895852 100644 --- a/bigquery/tests/unit/test_schema.py +++ b/bigquery/tests/unit/test_schema.py @@ -24,6 +24,12 @@ def _get_target_class(): return SchemaField + @staticmethod + def _get_standard_sql_data_type_class(): + from google.cloud.bigquery_v2 import types + + return types.StandardSqlDataType + def _make_one(self, *args, **kw): return self._get_target_class()(*args, **kw) @@ -151,6 +157,166 @@ def test_fields_property(self): schema_field = self._make_one("boat", "RECORD", fields=fields) self.assertIs(schema_field.fields, fields) + def test_to_standard_sql_simple_type(self): + sql_type = self._get_standard_sql_data_type_class() + examples = ( + # a few legacy types + ("INTEGER", sql_type.INT64), + ("FLOAT", sql_type.FLOAT64), + ("BOOLEAN", sql_type.BOOL), + ("DATETIME", sql_type.DATETIME), + # a few standard types + ("INT64", sql_type.INT64), + ("FLOAT64", sql_type.FLOAT64), + ("BOOL", sql_type.BOOL), + ("GEOGRAPHY", sql_type.GEOGRAPHY), + ) + for legacy_type, standard_type in examples: + field = self._make_one("some_field", legacy_type) + standard_field = field.to_standard_sql() + self.assertEqual(standard_field.name, "some_field") + self.assertEqual(standard_field.type.type_kind, standard_type) + self.assertFalse(standard_field.type.HasField("sub_type")) + + def test_to_standard_sql_struct_type(self): + from google.cloud.bigquery_v2 import types + + # Expected result object: + # + # name: "image_usage" + # type { + # type_kind: STRUCT + # struct_type { + # fields { + # name: "image_content" + # type {type_kind: BYTES} + # } + # fields { + # name: "last_used" + # type { + # type_kind: STRUCT + # struct_type { + # fields { + # name: "date_field" + # type {type_kind: DATE} + # } + # fields { + # name: "time_field" + # type {type_kind: TIME} + # } + # } + # } + # } + # } + # } + + sql_type = self._get_standard_sql_data_type_class() + + # level 2 fields + sub_sub_field_date = types.StandardSqlField( + name="date_field", type=sql_type(type_kind=sql_type.DATE) + ) + sub_sub_field_time = types.StandardSqlField( + name="time_field", type=sql_type(type_kind=sql_type.TIME) + ) + + # level 1 fields + sub_field_struct = types.StandardSqlField( + name="last_used", type=sql_type(type_kind=sql_type.STRUCT) + ) + sub_field_struct.type.struct_type.fields.extend( + [sub_sub_field_date, sub_sub_field_time] + ) + sub_field_bytes = types.StandardSqlField( + name="image_content", type=sql_type(type_kind=sql_type.BYTES) + ) + + # level 0 (top level) + expected_result = types.StandardSqlField( + name="image_usage", type=sql_type(type_kind=sql_type.STRUCT) + ) + expected_result.type.struct_type.fields.extend( + [sub_field_bytes, sub_field_struct] + ) + + # construct legacy SchemaField object + sub_sub_field1 = self._make_one("date_field", "DATE") + sub_sub_field2 = self._make_one("time_field", "TIME") + sub_field_record = self._make_one( + "last_used", "RECORD", fields=(sub_sub_field1, sub_sub_field2) + ) + sub_field_bytes = self._make_one("image_content", "BYTES") + + for type_name in ("RECORD", "STRUCT"): + schema_field = self._make_one( + "image_usage", type_name, fields=(sub_field_bytes, sub_field_record) + ) + standard_field = schema_field.to_standard_sql() + self.assertEqual(standard_field, expected_result) + + def test_to_standard_sql_array_type_simple(self): + from google.cloud.bigquery_v2 import types + + sql_type = self._get_standard_sql_data_type_class() + + # construct expected result object + expected_sql_type = sql_type(type_kind=sql_type.ARRAY) + expected_sql_type.array_element_type.type_kind = sql_type.INT64 + expected_result = types.StandardSqlField( + name="valid_numbers", type=expected_sql_type + ) + + # construct "repeated" SchemaField object and convert to standard SQL + schema_field = self._make_one("valid_numbers", "INT64", mode="REPEATED") + standard_field = schema_field.to_standard_sql() + + self.assertEqual(standard_field, expected_result) + + def test_to_standard_sql_array_type_struct(self): + from google.cloud.bigquery_v2 import types + + sql_type = self._get_standard_sql_data_type_class() + + # define person STRUCT + name_field = types.StandardSqlField( + name="name", type=sql_type(type_kind=sql_type.STRING) + ) + age_field = types.StandardSqlField( + name="age", type=sql_type(type_kind=sql_type.INT64) + ) + person_struct = types.StandardSqlField( + name="person_info", type=sql_type(type_kind=sql_type.STRUCT) + ) + person_struct.type.struct_type.fields.extend([name_field, age_field]) + + # define expected result - an ARRAY of person structs + expected_sql_type = sql_type( + type_kind=sql_type.ARRAY, array_element_type=person_struct.type + ) + expected_result = types.StandardSqlField( + name="known_people", type=expected_sql_type + ) + + # construct legacy repeated SchemaField object + sub_field1 = self._make_one("name", "STRING") + sub_field2 = self._make_one("age", "INTEGER") + schema_field = self._make_one( + "known_people", "RECORD", fields=(sub_field1, sub_field2), mode="REPEATED" + ) + + standard_field = schema_field.to_standard_sql() + self.assertEqual(standard_field, expected_result) + + def test_to_standard_sql_unknown_type(self): + sql_type = self._get_standard_sql_data_type_class() + field = self._make_one("weird_field", "TROOLEAN") + + standard_field = field.to_standard_sql() + + self.assertEqual(standard_field.name, "weird_field") + self.assertEqual(standard_field.type.type_kind, sql_type.TYPE_KIND_UNSPECIFIED) + self.assertFalse(standard_field.type.HasField("sub_type")) + def test___eq___wrong_type(self): field = self._make_one("test", "STRING") other = object()