Skip to content

Commit

Permalink
Spellchecker: slightly cleaned up the code
Browse files Browse the repository at this point in the history
  • Loading branch information
alldroll committed Jul 28, 2020
1 parent 1eb7542 commit 7a42b21
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 39 deletions.
8 changes: 8 additions & 0 deletions pkg/spellchecker/collector.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,14 @@ type lmCollector struct {
scorer suggest.Scorer
}

// newCollectorManager creates a new instance of lm CollectorManger.
func newCollectorManager(scorer suggest.Scorer, topK int) suggest.CollectorManager {
return &lmCollectorManager{
topK: topK,
scorer: scorer,
}
}

// lmCollectorManager implements CollectorManager interface
type lmCollectorManager struct {
topK int
Expand Down
25 changes: 18 additions & 7 deletions pkg/spellchecker/scorer.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,11 @@ import (
"github.com/suggest-go/suggest/pkg/lm"
"github.com/suggest-go/suggest/pkg/merger"
"github.com/suggest-go/suggest/pkg/suggest"
"sort"
)

// lmScorer implements the scorer interface
type lmScorer struct {
scorer lm.ScorerNext
scorer lm.ScorerNext
}

// Score returns the score of the given position
Expand All @@ -22,9 +21,21 @@ func (s *lmScorer) score(id lm.WordID) float64 {
return s.scorer.ScoreNext(id)
}

// sortCandidates performs sort of the given candidates using lm
func sortCandidates(scorer *lmScorer, candidates []suggest.Candidate) {
sort.SliceStable(candidates, func(i, j int) bool {
return scorer.score(candidates[i].Key) > scorer.score(candidates[j].Key)
})
type dummyScorer struct {
}

// Score returns the score of the given position
func (s *dummyScorer) Score(candidate merger.MergeCandidate) float64 {
return lm.UnknownWordScore
}

// newScorer creates a scorer for the provided lm.ScorerNext
func newScorer(next lm.ScorerNext) suggest.Scorer {
if next == nil {
return &dummyScorer{}
}

return &lmScorer{
scorer: next,
}
}
68 changes: 36 additions & 32 deletions pkg/spellchecker/spellchecker.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
// Package spellchecker provides spellcheck functionality
package spellchecker

// TODO add tests!!

import (
"sort"

"github.com/suggest-go/suggest/pkg/analysis"
"github.com/suggest-go/suggest/pkg/dictionary"
"github.com/suggest-go/suggest/pkg/lm"
Expand Down Expand Up @@ -41,12 +45,13 @@ func (s *SpellChecker) Predict(query string, topK int, similarity float64) ([]st
}

word, seq := tokens[len(tokens)-1], tokens[:len(tokens)-1]
collectorManager, err := s.createCollectorManager(seq, topK)
scorerNext, err := scorerNext(s.model, seq)

if err != nil {
return nil, err
}

collectorManager := newCollectorManager(newScorer(scorerNext), topK)
candidates, err := s.index.Autocomplete(word, collectorManager)

if err != nil {
Expand All @@ -56,7 +61,7 @@ func (s *SpellChecker) Predict(query string, topK int, similarity float64) ([]st
if len(candidates) < topK {
config, err := suggest.NewSearchConfig(
word,
topK-len(candidates),
topK,
metric.CosineMetric(),
similarity,
)
Expand All @@ -74,55 +79,54 @@ func (s *SpellChecker) Predict(query string, topK int, similarity float64) ([]st
candidates = merge(candidates, fuzzyCandidates)
}

scorer, ok := collectorManager.scorer.(*lmScorer)

if len(seq) > 0 && ok {
sortCandidates(scorer, candidates)
if scorerNext != nil {
sortCandidates(scorerNext, candidates)
}

result := make([]string, 0, len(candidates))

for _, c := range candidates {
val, err := s.dict.Get(c.Key)

if err != nil {
return nil, err
}

result = append(result, val)
if topK < len(candidates) {
candidates = candidates[:topK+1]
}

return result, nil
return retrieveValues(s.dict, candidates)
}

// createScorer creates scorer for the given sentence
func (s *SpellChecker) createCollectorManager(seq []string, topK int) (*lmCollectorManager, error) {
seqIds, err := lm.MapIntoListOfWordIDs(s.model, seq)
// scorerNext creates lm.ScorerNext for the provided sentence
func scorerNext(model lm.LanguageModel, seq lm.Sentence) (next lm.ScorerNext, err error) {
seqIds, err := lm.MapIntoListOfWordIDs(model, seq)

if err != nil {
return nil, err
}

var scorer suggest.Scorer

if len(seqIds) > 0 {
next, err := s.model.Next(seqIds)
next, err = model.Next(seqIds)
}

return
}

// retrieveValues fetches the corresponding values from the dictionary for the provided candidates.
func retrieveValues(dict dictionary.Dictionary, candidates []suggest.Candidate) ([]string, error) {
result := make([]string, 0, len(candidates))

for _, c := range candidates {
val, err := dict.Get(c.Key)

if err != nil {
return nil, err
}

if next != nil {
scorer = &lmScorer{
scorer: next,
}
}
result = append(result, val)
}

return &lmCollectorManager{
topK: topK,
scorer: scorer,
}, nil
return result, nil
}

// sortCandidates performs sort of the given candidates using lm
func sortCandidates(scorer lm.ScorerNext, candidates []suggest.Candidate) {
sort.SliceStable(candidates, func(i, j int) bool {
return scorer.ScoreNext(candidates[i].Key) > scorer.ScoreNext(candidates[j].Key)
})
}

// merge merges the 2 candidates sets into one without duplication
Expand Down

0 comments on commit 7a42b21

Please sign in to comment.