diff --git a/src/rime/gear/contextual_translation.cc b/src/rime/gear/contextual_translation.cc new file mode 100644 index 0000000000..6266195ec9 --- /dev/null +++ b/src/rime/gear/contextual_translation.cc @@ -0,0 +1,60 @@ +#include +#include +#include +#include + +namespace rime { + +const int kContextualSearchLimit = 32; + +bool ContextualTranslation::Replenish() { + vector> queue; + size_t end_pos = 0; + while (!translation_->exhausted() && + cache_.size() + queue.size() < kContextualSearchLimit) { + auto cand = translation_->Peek(); + DLOG(INFO) << cand->text() << " cache/queue: " + << cache_.size() << "/" << queue.size(); + if (cand->type() == "phrase" || cand->type() == "table") { + if (end_pos != cand->end()) { + end_pos = cand->end(); + AppendToCache(queue); + } + queue.push_back(Evaluate(As(cand))); + } else { + AppendToCache(queue); + cache_.push_back(cand); + } + if (!translation_->Next()) { + break; + } + } + AppendToCache(queue); + return !cache_.empty(); +} + +an ContextualTranslation::Evaluate(an phrase) { + auto sentence = New(phrase->language()); + sentence->Offset(phrase->start()); + bool is_rear = phrase->end() == input_.length(); + sentence->Extend(phrase->entry(), phrase->end(), is_rear, preceding_text_, + grammar_); + phrase->set_weight(sentence->weight()); + DLOG(INFO) << "contextual suggestion: " << phrase->text() + << " weight: " << phrase->weight(); + return phrase; +} + +static bool compare_by_weight_desc(const an& a, const an& b) { + return a->weight() > b->weight(); +} + +void ContextualTranslation::AppendToCache(vector>& queue) { + if (queue.empty()) return; + DLOG(INFO) << "appending to cache " << queue.size() << " candidates."; + std::sort(queue.begin(), queue.end(), compare_by_weight_desc); + std::copy(queue.begin(), queue.end(), std::back_inserter(cache_)); + queue.clear(); +} + +} // namespace rime diff --git a/src/rime/gear/contextual_translation.h b/src/rime/gear/contextual_translation.h new file mode 100644 index 0000000000..e817ea4c23 --- /dev/null +++ b/src/rime/gear/contextual_translation.h @@ -0,0 +1,38 @@ +// +// Copyright RIME Developers +// Distributed under the BSD License +// + +#include +#include + +namespace rime { + +class Candidate; +class Grammar; +class Phrase; + +class ContextualTranslation : public PrefetchTranslation { + public: + ContextualTranslation(an translation, + string input, + string preceding_text, + Grammar* grammar) + : PrefetchTranslation(translation), + input_(input), + preceding_text_(preceding_text), + grammar_(grammar) {} + + protected: + bool Replenish() override; + + private: + an Evaluate(an phrase); + void AppendToCache(vector>& queue); + + string input_; + string preceding_text_; + Grammar* grammar_; +}; + +} // namespace rime diff --git a/src/rime/gear/poet.h b/src/rime/gear/poet.h index bbd10c2f47..1b1dd44dfa 100644 --- a/src/rime/gear/poet.h +++ b/src/rime/gear/poet.h @@ -11,8 +11,10 @@ #define RIME_POET_H_ #include +#include #include #include +#include namespace rime { @@ -39,6 +41,19 @@ class Poet { size_t total_length, const string& preceding_text); + template + an ContextualWeighted(an translation, + const string& input, + TranslatorT* translator) { + if (!translator->contextual_suggestions() || !grammar_) { + return translation; + } + return New(translation, + input, + translator->GetPrecedingText(), + grammar_.get()); + } + private: const Language* language_; the grammar_; diff --git a/src/rime/gear/script_translator.cc b/src/rime/gear/script_translator.cc index ce3cfcae1b..33c36c16a6 100644 --- a/src/rime/gear/script_translator.cc +++ b/src/rime/gear/script_translator.cc @@ -194,7 +194,11 @@ an ScriptTranslator::Query(const string& input, enable_user_dict ? user_dict_.get() : NULL)) { return nullptr; } - return New(result); + auto deduped = New(result); + if (contextual_suggestions_) { + return poet_->ContextualWeighted(deduped, input, this); + } + return deduped; } string ScriptTranslator::FormatPreedit(const string& preedit) { diff --git a/src/rime/gear/table_translator.cc b/src/rime/gear/table_translator.cc index 24c8ae6a37..e21236e7a4 100644 --- a/src/rime/gear/table_translator.cc +++ b/src/rime/gear/table_translator.cc @@ -220,7 +220,8 @@ TableTranslator::TableTranslator(const Ticket& ticket) &max_phrase_length_); config->GetInt(name_space_ + "/max_homographs", &max_homographs_); - if (enable_sentence_ || sentence_over_completion_) { + if (enable_sentence_ || sentence_over_completion_ || + contextual_suggestions_) { poet_.reset(new Poet(language(), config, Poet::LeftAssociateCompare)); } } @@ -306,11 +307,12 @@ an TableTranslator::Query(const string& input, translation = sentence + translation; } } - if (translation) { - translation = New(translation); - } if (translation && translation->exhausted()) { - translation.reset(); // discard futile translation + return nullptr; + } + translation = New(translation); + if (contextual_suggestions_) { + return poet_->ContextualWeighted(translation, input, this); } return translation; } diff --git a/src/rime/gear/translator_commons.h b/src/rime/gear/translator_commons.h index 2f2efff5c0..343fd74859 100644 --- a/src/rime/gear/translator_commons.h +++ b/src/rime/gear/translator_commons.h @@ -91,8 +91,8 @@ class Phrase : public Candidate { void set_syllabifier(an syllabifier) { syllabifier_ = syllabifier; } - double weight() const { return entry_->weight; } + void set_weight(double weight) { entry_->weight = weight; } Code& code() const { return entry_->code; } const DictEntry& entry() const { return *entry_; } const Language* language() const { return language_; }