From d66a146741588fb208450bde15aa7db143baaa69 Mon Sep 17 00:00:00 2001 From: Jacob Devlin Date: Thu, 28 Mar 2019 09:29:59 -0700 Subject: [PATCH] (1) Updating TF Hub classifier (2) Updating tokenizer to support emojis --- run_classifier_with_tfhub.py | 48 +++++++++++++++++++++++++++++++++--- tokenization.py | 2 +- tokenization_test.py | 1 + 3 files changed, 46 insertions(+), 5 deletions(-) diff --git a/run_classifier_with_tfhub.py b/run_classifier_with_tfhub.py index f42b4f74a..9d2f80f6b 100644 --- a/run_classifier_with_tfhub.py +++ b/run_classifier_with_tfhub.py @@ -73,6 +73,7 @@ def create_model(is_training, input_ids, input_mask, segment_ids, labels, logits = tf.matmul(output_layer, output_weights, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) + probabilities = tf.nn.softmax(logits, axis=-1) log_probs = tf.nn.log_softmax(logits, axis=-1) one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32) @@ -80,7 +81,7 @@ def create_model(is_training, input_ids, input_mask, segment_ids, labels, per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) loss = tf.reduce_mean(per_example_loss) - return (loss, per_example_loss, logits) + return (loss, per_example_loss, logits, probabilities) def model_fn_builder(num_labels, learning_rate, num_train_steps, @@ -101,7 +102,7 @@ def model_fn(features, labels, mode, params): # pylint: disable=unused-argument is_training = (mode == tf.estimator.ModeKeys.TRAIN) - (total_loss, per_example_loss, logits) = create_model( + (total_loss, per_example_loss, logits, probabilities) = create_model( is_training, input_ids, input_mask, segment_ids, label_ids, num_labels, bert_hub_module_handle) @@ -130,8 +131,12 @@ def metric_fn(per_example_loss, label_ids, logits): mode=mode, loss=total_loss, eval_metrics=eval_metrics) + elif mode == tf.estimator.ModeKeys.PREDICT: + output_spec = tf.contrib.tpu.TPUEstimatorSpec( + mode=mode, predictions={"probabilities": probabilities}) else: - raise ValueError("Only TRAIN and EVAL modes are supported: %s" % (mode)) + raise ValueError( + "Only TRAIN, EVAL and PREDICT modes are supported: %s" % (mode)) return output_spec @@ -215,7 +220,8 @@ def main(_): model_fn=model_fn, config=run_config, train_batch_size=FLAGS.train_batch_size, - eval_batch_size=FLAGS.eval_batch_size) + eval_batch_size=FLAGS.eval_batch_size, + predict_batch_size=FLAGS.predict_batch_size) if FLAGS.do_train: train_features = run_classifier.convert_examples_to_features( @@ -265,6 +271,40 @@ def main(_): tf.logging.info(" %s = %s", key, str(result[key])) writer.write("%s = %s\n" % (key, str(result[key]))) + if FLAGS.do_predict: + predict_examples = processor.get_test_examples(FLAGS.data_dir) + if FLAGS.use_tpu: + # Discard batch remainder if running on TPU + n = len(predict_examples) + predict_examples = predict_examples[:(n - n % FLAGS.predict_batch_size)] + + predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record") + run_classifier.file_based_convert_examples_to_features( + predict_examples, label_list, FLAGS.max_seq_length, tokenizer, + predict_file) + + tf.logging.info("***** Running prediction*****") + tf.logging.info(" Num examples = %d", len(predict_examples)) + tf.logging.info(" Batch size = %d", FLAGS.predict_batch_size) + + predict_input_fn = run_classifier.file_based_input_fn_builder( + input_file=predict_file, + seq_length=FLAGS.max_seq_length, + is_training=False, + drop_remainder=FLAGS.use_tpu) + + result = estimator.predict(input_fn=predict_input_fn) + + output_predict_file = os.path.join(FLAGS.output_dir, "test_results.tsv") + with tf.gfile.GFile(output_predict_file, "w") as writer: + tf.logging.info("***** Predict results *****") + for prediction in result: + probabilities = prediction["probabilities"] + output_line = "\t".join( + str(class_probability) + for class_probability in probabilities) + "\n" + writer.write(output_line) + if __name__ == "__main__": flags.mark_flag_as_required("data_dir") diff --git a/tokenization.py b/tokenization.py index dc476a698..0ee135953 100644 --- a/tokenization.py +++ b/tokenization.py @@ -378,7 +378,7 @@ def _is_control(char): if char == "\t" or char == "\n" or char == "\r": return False cat = unicodedata.category(char) - if cat.startswith("C"): + if cat in ("Cc", "Cf"): return True return False diff --git a/tokenization_test.py b/tokenization_test.py index e85a644d1..0afaedd2e 100644 --- a/tokenization_test.py +++ b/tokenization_test.py @@ -121,6 +121,7 @@ def test_is_control(self): self.assertFalse(tokenization._is_control(u" ")) self.assertFalse(tokenization._is_control(u"\t")) self.assertFalse(tokenization._is_control(u"\r")) + self.assertFalse(tokenization._is_control(u"\U0001F4A9")) def test_is_punctuation(self): self.assertTrue(tokenization._is_punctuation(u"-"))