Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: use StandardSqlField class for Model.feature_columns and Model.label_columns #1117

Merged
merged 1 commit into from
Jan 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions google/cloud/bigquery/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import google.cloud._helpers # type: ignore
from google.cloud.bigquery import _helpers
from google.cloud.bigquery import standard_sql
from google.cloud.bigquery.encryption_configuration import EncryptionConfiguration


Expand Down Expand Up @@ -171,26 +172,32 @@ def training_runs(self) -> Sequence[Dict[str, Any]]:
)

@property
def feature_columns(self) -> Sequence[Dict[str, Any]]:
def feature_columns(self) -> Sequence[standard_sql.StandardSqlField]:
"""Input feature columns that were used to train this model.

Read-only.
"""
return typing.cast(
resource: Sequence[Dict[str, Any]] = typing.cast(
Sequence[Dict[str, Any]], self._properties.get("featureColumns", [])
)
return [
standard_sql.StandardSqlField.from_api_repr(column) for column in resource
]

@property
def label_columns(self) -> Sequence[Dict[str, Any]]:
def label_columns(self) -> Sequence[standard_sql.StandardSqlField]:
"""Label columns that were used to train this model.

The output of the model will have a ``predicted_`` prefix to these columns.

Read-only.
"""
return typing.cast(
resource: Sequence[Dict[str, Any]] = typing.cast(
Sequence[Dict[str, Any]], self._properties.get("labelColumns", [])
)
return [
standard_sql.StandardSqlField.from_api_repr(column) for column in resource
]

@property
def best_trial_id(self) -> Optional[int]:
Expand Down
40 changes: 40 additions & 0 deletions tests/unit/model/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,46 @@ def test_build_resource(object_under_test, resource, filter_fields, expected):
assert got == expected


def test_feature_columns(object_under_test):
from google.cloud.bigquery import standard_sql

object_under_test._properties["featureColumns"] = [
{"name": "col_1", "type": {"typeKind": "STRING"}},
{"name": "col_2", "type": {"typeKind": "FLOAT64"}},
]
expected = [
standard_sql.StandardSqlField(
"col_1",
standard_sql.StandardSqlDataType(standard_sql.StandardSqlTypeNames.STRING),
),
standard_sql.StandardSqlField(
"col_2",
standard_sql.StandardSqlDataType(standard_sql.StandardSqlTypeNames.FLOAT64),
),
]
assert object_under_test.feature_columns == expected


def test_label_columns(object_under_test):
from google.cloud.bigquery import standard_sql

object_under_test._properties["labelColumns"] = [
{"name": "col_1", "type": {"typeKind": "STRING"}},
{"name": "col_2", "type": {"typeKind": "FLOAT64"}},
]
expected = [
standard_sql.StandardSqlField(
"col_1",
standard_sql.StandardSqlDataType(standard_sql.StandardSqlTypeNames.STRING),
),
standard_sql.StandardSqlField(
"col_2",
standard_sql.StandardSqlDataType(standard_sql.StandardSqlTypeNames.FLOAT64),
),
]
assert object_under_test.label_columns == expected


def test_set_description(object_under_test):
assert not object_under_test.description
object_under_test.description = "A model description."
Expand Down