Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: Vertex ML pipeline test failures #7727

Merged
merged 4 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions package_build/initialize.sh
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ do
ln -sf $BASEDIR/setup.py $BASEDIR/package_build/$CONFIG_NAME/
ln -sf $BASEDIR/dist $BASEDIR/package_build/$CONFIG_NAME/
ln -sf $BASEDIR/tfx $BASEDIR/package_build/$CONFIG_NAME/
ln -sf $BASEDIR/MANIFEST.in $BASEDIR/package_build/$CONFIG_NAME/
ln -sf $BASEDIR/README*.md $BASEDIR/package_build/$CONFIG_NAME/
ln -sf $BASEDIR/LICENSE $BASEDIR/package_build/$CONFIG_NAME/

Expand Down
3 changes: 1 addition & 2 deletions tfx/examples/chicago_taxi_pipeline/taxi_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,6 @@ def _build_keras_model(
output = tf.keras.layers.Dense(1, activation='sigmoid')(
tf.keras.layers.concatenate([deep, wide])
)
output = tf.squeeze(output, -1)

model = tf.keras.Model(input_layers, output)
model.compile(
Expand Down Expand Up @@ -371,4 +370,4 @@ def run_fn(fn_args: fn_args_utils.FnArgs):
model, tf_transform_output
),
}
model.save(fn_args.serving_model_dir, save_format='tf', signatures=signatures)
tf.saved_model.save(model, fn_args.serving_model_dir, signatures=signatures)
141 changes: 58 additions & 83 deletions tfx/experimental/templates/taxi/models/keras_model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,98 +106,73 @@ def _build_keras_model(hidden_units, learning_rate):
Returns:
A keras Model.
"""
real_valued_columns = [
tf.feature_column.numeric_column(key, shape=())
for key in features.transformed_names(features.DENSE_FLOAT_FEATURE_KEYS)
]
categorical_columns = [
tf.feature_column.categorical_column_with_identity( # pylint: disable=g-complex-comprehension
key,
num_buckets=features.VOCAB_SIZE + features.OOV_SIZE,
default_value=0)
for key in features.transformed_names(features.VOCAB_FEATURE_KEYS)
]
categorical_columns += [
tf.feature_column.categorical_column_with_identity( # pylint: disable=g-complex-comprehension
key,
num_buckets=num_buckets,
default_value=0) for key, num_buckets in zip(
features.transformed_names(features.BUCKET_FEATURE_KEYS),
features.BUCKET_FEATURE_BUCKET_COUNT)
]
categorical_columns += [
tf.feature_column.categorical_column_with_identity( # pylint: disable=g-complex-comprehension
key,
num_buckets=num_buckets,
default_value=0) for key, num_buckets in zip(
features.transformed_names(features.CATEGORICAL_FEATURE_KEYS),
features.CATEGORICAL_FEATURE_MAX_VALUES)
]
indicator_column = [
tf.feature_column.indicator_column(categorical_column)
for categorical_column in categorical_columns
]

model = _wide_and_deep_classifier(
# TODO(b/140320729) Replace with premade wide_and_deep keras model
wide_columns=indicator_column,
deep_columns=real_valued_columns,
dnn_hidden_units=hidden_units,
learning_rate=learning_rate)
return model


def _wide_and_deep_classifier(wide_columns, deep_columns, dnn_hidden_units,
learning_rate):
"""Build a simple keras wide and deep model.

Args:
wide_columns: Feature columns wrapped in indicator_column for wide (linear)
part of the model.
deep_columns: Feature columns for deep part of the model.
dnn_hidden_units: [int], the layer sizes of the hidden DNN.
learning_rate: [float], learning rate of the Adam optimizer.

