Skip to content
This repository has been archived by the owner on Jun 7, 2023. It is now read-only.

Commit

Permalink
first implementation of get_predictions.hdf5, plus accompanying tests…
Browse files Browse the repository at this point in the history
…! see #26
  • Loading branch information
hans committed May 13, 2020
1 parent 1cbaa13 commit 193cc3f
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 9 deletions.
6 changes: 4 additions & 2 deletions models/GRNN/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
perl && \
rm -rf /var/lib/apt/lists/*

# Add runtime dependencies.
RUN pip install h5py

# Add test dependencies.
RUN pip install nose rednose jsonschema
ENV NOSE_REDNOSE 1
RUN pip install nose jsonschema

# Copy in tokenizer.
COPY ${MODEL_ROOT}/tokenizer /opt/tokenizer
Expand Down
12 changes: 12 additions & 0 deletions models/GRNN/bin/get_predictions.hdf5
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#!/bin/sh
GRNN_ROOT="/opt/colorlessgreenRNNs"
INPUT_FILE="$1"
OUTPUT_FILE="$2"

/opt/bin/tokenize $1 > /tmp/input_tokenized

python ${GRNN_ROOT}/src/language_models/evaluate_target_word_test.py \
--data ${GRNN_ROOT}/data/wiki \
--checkpoint ${GRNN_ROOT}/hidden650_batch128_dropout0.2_lr20.0.pt \
--prefixfile /tmp/input_tokenized \
--mode predictions --outf $OUTPUT_FILE
65 changes: 59 additions & 6 deletions models/GRNN/evaluate_target_word_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import argparse
import sys

import h5py
import torch
import torch.nn as nn
import torch.nn.functional as F
Expand Down Expand Up @@ -66,7 +67,10 @@
prefix = dictionary_corpus.tokenize(dictionary, args.prefixfile)


def get_surprisals(sentences, model, dictionary, seed, device="cpu"):
def _get_predictions_inner(sentences, model, dictionary, seed, device="cpu"):
"""
Returns torch tensors. See `get_predictions` for Numpy returns.
"""
ntokens = dictionary.__len__()

with torch.no_grad():
Expand All @@ -76,23 +80,50 @@ def get_surprisals(sentences, model, dictionary, seed, device="cpu"):
input = torch.randint(ntokens, (1, 1), dtype=torch.long).to(device)

prev_word = None
sentence_surprisals = []
sentence_predictions = []
for j, word in enumerate(sentence):
if j == 0:
word_surprisal = 0.
sentence_predictions.append(None)
else:
input.fill_(prev_word.item())
output, hidden = model(input, hidden)

# Compute word-level surprisals
word_softmax = F.softmax(output, dim=2)
word_surprisals = -torch.log2(word_softmax)
word_surprisals = word_surprisals.squeeze().cpu()
word_surprisal = word_surprisals[word].item()
sentence_predictions.append(word_softmax)

sentence_surprisals.append((dictionary.idx2word[word.item()], word_surprisal))
prev_word = word

yield sentence_predictions


def get_predictions(sentences, model, dictionary, seed, device="cpu"):
ret = _get_predictions_inner(sentences, model, dictionary, seed, device=device)
for sentence_preds in ret:
ret_i = np.array([preds.cpu() if preds is not None else preds
for preds in sentence_preds])
yield ret_i


def get_surprisals(sentences, model, dictionary, seed, device="cpu"):
ntokens = dictionary.__len__()

with torch.no_grad():
predictions = _get_predictions_inner(sentences, model, dictionary, seed, device=device)

for i, (sentence, sentence_preds) in enumerate(zip(sentences, predictions)):
sentence_surprisals = []
for j, (word_j, preds_j) in enumerate(zip(sentence, sentence_preds)):
word_id = word_j.item()

if preds_j is None:
word_surprisal = 0.
else:
word_surprisal = -torch.log2(preds_j).squeeze().cpu()[word_id]

sentence_surprisals.append((dictionary.idx2word[word_id], word_surprisal))

yield sentence_surprisals


Expand All @@ -116,3 +147,25 @@ def get_surprisals(sentences, model, dictionary, seed, device="cpu"):
for i, sentence_surps in enumerate(surprisals):
for j, (word, word_surp) in enumerate(sentence_surps):
outf.write("%i\t%i\t%s\t%f\n" % (i + 1, j + 1, word, word_surp))
elif args.mode == "predictions":
outf = h5py.File(args.outf.name, args.outf.mode)

predictions = get_predictions(sentences, model, dictionary, args.seed, device)
for i, (sentence, sentence_preds) in enumerate(zip(sentences, predictions)):
sentence = [token_id.item() for token_id in sentence]

# Skip the first word, which has null predictions
sentence_preds = [word_preds.squeeze().cpu() for word_preds in sentence_preds[1:]]
first_word_pred = np.ones_like(sentence_preds[0])
first_word_pred /= first_word_pred.sum()
sentence_preds = np.vstack([first_word_pred] + sentence_preds)

group = outf.create_group("/sentence/%i" % i)
group.create_dataset("predictions", data=sentence_preds)
group.create_dataset("tokens", data=sentence)

vocab_encoded = np.array(dictionary.idx2word)
vocab_encoded = np.char.encode(vocab_encoded, "utf-8")
outf.create_dataset("/vocabulary", data=vocab_encoded)

outf.close()
61 changes: 60 additions & 1 deletion test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,17 @@
# coding=utf-8

from collections import defaultdict
from functools import lru_cache
import json
import re
import subprocess
import sys
from tempfile import NamedTemporaryFile
import unittest

import h5py
import numpy as np

import jsonschema
from nose.tools import *

Expand All @@ -26,6 +30,7 @@
flags=re.MULTILINE)


@lru_cache()
def get_spec():
return json.loads(subprocess.check_output(["spec"]).decode("utf-8"))

Expand Down Expand Up @@ -58,7 +63,9 @@ def setUpClass(cls):
else:
text_f = NamedTemporaryFile("w", encoding="utf-8")

with text_f:
predictions_f = NamedTemporaryFile("wb")

with text_f, predictions_f:
test_string = TEST_STRING
if sys.version_info[0] == 2:
test_string = TEST_STRING.encode("utf-8")
Expand All @@ -79,6 +86,11 @@ def setUpClass(cls):
cls.surprisals_output = subprocess.check_output(["get_surprisals", text_f.name]).decode("utf-8")
print(cls.surprisals_output)

print("== get_predictions.hdf5 %s" % text_f.name)
cls.predictions_output = subprocess.check_output(["get_predictions.hdf5", text_f.name, predictions_f.name]).decode("utf-8")
print(cls.predictions_output)
cls.predictions_data = h5py.File(predictions_f.name, "r")

cls.tokenized_lines = [line.strip() for line in cls.tokenized_output.strip().split("\n")]
cls.unkified_lines = [line.strip() for line in cls.unkified_output.strip().split("\n")]
cls.surprisal_lines = [line.strip().split("\t") for line in cls.surprisals_output.strip().split("\n")]
Expand Down Expand Up @@ -149,6 +161,53 @@ def test_surprisal_determinism(self):
# TODO
...

def test_tokenization_match_predictions(self):
"""
Tokenized sequence should exactly match size of predictions array
"""
print(self.predictions_data)
eq_(len(self.predictions_data["/sentence"]), len(self.tokenized_lines),
"Number of lines in predictions output should match number of tokenized lines")

vocabulary = get_spec()["vocabulary"]["items"]
vocab_size = len(vocabulary)

for i, sentence in self.predictions_data["/sentence"].items():
i = int(i)
tokenized_sentence = self.tokenized_lines[i]
tokens = tokenized_sentence.split(" ")
eq_(len(sentence["predictions"]), len(tokens))
eq_(len(sentence["tokens"]), len(tokens))
eq_(sentence["predictions"].shape[1], vocab_size)


def test_predictions_quantatitive(self):
for i, sentence in self.predictions_data["/sentence"].items():
for word_preds in sentence["predictions"]:
# Predictions should be valid probability distribution
ok_(((word_preds >= 0) & (word_preds <= 1)).all(),
"Prediction distributions must have entries in [0, 1]")
np.testing.assert_almost_equal(word_preds.sum(), 1, decimal=3)

def test_predictions_vocabulary(self):
"""
Token IDs in prediction output should match the IDs we reconstruct from
the /vocabulary dataset.
"""

vocabulary = self.predictions_data["/vocabulary"]
# Decode bytestring to UTF-8
vocabulary = np.char.decode(vocabulary, "utf-8")
vocab_size = len(vocabulary)

spec_vocab = get_spec()["vocabulary"]["items"]

ok_(vocab_size > 0)
eq_(vocab_size, len(spec_vocab),
"Prediction vocabulary should match size stated in model spec")
eq_(set(vocabulary), set(spec_vocab),
"Vocabulary items should match exactly (not necessarily in order)")


if __name__ == "__main__":
import nose
Expand Down

0 comments on commit 193cc3f

Please sign in to comment.