Skip to content
This repository has been archived by the owner on Jun 7, 2023. It is now read-only.

Commit

Permalink
Clean up GRNN surprisal extraction file
Browse files Browse the repository at this point in the history
  • Loading branch information
hans committed May 13, 2020
1 parent 966de58 commit 1cbaa13
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 60 deletions.
2 changes: 1 addition & 1 deletion models/GRNN/bin/get_surprisals
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ python ${GRNN_ROOT}/src/language_models/evaluate_target_word_test.py \
--data ${GRNN_ROOT}/data/wiki \
--checkpoint ${GRNN_ROOT}/hidden650_batch128_dropout0.2_lr20.0.pt \
--prefixfile /tmp/input_tokenized \
--surprisalmode True
--mode surprisal
101 changes: 42 additions & 59 deletions models/GRNN/evaluate_target_word_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,9 @@
help='temperature - higher will increase diversity')
parser.add_argument('--outf', type=argparse.FileType("w", encoding="utf-8"), default=sys.stdout,
help='output file for generated text')
parser.add_argument("--mode", choices=["surprisal", "predictions"])
parser.add_argument('--prefixfile', type=str, default='-',
help='File with sentence prefix from which to generate continuations')
parser.add_argument('--surprisalmode', type=bool, default=False,
help='Run in surprisal mode; specify sentence with --prefixfile')

args = parser.parse_args()

Expand Down Expand Up @@ -64,41 +63,41 @@

dictionary = dictionary_corpus.Dictionary(args.data)
vocab_size = len(dictionary)

###
prefix = dictionary_corpus.tokenize(dictionary, args.prefixfile)
#print(prefix.shape)
#for w in prefix:
# print(dictionary.idx2word[w.item()])
# try auto-generate
if not args.surprisalmode:
# print(type(prefix))
# print(prefix.shape)
# print(prefix)
hidden = model.init_hidden(1)


def get_surprisals(sentences, model, dictionary, seed, device="cpu"):
ntokens = dictionary.__len__()

with torch.no_grad():
for i, sentence in enumerate(sentences):
torch.manual_seed(seed)
hidden = model.init_hidden(1)
input = torch.randint(ntokens, (1, 1), dtype=torch.long).to(device)

prev_word = None
sentence_surprisals = []
for j, word in enumerate(sentence):
if j == 0:
word_surprisal = 0.
else:
input.fill_(prev_word.item())
output, hidden = model(input, hidden)

# Compute word-level surprisals
word_softmax = F.softmax(output, dim=2)
word_surprisals = -torch.log2(word_softmax)
word_surprisals = word_surprisals.squeeze().cpu()
word_surprisal = word_surprisals[word].item()

sentence_surprisals.append((dictionary.idx2word[word.item()], word_surprisal))
prev_word = word

yield sentence_surprisals


if __name__ == "__main__":
device = torch.device("cuda" if args.cuda else "cpu")
input = torch.randint(ntokens, (1, 1), dtype=torch.long).to(device)
with args.outf as outf:
for i in range(args.sentences):
for word in prefix:
#print(word)
#print(word.item())
outf.write(dictionary.idx2word[word.item()] + " ")
input.fill_(word.item())
output, hidden = model(input,hidden)
generated_word = None
while generated_word != "<eos>":
word_weights = output.squeeze().div(args.temperature).exp().cpu()
word_idx = torch.multinomial(word_weights, 1)[0]
input.fill_(word_idx)
generated_word = dictionary.idx2word[word_idx]
outf.write(generated_word + " ")
output, hidden = model(input, hidden)
outf.write("\n")


if args.surprisalmode:
sentences = []
thesentence = []
eosidx = dictionary.word2idx["<eos>"]
Expand All @@ -107,29 +106,13 @@
if w == eosidx:
sentences.append(thesentence)
thesentence = []
ntokens = dictionary.__len__()
device = torch.device("cuda" if args.cuda else "cpu")
with args.outf as outf:
# write header.
outf.write("sentence_id\ttoken_id\ttoken\tsurprisal\n")

for i, sentence in enumerate(sentences):
torch.manual_seed(args.seed)
hidden = model.init_hidden(1)
input = torch.randint(ntokens, (1, 1), dtype=torch.long).to(device)
totalsurprisal = 0.0
firstword = sentence[0]
input.fill_(firstword.item())

outf.write("%i\t%i\t%s\t%f\n" % (i + 1, 1, dictionary.idx2word[firstword.item()], 0.))

output, hidden = model(input,hidden)
word_weights = output.squeeze().div(args.temperature).exp().cpu()
word_surprisals = -1*torch.log2(word_weights/sum(word_weights))
for j, word in enumerate(sentence[1:len(prefix)]):
word_surprisal = word_surprisals[word].item()
outf.write("%i\t%i\t%s\t%f\n" % (i + 1, j + 2, dictionary.idx2word[word.item()], word_surprisal))
input.fill_(word.item())
output, hidden = model(input, hidden)
word_weights = output.squeeze().div(args.temperature).exp().cpu()
word_surprisals = -1*torch.log2(word_weights/sum(word_weights))
if args.mode == "surprisal":
with args.outf as outf:
# write header.
outf.write("sentence_id\ttoken_id\ttoken\tsurprisal\n")

surprisals = get_surprisals(sentences, model, dictionary, args.seed, device)
for i, sentence_surps in enumerate(surprisals):
for j, (word, word_surp) in enumerate(sentence_surps):
outf.write("%i\t%i\t%s\t%f\n" % (i + 1, j + 1, word, word_surp))

0 comments on commit 1cbaa13

Please sign in to comment.