Returns:
A Wide and Deep Keras model
"""
# Keras needs the feature definitions at compile time.
# TODO(b/139081439): Automate generation of input layers from FeatureColumn.
input_layers = {
colname: tf.keras.layers.Input(name=colname, shape=(), dtype=tf.float32)
for colname in features.transformed_names(
features.DENSE_FLOAT_FEATURE_KEYS)
deep_input = {
colname: tf.keras.layers.Input(name=colname, shape=(1,), dtype=tf.float32)
for colname in features.transformed_names(features.DENSE_FLOAT_FEATURE_KEYS)
}
input_layers.update({
colname: tf.keras.layers.Input(name=colname, shape=(), dtype='int32')
wide_vocab_input = {
colname: tf.keras.layers.Input(name=colname, shape=(1,), dtype='int32')
for colname in features.transformed_names(features.VOCAB_FEATURE_KEYS)
})
input_layers.update({
colname: tf.keras.layers.Input(name=colname, shape=(), dtype='int32')
}
wide_bucket_input = {
colname: tf.keras.layers.Input(name=colname, shape=(1,), dtype='int32')
for colname in features.transformed_names(features.BUCKET_FEATURE_KEYS)
})
input_layers.update({
colname: tf.keras.layers.Input(name=colname, shape=(), dtype='int32') for
colname in features.transformed_names(features.CATEGORICAL_FEATURE_KEYS)
})

# TODO(b/161952382): Replace with Keras premade models and
# Keras preprocessing layers.
deep = tf.keras.layers.DenseFeatures(deep_columns)(input_layers)
for numnodes in dnn_hidden_units:
}
wide_categorical_input = {
colname: tf.keras.layers.Input(name=colname, shape=(1,), dtype='int32')
for colname in features.transformed_names(features.CATEGORICAL_FEATURE_KEYS)
}
input_layers = {
**deep_input,
**wide_vocab_input,
**wide_bucket_input,
**wide_categorical_input,
}

deep = tf.keras.layers.concatenate(
[tf.keras.layers.Normalization()(layer) for layer in deep_input.values()]
)
for numnodes in (hidden_units or [100, 70, 50, 25]):
deep = tf.keras.layers.Dense(numnodes)(deep)
wide = tf.keras.layers.DenseFeatures(wide_columns)(input_layers)

output = tf.keras.layers.Dense(
1, activation='sigmoid')(
tf.keras.layers.concatenate([deep, wide]))
output = tf.squeeze(output, -1)
wide_layers = []
for key in features.transformed_names(features.VOCAB_FEATURE_KEYS):
wide_layers.append(
tf.keras.layers.CategoryEncoding(num_tokens=features.VOCAB_SIZE + features.OOV_SIZE)(
input_layers[key]
)
)
for key, num_tokens in zip(
features.transformed_names(features.BUCKET_FEATURE_KEYS),
features.BUCKET_FEATURE_BUCKET_COUNT,
):
wide_layers.append(
tf.keras.layers.CategoryEncoding(num_tokens=num_tokens)(
input_layers[key]
)
)
for key, num_tokens in zip(
features.transformed_names(features.CATEGORICAL_FEATURE_KEYS),
features.CATEGORICAL_FEATURE_MAX_VALUES,
):
wide_layers.append(
tf.keras.layers.CategoryEncoding(num_tokens=num_tokens)(
input_layers[key]
)
)
wide = tf.keras.layers.concatenate(wide_layers)

output = tf.keras.layers.Dense(1, activation='sigmoid')(
tf.keras.layers.concatenate([deep, wide])
)
output = tf.keras.layers.Reshape((1,))(output)

model = tf.keras.Model(input_layers, output)
model.compile(
loss='binary_crossentropy',
optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
metrics=[tf.keras.metrics.BinaryAccuracy()])
metrics=[tf.keras.metrics.BinaryAccuracy()],
)
model.summary(print_fn=logging.info)
return model

Expand Down Expand Up @@ -240,4 +215,4 @@ def run_fn(fn_args):
'transform_features':
_get_transform_features_signature(model, tf_transform_output),
}
model.save(fn_args.serving_model_dir, save_format='tf', signatures=signatures)
tf.saved_model.save(model, fn_args.serving_model_dir, signatures=signatures)
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class ModelTest(tf.test.TestCase):
def testBuildKerasModel(self):
built_model = model._build_keras_model(
hidden_units=[1, 1], learning_rate=0.1) # pylint: disable=protected-access
self.assertEqual(len(built_model.layers), 10)
self.assertEqual(len(built_model.layers), 13)

built_model = model._build_keras_model(hidden_units=[1], learning_rate=0.1) # pylint: disable=protected-access
self.assertEqual(len(built_model.layers), 9)
self.assertEqual(len(built_model.layers), 12)
37 changes: 20 additions & 17 deletions tfx/orchestration/kubeflow/v2/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,25 +234,28 @@ def create_pipeline_components(
model_blessing=tfx.dsl.Channel(
type=tfx.types.standard_artifacts.ModelBlessing)).with_id(
'Resolver.latest_blessed_model_resolver')
# Set the TFMA config for Model Evaluation and Validation.
# Uses TFMA to compute a evaluation statistics over features of a model and
# perform quality validation of a candidate model (compared to a baseline).
eval_config = tfma.EvalConfig(
model_specs=[tfma.ModelSpec(signature_name='eval')],
metrics_specs=[
tfma.MetricsSpec(
metrics=[tfma.MetricConfig(class_name='ExampleCount')],
thresholds={
'binary_accuracy':
tfma.MetricThreshold(
value_threshold=tfma.GenericValueThreshold(
lower_bound={'value': 0.5}),
change_threshold=tfma.GenericChangeThreshold(
direction=tfma.MetricDirection.HIGHER_IS_BETTER,
absolute={'value': -1e-10}))
})
model_specs=[
tfma.ModelSpec(
signature_name='serving_default', label_key='tips_xf',
preprocessing_function_names=['transform_features'])
],
slicing_specs=[
tfma.SlicingSpec(),
tfma.SlicingSpec(feature_keys=['trip_start_hour'])
slicing_specs=[tfma.SlicingSpec()],
metrics_specs=[
tfma.MetricsSpec(metrics=[
tfma.MetricConfig(
class_name='BinaryAccuracy',
threshold=tfma.MetricThreshold(
value_threshold=tfma.GenericValueThreshold(
lower_bound={'value': 0.6}),
# Change threshold will be ignored if there is no
# baseline model resolved from MLMD (first run).
change_threshold=tfma.GenericChangeThreshold(
direction=tfma.MetricDirection.HIGHER_IS_BETTER,
absolute={'value': -1e-10})))
])
])
evaluator = tfx.components.Evaluator(
examples=example_gen.outputs['examples'],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -706,7 +706,7 @@
"parameters": {
"eval_config": {
"runtimeValue": {
"constant": "{\n \"metrics_specs\": [\n {\n \"metrics\": [\n {\n \"class_name\": \"ExampleCount\"\n }\n ],\n \"thresholds\": {\n \"binary_accuracy\": {\n \"change_threshold\": {\n \"absolute\": -1e-10,\n \"direction\": \"HIGHER_IS_BETTER\"\n },\n \"value_threshold\": {\n \"lower_bound\": 0.5\n }\n }\n }\n }\n ],\n \"model_specs\": [\n {\n \"signature_name\": \"eval\"\n }\n ],\n \"slicing_specs\": [\n {},\n {\n \"feature_keys\": [\n \"trip_start_hour\"\n ]\n }\n ]\n}"
"constant": "{\n \"metrics_specs\": [\n {\n \"metrics\": [\n {\n \"class_name\": \"BinaryAccuracy\",\n \"threshold\": {\n \"change_threshold\": {\n \"absolute\": -1e-10,\n \"direction\": \"HIGHER_IS_BETTER\"\n },\n \"value_threshold\": {\n \"lower_bound\": 0.6\n }\n }\n }\n ]\n }\n ],\n \"model_specs\": [\n {\n \"label_key\": \"tips_xf\",\n \"preprocessing_function_names\": [\n \"transform_features\"\n ],\n \"signature_name\": \"serving_default\"\n }\n ],\n \"slicing_specs\": [\n {}\n ]\n}"
}
},
"example_splits": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -698,7 +698,7 @@
"eval_config": {
"runtimeValue": {
"constantValue": {
"stringValue": "{\n \"metrics_specs\": [\n {\n \"metrics\": [\n {\n \"class_name\": \"ExampleCount\"\n }\n ],\n \"thresholds\": {\n \"binary_accuracy\": {\n \"change_threshold\": {\n \"absolute\": -1e-10,\n \"direction\": \"HIGHER_IS_BETTER\"\n },\n \"value_threshold\": {\n \"lower_bound\": 0.5\n }\n }\n }\n }\n ],\n \"model_specs\": [\n {\n \"signature_name\": \"eval\"\n }\n ],\n \"slicing_specs\": [\n {},\n {\n \"feature_keys\": [\n \"trip_start_hour\"\n ]\n }\n ]\n}"
"stringValue": "{\n \"metrics_specs\": [\n {\n \"metrics\": [\n {\n \"class_name\": \"BinaryAccuracy\",\n \"threshold\": {\n \"change_threshold\": {\n \"absolute\": -1e-10,\n \"direction\": \"HIGHER_IS_BETTER\"\n },\n \"value_threshold\": {\n \"lower_bound\": 0.6\n }\n }\n }\n ]\n }\n ],\n \"model_specs\": [\n {\n \"label_key\": \"tips_xf\",\n \"preprocessing_function_names\": [\n \"transform_features\"\n ],\n \"signature_name\": \"serving_default\"\n }\n ],\n \"slicing_specs\": [\n {}\n ]\n}"
}
}
},
Expand Down
3 changes: 1 addition & 2 deletions tfx/tools/docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ WORKDIR ${TFX_DIR}
ARG TFX_DEPENDENCY_SELECTOR
ENV TFX_DEPENDENCY_SELECTOR=${TFX_DEPENDENCY_SELECTOR}

RUN python -m pip install --upgrade pip wheel setuptools
RUN python -m pip install tomli
RUN python -m pip install --upgrade pip wheel setuptools tomli

# TODO(b/175089240): clean up conditional checks on whether ml-pipelines-sdk is
# built after TFX versions <= 0.25 are no longer eligible for cherry-picks.
Expand Down
6 changes: 3 additions & 3 deletions tfx/v1/orchestration/experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
"""TFX orchestration.experimental module."""

try:
from tfx.orchestration.kubeflow.decorators import exit_handler # pylint: disable=g-import-not-at-top
from tfx.orchestration.kubeflow.decorators import FinalStatusStr # pylint: disable=g-import-not-at-top

from tfx.orchestration.kubeflow.v2.kubeflow_v2_dag_runner import (
KubeflowV2DagRunner,
KubeflowV2DagRunnerConfig,
Expand All @@ -24,11 +27,8 @@

__all__ = [
"FinalStatusStr",
"KubeflowDagRunner",
"KubeflowDagRunnerConfig",
"KubeflowV2DagRunner",
"KubeflowV2DagRunnerConfig",
"LABEL_KFP_SDK_ENV",
"exit_handler",
"get_default_kubeflow_metadata_config",
]
Loading