Skip to content

Commit

Permalink
Merge pull request suggest-go#38 from suggest-go/scorer_next
Browse files Browse the repository at this point in the history
Introduced a new entity ScorerNext
  • Loading branch information
alldroll authored Dec 21, 2019
2 parents 571a36c + 4335bc5 commit fc5e1ea
Show file tree
Hide file tree
Showing 8 changed files with 143 additions and 104 deletions.
4 changes: 2 additions & 2 deletions pkg/lm/language_model.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ type LanguageModel interface {
// GetWordID returns id for the given token
GetWordID(token Token) (WordID, error)
// Next returns the list of candidates for the given sequence
Next(sequence []WordID) ([]WordID, error)
Next(sequence []WordID) (ScorerNext, error)
}

// languageModel implements LanguageModel interface
Expand Down Expand Up @@ -97,7 +97,7 @@ func (lm *languageModel) GetWordID(token Token) (WordID, error) {
}

// Next returns the list of next candidates for the given sequence
func (lm *languageModel) Next(sequence []WordID) ([]WordID, error) {
func (lm *languageModel) Next(sequence []WordID) (ScorerNext, error) {
nGramOrder := int(lm.config.NGramOrder)

if len(sequence)+1 < nGramOrder {
Expand Down
23 changes: 18 additions & 5 deletions pkg/lm/ngram_model.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ type NGramModel interface {
// Score returns a lm value of the given sequence of WordID
Score(nGrams []WordID) float64
// Next returns a list of WordID which follow after the given sequence of nGrams
Next(nGrams []WordID) ([]WordID, error)
Next(nGrams []WordID) (ScorerNext, error)
}

const (
Expand Down Expand Up @@ -61,12 +61,13 @@ func (m *nGramModel) Score(nGrams []WordID) float64 {
}

// Next returns a list of WordID where each candidate follows after the given sequence of nGrams
func (m *nGramModel) Next(nGrams []WordID) ([]WordID, error) {
if int(m.nGramOrder) <= len(nGrams) {
func (m *nGramModel) Next(nGrams []WordID) (ScorerNext, error) {
if int(m.nGramOrder) <= len(nGrams) || len(nGrams) == 0 {
return nil, errors.New("nGrams length should be less than the nGramModel order")
}

order := 0
counts := make([]WordCount, 0, len(nGrams))
count := WordCount(0)
parent := InvalidContextOffset

Expand All @@ -75,11 +76,23 @@ func (m *nGramModel) Next(nGrams []WordID) ([]WordID, error) {
count, parent = vector.GetCount(nGrams[order], parent)

if count == 0 {
return []WordID{}, nil
return nil, nil
}

counts = append(counts, count)
}

subVector := m.indices[order].SubVector(parent)

if subVector == nil {
return nil, nil
}

return m.indices[order].Next(parent), nil
return &scorerNext{
contextCounts: counts,
nGramVector: subVector,
context: parent,
}, nil
}

// MarshalBinary encodes the receiver into a binary form and returns the result.
Expand Down
109 changes: 64 additions & 45 deletions pkg/lm/ngram_model_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ package lm
import (
"bytes"
"encoding/gob"
"fmt"
"github.com/suggest-go/suggest/pkg/store"
"math"
"reflect"
"testing"
)

Expand Down Expand Up @@ -35,65 +35,84 @@ func TestScoreFromFile(t *testing.T) {
}

func TestPredict(t *testing.T) {
indexer, err := buildIndexerWithInMemoryDictionary("testdata/fixtures/1-gm")

if err != nil {
t.Errorf("Unexpected error: %v", err)
cases := []struct {
nGrams Sentence
word string
expected float64
}{
{
nGrams: Sentence{"i", "am"},
word: "sam",
expected: -0.6931,
},
{
nGrams: Sentence{"i", "am"},
word: "</S>",
expected: -0.6931,
},
{
nGrams: Sentence{"i"},
word: "am",
expected: -0.4054,
},
{
nGrams: Sentence{"i"},
word: "do",
expected: -1.0986,
},
{
nGrams: Sentence{"green"},
word: "eggs",
expected: 0.0,
},
}

directory, err := store.NewFSDirectory("testdata/fixtures")
for i, c := range cases {
t.Run(fmt.Sprintf("predict #%d", i), func(t *testing.T) {
indexer, err := buildIndexerWithInMemoryDictionary("testdata/fixtures/1-gm")

if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if err != nil {
t.Errorf("Unexpected error: %v", err)
}

reader := NewGoogleNGramReader(3, indexer, directory)
ids := make([]WordID, 0, 3)
directory, err := store.NewFSDirectory("testdata/fixtures")

model, err := reader.Read()
if err != nil {
t.Errorf("Unexpected error %v", err)
}
if err != nil {
t.Errorf("Unexpected error: %v", err)
}

cases := []struct {
nGrams Sentence
expected []Token
}{
{Sentence{"i", "am"}, []Token{"sam", "</S>"}},
{Sentence{"i"}, []Token{"am", "do"}},
{Sentence{"green"}, []Token{"eggs"}},
}
reader := NewGoogleNGramReader(3, indexer, directory)
ids := make([]WordID, 0, 3)

for _, c := range cases {
for _, nGram := range c.nGrams {
id, _ := indexer.Get(nGram)
ids = append(ids, id)
}
model, err := reader.Read()

list, err := model.Next(ids)
if err != nil {
t.Errorf("Unexpected error %v", err)
}
if err != nil {
t.Errorf("Unexpected error %v", err)
}

ids = ids[:0]
actual := []Token{}
for _, nGram := range c.nGrams {
id, _ := indexer.Get(nGram)
ids = append(ids, id)
}

scorerNext, err := model.Next(ids)

for _, item := range list {
token, err := indexer.Find(item)
if err != nil {
t.Errorf("Unexpected error %v", err)
}

actual = append(actual, token)
}
id, _ := indexer.Get(c.word)
actual := scorerNext.ScoreNext(id)

if !reflect.DeepEqual(actual, c.expected) {
t.Errorf(
"Test fail, expected %v, got %v",
c.expected,
actual,
)
}
if diff := math.Abs(c.expected - actual); diff >= tolerance {
t.Errorf(
"Test fail, for %v expected score %v, got %v",
c.nGrams,
c.expected,
actual,
)
}
})
}
}

Expand Down
26 changes: 16 additions & 10 deletions pkg/lm/ngram_vector.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ type NGramVector interface {
GetContextOffset(word WordID, context ContextOffset) ContextOffset
// CorpusCount returns size of all counts in the collection
CorpusCount() WordCount
// Next returns next words for the given context
Next(context ContextOffset) []WordID
// SubVector returns NGramVector for the given context
SubVector(context ContextOffset) NGramVector
}

const (
Expand Down Expand Up @@ -65,6 +65,7 @@ func (s *sortedArray) GetCount(word WordID, context ContextOffset) (WordCount, C
// GetContextOffset returns the given node context offset
func (s *sortedArray) GetContextOffset(word WordID, context ContextOffset) ContextOffset {
key := makeKey(word, context)

return s.find(key)
}

Expand All @@ -73,23 +74,24 @@ func (s *sortedArray) CorpusCount() WordCount {
return s.total
}

// Next returns next words for the given context
func (s *sortedArray) Next(context ContextOffset) []WordID {
// SubVector returns NGramVector for the given context
func (s *sortedArray) SubVector(context ContextOffset) NGramVector {
minChild := makeKey(0, context)
maxChild := makeKey(maxContextOffset-2, context)

i := sort.Search(len(s.keys), func(i int) bool { return s.keys[i] >= minChild })
var words []WordID

if i < 0 || i >= len(s.keys) {
return words
return nil
}

for ; s.keys[i] <= maxChild; i++ {
words = append(words, getWordID(s.keys[i]))
}
j := sort.Search(len(s.keys)-i, func(j int) bool { return s.keys[j+i] >= maxChild })

return words
return &sortedArray{
keys: s.keys[i:i+j],
values: s.values[i:i+j],
total: s.total,
}
}

// MarshalBinary encodes the receiver into a binary form and returns the result.
Expand Down Expand Up @@ -163,6 +165,10 @@ func (s *sortedArray) UnmarshalBinary(data []byte) error {

// find finds the given key in the collection. Returns ContextOffset if the key exists, otherwise returns InvalidContextOffset
func (s *sortedArray) find(key uint64) ContextOffset {
if len(s.keys) == 0 || s.keys[0] > key || s.keys[len(s.keys) - 1] < key {
return InvalidContextOffset
}

i := sort.Search(len(s.keys), func(i int) bool { return s.keys[i] >= key })

if i < 0 || i >= len(s.keys) || s.keys[i] != key {
Expand Down
23 changes: 23 additions & 0 deletions pkg/lm/scorer_next.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package lm

// ScorerNext represents the entity that responses for scoring the word using the parent context
type ScorerNext interface {
// ScoreNext calculates the score for the given nGram built on the parent context
ScoreNext(nGram WordID) float64
}

type scorerNext struct {
contextCounts []WordCount
context ContextOffset
nGramVector NGramVector
}

func (s *scorerNext) ScoreNext(nGram WordID) float64 {
count, _ := s.nGramVector.GetCount(nGram, s.context)

if count == 0 {
return UnknownWordScore
}

return calcScore(append(s.contextCounts, count))
}
33 changes: 4 additions & 29 deletions pkg/spellchecker/collector.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,29 +4,25 @@ import (
"github.com/suggest-go/suggest/pkg/lm"
"github.com/suggest-go/suggest/pkg/merger"
"github.com/suggest-go/suggest/pkg/suggest"
"sort"
)

// lmCollector implements Collector interface
type lmCollector struct {
topKQueue suggest.TopKQueue
scorer suggest.Scorer
next []lm.WordCount
}

// lmCollectorManager implements CollectorManager interface
type lmCollectorManager struct {
topK int
scorer *lmScorer
next []lm.WordCount
scorer suggest.Scorer
}

// Create creates a new collector that will be used for a search segment
func (l *lmCollectorManager) Create() (suggest.Collector, error) {
return &lmCollector{
topKQueue: suggest.NewTopKQueue(l.topK),
scorer: l.scorer,
next: l.next,
}, nil
}

Expand All @@ -47,7 +43,7 @@ func (l *lmCollectorManager) Reduce(collectors []suggest.Collector) []suggest.Ca
func (c *lmCollector) Collect(item merger.MergeCandidate) error {
doc := item.Position()

if len(c.next) == 0 {
if c.scorer == nil {
if c.topKQueue.IsFull() {
return merger.ErrCollectionTerminated
}
Expand All @@ -57,29 +53,8 @@ func (c *lmCollector) Collect(item merger.MergeCandidate) error {
return nil
}

if c.next[0] > item.Position() {
c.topKQueue.Add(doc, lm.UnknownWordScore)

return nil
}

if c.next[len(c.next)-1] < item.Position() {
c.topKQueue.Add(doc, lm.UnknownWordScore)
c.next = c.next[:0]

return nil
}

i := sort.Search(len(c.next), func(i int) bool { return c.next[i] >= doc })
c.next = c.next[i:]

if c.next[0] != doc {
c.topKQueue.Add(doc, lm.UnknownWordScore)

return nil
}

c.topKQueue.Add(doc, c.scorer.Score(item))
score := c.scorer.Score(item)
c.topKQueue.Add(doc, score)

return nil
}
Expand Down
5 changes: 2 additions & 3 deletions pkg/spellchecker/scorer.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@ import (

// lmScorer implements the scorer interface
type lmScorer struct {
model lm.LanguageModel
sentence []lm.WordID
scorer lm.ScorerNext
}

// Score returns the score of the given position
Expand All @@ -20,7 +19,7 @@ func (s *lmScorer) Score(candidate merger.MergeCandidate) float64 {

// score returns the lm score for the given word ID
func (s *lmScorer) score(id lm.WordID) float64 {
return s.model.ScoreWordIDs(append(s.sentence, id))
return s.scorer.ScoreNext(id)
}

// sortCandidates performs sort of the given candidates using lm
Expand Down
Loading

0 comments on commit fc5e1ea

Please sign in to comment.