forked from Dechrissen/LIN538-Final
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrigram_model.py
99 lines (86 loc) · 3.12 KB
/
trigram_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import nltk
from nltk import trigrams
from collections import Counter, defaultdict
import random
from pathlib import Path
import os
# Path to wsj corpus
corpus_path = Path("C:/Users/Derek/Documents/wsj_corpus")
def trigram_model(corpus_path):
"""Builds a trigram model trained on a training corpus."""
# Smoothing of 0.01 to handle unattested words in test data
model = defaultdict(lambda: defaultdict(lambda: 0.01))
# Training set of 80% of the Wall Street Journal corpus (first 1963 files)
for file in os.listdir(corpus_path)[:1964]:
with open(corpus_path / file, 'r') as current:
sents = current.readlines()
for sentence in sents:
if ('.START' in sentence) or (sentence == '\n'):
continue
else:
sentence = sentence.split()
for w1, w2, w3 in trigrams(sentence, pad_right=True, pad_left=True):
model[(w1, w2)][w3] += 1
# Transform the counts into probabilities
for w1_w2 in model:
total_count = float(sum(model[w1_w2].values()))
for w3 in model[w1_w2]:
model[w1_w2][w3] /= total_count
return model
def generate_sentence(model):
"""Generates a sentence according to a trigram model."""
text = [None, None]
sentence_finished = False
while not sentence_finished:
r = random.random()
accumulator = .0
for word in model[tuple(text[-2:])].keys():
accumulator += model[tuple(text[-2:])][word]
if accumulator >= r:
text.append(word)
break
if text[-2:] == [None, None]:
sentence_finished = True
print(' '.join([t for t in text if t]))
def perplexity(test_sent, model):
"""Computes the perplexity of a trigram model on a test sentence."""
test_sent = test_sent.split()
perplexity = 1
N = 0
for w1, w2, w3 in trigrams(test_sent, pad_right=True, pad_left=True):
N += 1
perplexity = perplexity * (1/model[(w1, w2)][w3])
perplexity = pow(perplexity, 1/float(N))
return perplexity
# Create a trigram model according to wsj corpus
model = trigram_model(corpus_path)
# Construct a test set of 20% of the Wall Street Journal corpus (files 1964 - 2454)
testset = []
for file in os.listdir(corpus_path)[1964:2455]:
with open(corpus_path / file, 'r') as current:
sents = current.readlines()
for sentence in sents:
if ('.START' in sentence) or (sentence == '\n'):
continue
else:
testset.append(sentence)
# Calculate the perplexity of the model with the entire test set
PP = 0
perplexities = []
i = 0
for sentence in testset:
p = perplexity(sentence, model)
# ignore infinity cases
if p == float("inf"):
continue
i += 1
PP += p
# average of perplexities
PP = PP / i
# --- Output ---
# Print the probability of a sentence starting with 'The'
print(model[None, None]["The"])
# Generate a sentence according to the model
generate_sentence(model)
# Print the model's perplexity on our test set
print('Model perplexity on test set:', PP)