Skip to content

Commit

Permalink
(1) Updating TF Hub classifier (2) Updating tokenizer to support emojis
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobdevlin-google committed Mar 28, 2019
1 parent 7c1a4bf commit d66a146
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 5 deletions.
48 changes: 44 additions & 4 deletions run_classifier_with_tfhub.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,15 @@ def create_model(is_training, input_ids, input_mask, segment_ids, labels,

logits = tf.matmul(output_layer, output_weights, transpose_b=True)
logits = tf.nn.bias_add(logits, output_bias)
probabilities = tf.nn.softmax(logits, axis=-1)
log_probs = tf.nn.log_softmax(logits, axis=-1)

one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32)

per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1)
loss = tf.reduce_mean(per_example_loss)

return (loss, per_example_loss, logits)
return (loss, per_example_loss, logits, probabilities)


def model_fn_builder(num_labels, learning_rate, num_train_steps,
Expand All @@ -101,7 +102,7 @@ def model_fn(features, labels, mode, params): # pylint: disable=unused-argument

is_training = (mode == tf.estimator.ModeKeys.TRAIN)

(total_loss, per_example_loss, logits) = create_model(
(total_loss, per_example_loss, logits, probabilities) = create_model(
is_training, input_ids, input_mask, segment_ids, label_ids, num_labels,
bert_hub_module_handle)

Expand Down Expand Up @@ -130,8 +131,12 @@ def metric_fn(per_example_loss, label_ids, logits):
mode=mode,
loss=total_loss,
eval_metrics=eval_metrics)
elif mode == tf.estimator.ModeKeys.PREDICT:
output_spec = tf.contrib.tpu.TPUEstimatorSpec(
mode=mode, predictions={"probabilities": probabilities})
else:
raise ValueError("Only TRAIN and EVAL modes are supported: %s" % (mode))
raise ValueError(
"Only TRAIN, EVAL and PREDICT modes are supported: %s" % (mode))

return output_spec

Expand Down Expand Up @@ -215,7 +220,8 @@ def main(_):
model_fn=model_fn,
config=run_config,
train_batch_size=FLAGS.train_batch_size,
eval_batch_size=FLAGS.eval_batch_size)
eval_batch_size=FLAGS.eval_batch_size,
predict_batch_size=FLAGS.predict_batch_size)

if FLAGS.do_train:
train_features = run_classifier.convert_examples_to_features(
Expand Down Expand Up @@ -265,6 +271,40 @@ def main(_):
tf.logging.info(" %s = %s", key, str(result[key]))
writer.write("%s = %s\n" % (key, str(result[key])))

if FLAGS.do_predict:
predict_examples = processor.get_test_examples(FLAGS.data_dir)
if FLAGS.use_tpu:
# Discard batch remainder if running on TPU
n = len(predict_examples)
predict_examples = predict_examples[:(n - n % FLAGS.predict_batch_size)]

predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record")
run_classifier.file_based_convert_examples_to_features(
predict_examples, label_list, FLAGS.max_seq_length, tokenizer,
predict_file)

tf.logging.info("***** Running prediction*****")
tf.logging.info(" Num examples = %d", len(predict_examples))
tf.logging.info(" Batch size = %d", FLAGS.predict_batch_size)

predict_input_fn = run_classifier.file_based_input_fn_builder(
input_file=predict_file,
seq_length=FLAGS.max_seq_length,
is_training=False,
drop_remainder=FLAGS.use_tpu)

result = estimator.predict(input_fn=predict_input_fn)

output_predict_file = os.path.join(FLAGS.output_dir, "test_results.tsv")
with tf.gfile.GFile(output_predict_file, "w") as writer:
tf.logging.info("***** Predict results *****")
for prediction in result:
probabilities = prediction["probabilities"]
output_line = "\t".join(
str(class_probability)
for class_probability in probabilities) + "\n"
writer.write(output_line)


if __name__ == "__main__":
flags.mark_flag_as_required("data_dir")
Expand Down
2 changes: 1 addition & 1 deletion tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ def _is_control(char):
if char == "\t" or char == "\n" or char == "\r":
return False
cat = unicodedata.category(char)
if cat.startswith("C"):
if cat in ("Cc", "Cf"):
return True
return False

Expand Down
1 change: 1 addition & 0 deletions tokenization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def test_is_control(self):
self.assertFalse(tokenization._is_control(u" "))
self.assertFalse(tokenization._is_control(u"\t"))
self.assertFalse(tokenization._is_control(u"\r"))
self.assertFalse(tokenization._is_control(u"\U0001F4A9"))

def test_is_punctuation(self):
self.assertTrue(tokenization._is_punctuation(u"-"))
Expand Down

0 comments on commit d66a146

Please sign in to comment.