Skip to content

Commit

Permalink
Replace the usage of FLAGS.bert_hub_module_handle with function argum…
Browse files Browse the repository at this point in the history
…ent to faciliate code reuse in colabs.
  • Loading branch information
jacobdevlin-google committed Mar 20, 2019
1 parent ffbda2a commit 7c1a4bf
Showing 1 changed file with 10 additions and 11 deletions.
21 changes: 10 additions & 11 deletions run_classifier_with_tfhub.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,12 @@


def create_model(is_training, input_ids, input_mask, segment_ids, labels,
num_labels):
num_labels, bert_hub_module_handle):
"""Creates a classification model."""
tags = set()
if is_training:
tags.add("train")
bert_module = hub.Module(
FLAGS.bert_hub_module_handle,
tags=tags,
trainable=True)
bert_module = hub.Module(bert_hub_module_handle, tags=tags, trainable=True)
bert_inputs = dict(
input_ids=input_ids,
input_mask=input_mask,
Expand Down Expand Up @@ -87,7 +84,7 @@ def create_model(is_training, input_ids, input_mask, segment_ids, labels,


def model_fn_builder(num_labels, learning_rate, num_train_steps,
num_warmup_steps, use_tpu):
num_warmup_steps, use_tpu, bert_hub_module_handle):
"""Returns `model_fn` closure for TPUEstimator."""

def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
Expand All @@ -105,7 +102,8 @@ def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
is_training = (mode == tf.estimator.ModeKeys.TRAIN)

(total_loss, per_example_loss, logits) = create_model(
is_training, input_ids, input_mask, segment_ids, label_ids, num_labels)
is_training, input_ids, input_mask, segment_ids, label_ids, num_labels,
bert_hub_module_handle)

output_spec = None
if mode == tf.estimator.ModeKeys.TRAIN:
Expand Down Expand Up @@ -140,10 +138,10 @@ def metric_fn(per_example_loss, label_ids, logits):
return model_fn


def create_tokenizer_from_hub_module():
def create_tokenizer_from_hub_module(bert_hub_module_handle):
"""Get the vocab file and casing info from the Hub module."""
with tf.Graph().as_default():
bert_module = hub.Module(FLAGS.bert_hub_module_handle)
bert_module = hub.Module(bert_hub_module_handle)
tokenization_info = bert_module(signature="tokenization_info", as_dict=True)
with tf.Session() as sess:
vocab_file, do_lower_case = sess.run([tokenization_info["vocab_file"],
Expand Down Expand Up @@ -175,7 +173,7 @@ def main(_):

label_list = processor.get_labels()

tokenizer = create_tokenizer_from_hub_module()
tokenizer = create_tokenizer_from_hub_module(FLAGS.bert_hub_module_handle)

tpu_cluster_resolver = None
if FLAGS.use_tpu and FLAGS.tpu_name:
Expand Down Expand Up @@ -207,7 +205,8 @@ def main(_):
learning_rate=FLAGS.learning_rate,
num_train_steps=num_train_steps,
num_warmup_steps=num_warmup_steps,
use_tpu=FLAGS.use_tpu)
use_tpu=FLAGS.use_tpu,
bert_hub_module_handle=FLAGS.bert_hub_module_handle)

# If TPU is not available, this will fall back to normal Estimator on CPU
# or GPU.
Expand Down

0 comments on commit 7c1a4bf

Please sign in to comment.