Skip to content

Commit

Permalink
Running through pyformat to meet Google code standards
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobdevlin-google committed Nov 9, 2018
1 parent aefad12 commit 3c67c1d
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 17 deletions.
5 changes: 3 additions & 2 deletions extract_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,9 @@ def model_fn(features, labels, mode, params): # pylint: disable=unused-argument

tvars = tf.trainable_variables()
scaffold_fn = None
(assignment_map, initialized_variable_names
) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint)
(assignment_map,
initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(
tvars, init_checkpoint)
if use_tpu:

def tpu_scaffold():
Expand Down
30 changes: 16 additions & 14 deletions run_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@

flags.DEFINE_bool("do_eval", False, "Whether to run eval on the dev set.")

flags.DEFINE_bool("do_predict", False, "Whether to run the model in inference mode on the test set.")
flags.DEFINE_bool(
"do_predict", False,
"Whether to run the model in inference mode on the test set.")

flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.")

Expand Down Expand Up @@ -248,8 +250,7 @@ def get_dev_examples(self, data_dir):
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test_matched.tsv")),
"test")
self._read_tsv(os.path.join(data_dir, "test_matched.tsv")), "test")

def get_labels(self):
"""See base class."""
Expand Down Expand Up @@ -289,7 +290,7 @@ def get_dev_examples(self, data_dir):
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")

def get_labels(self):
"""See base class."""
Expand Down Expand Up @@ -329,7 +330,7 @@ def get_dev_examples(self, data_dir):
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")

def get_labels(self):
"""See base class."""
Expand Down Expand Up @@ -659,9 +660,7 @@ def metric_fn(per_example_loss, label_ids, logits):
scaffold_fn=scaffold_fn)
else:
output_spec = tf.contrib.tpu.TPUEstimatorSpec(
mode=mode,
predictions=probabilities,
scaffold_fn=scaffold_fn)
mode=mode, predictions=probabilities, scaffold_fn=scaffold_fn)
return output_spec

return model_fn
Expand Down Expand Up @@ -874,7 +873,8 @@ def main(_):
predict_examples = processor.get_test_examples(FLAGS.data_dir)
predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record")
file_based_convert_examples_to_features(predict_examples, label_list,
FLAGS.max_seq_length, tokenizer, predict_file)
FLAGS.max_seq_length, tokenizer,
predict_file)

tf.logging.info("***** Running prediction*****")
tf.logging.info(" Num examples = %d", len(predict_examples))
Expand All @@ -887,20 +887,22 @@ def main(_):

predict_drop_remainder = True if FLAGS.use_tpu else False
predict_input_fn = file_based_input_fn_builder(
input_file=predict_file,
seq_length=FLAGS.max_seq_length,
is_training=False,
drop_remainder=predict_drop_remainder)
input_file=predict_file,
seq_length=FLAGS.max_seq_length,
is_training=False,
drop_remainder=predict_drop_remainder)

result = estimator.predict(input_fn=predict_input_fn)

output_predict_file = os.path.join(FLAGS.output_dir, "test_results.tsv")
with tf.gfile.GFile(output_predict_file, "w") as writer:
tf.logging.info("***** Predict results *****")
for prediction in result:
output_line = "\t".join(str(class_probability) for class_probability in prediction) + "\n"
output_line = "\t".join(
str(class_probability) for class_probability in prediction) + "\n"
writer.write(output_line)


if __name__ == "__main__":
flags.mark_flag_as_required("data_dir")
flags.mark_flag_as_required("task_name")
Expand Down
2 changes: 1 addition & 1 deletion tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def convert_by_vocab(vocab, items):
"""Converts a sequence of [tokens|ids] using the vocab."""
output = []
for item in items:
output.append(vocab[item])
output.append(vocab[item])
return output


Expand Down

0 comments on commit 3c67c1d

Please sign in to comment.