From f2f991096bb1f59f46d40d7ec45f54126484a19f Mon Sep 17 00:00:00 2001 From: Michael Ringgaard Date: Sat, 9 Nov 2019 11:31:12 +0100 Subject: [PATCH] Write parser model after training (#425) --- python/task/wiki.py | 2 +- sling/nlp/document/affix.cc | 2 +- sling/nlp/document/document.cc | 35 ---------- sling/nlp/document/document.h | 67 ------------------ sling/nlp/document/lex.cc | 7 +- sling/nlp/document/lexical-encoder.cc | 2 +- sling/nlp/parser/BUILD | 1 + sling/nlp/parser/action-table.cc | 97 +++++++++++++------------- sling/nlp/parser/action-table.h | 3 + sling/nlp/parser/cascade.cc | 1 - sling/nlp/parser/caspar-trainer.cc | 61 ++++++++++++++++ sling/nlp/parser/parser-trainer.cc | 77 +++++++++++++++++--- sling/nlp/parser/parser-trainer.h | 16 ++++- sling/nlp/parser/tools/train.sh | 2 +- sling/nlp/parser/tools/train_caspar.py | 12 ++++ sling/task/learner.cc | 1 + 16 files changed, 216 insertions(+), 170 deletions(-) diff --git a/python/task/wiki.py b/python/task/wiki.py index 2e97bb62..d81e468a 100644 --- a/python/task/wiki.py +++ b/python/task/wiki.py @@ -116,7 +116,7 @@ def wikidata_import(self, input, name=None): def wikidata(self, dump=None): """Import Wikidata dump to frame format. It takes a Wikidata dump in JSON - format as inpput and converts each item and property to a SLING frame. + format as input and converts each item and property to a SLING frame. Returns the item and property output files.""" if dump == None: dump = self.wikidata_dump() with self.wf.namespace("wikidata"): diff --git a/sling/nlp/document/affix.cc b/sling/nlp/document/affix.cc index dd3333a6..78f08252 100644 --- a/sling/nlp/document/affix.cc +++ b/sling/nlp/document/affix.cc @@ -122,7 +122,7 @@ void AffixTable::Write(OutputStream *stream) const { output.WriteVarint32(affix->form().size()); output.Write(affix->form()); output.WriteVarint32(affix->length()); - if (affix->length() > 1) { + if (affix->length() > 0) { CHECK(affix->shorter() != nullptr); output.WriteVarint32(affix->shorter()->id()); } diff --git a/sling/nlp/document/document.cc b/sling/nlp/document/document.cc index 6ed624ce..8b4794aa 100644 --- a/sling/nlp/document/document.cc +++ b/sling/nlp/document/document.cc @@ -46,12 +46,10 @@ CaseForm Token::Form() const { void Span::Evoke(const Frame &frame) { mention_.Add(document_->names_->n_evokes, frame); - document_->AddMention(frame.handle(), this); } void Span::Evoke(Handle frame) { mention_.Add(document_->names_->n_evokes, frame); - document_->AddMention(frame, this); } void Span::Replace(Handle existing, Handle replacement) { @@ -59,9 +57,7 @@ void Span::Replace(Handle existing, Handle replacement) { FrameDatum *mention = mention_.store()->GetFrame(mention_.handle()); for (Slot *slot = mention->begin(); slot < mention->end(); ++slot) { if (slot->name == n_evokes && slot->value == existing) { - document_->RemoveMention(existing, this); slot->value = replacement; - document_->AddMention(replacement, this); return; } } @@ -268,9 +264,6 @@ Document::Document(const Frame &top, const DocumentNames *names) Span *span = Insert(begin, end); CHECK(span != nullptr) << "Crossing span: " << begin << "," << end; span->mention_ = Frame(store(), mention->self); - for (const Slot &s : span->mention_) { - if (s.name == names_->n_evokes) AddMention(s.value, span); - } } else if (slot->name == names_->n_theme.handle()) { // Add thematic frame. themes_.push_back(slot->value); @@ -305,9 +298,6 @@ Document::Document(const Document &other, bool annotations) for (const Span *s : other.spans_) { Span *span = Insert(s->begin_, s->end_); span->mention_ = Frame(store, store->Clone(s->mention_.handle())); - for (const Slot &s : span->mention_) { - if (s.name == names_->n_evokes) AddMention(s.value, span); - } } // Copy themes. @@ -366,9 +356,6 @@ Document::Document(const Document &other, if (b < 0 || e > length) continue; Span *span = Insert(b, e); span->mention_ = Frame(store, store->Clone(s->mention_.handle())); - for (const Slot &s : span->mention_) { - if (s.name == names_->n_evokes) AddMention(s.value, span); - } } } @@ -511,13 +498,6 @@ void Document::DeleteSpan(Span *span) { // Remove span from span index. Remove(span); - // Remove all evoked frames from mention table. - for (const Slot &slot : span->mention_) { - if (slot.name == names_->n_evokes) { - RemoveMention(slot.value, span); - } - } - // Clear the reference to the mention frame. This will mark the span as // deleted. span->mention_ = Frame::nil(); @@ -537,20 +517,6 @@ void Document::AddExtra(Handle name, Handle value) { extras_->emplace_back(name, value); } -void Document::AddMention(Handle handle, Span *span) { - mentions_.emplace(handle, span); -} - -void Document::RemoveMention(Handle handle, Span *span) { - auto interval = mentions_.equal_range(handle); - for (auto it = interval.first; it != interval.second; ++it) { - if (it->second == span) { - mentions_.erase(it); - break; - } - } -} - int Document::Locate(int position) const { int index = 0; int len = tokens_.size(); @@ -781,7 +747,6 @@ void Document::ClearAnnotations() { for (Token &t : tokens_) t.span_ = nullptr; for (Span *s : spans_) delete s; spans_.clear(); - mentions_.clear(); themes_.clear(); } diff --git a/sling/nlp/document/document.h b/sling/nlp/document/document.h index 8876e866..36d0f512 100644 --- a/sling/nlp/document/document.h +++ b/sling/nlp/document/document.h @@ -374,62 +374,6 @@ class Document { AddExtra(name.handle(), store()->AllocateString(value)); } - // Types for mapping from frame to spans that evoke it. - typedef std::unordered_multimap MentionMap; - typedef std::pair - ConstMentionIteratorPair; - typedef std::pair - MentionIteratorPair; - - // Iterator adapters for mention ranges. - class ConstMentionRange { - public: - explicit ConstMentionRange(const ConstMentionIteratorPair &interval) - : interval_(interval) {} - MentionMap::const_iterator begin() const { return interval_.first; } - MentionMap::const_iterator end() const { return interval_.second; } - - private: - ConstMentionIteratorPair interval_; - }; - - class MentionRange { - public: - explicit MentionRange(const MentionIteratorPair &interval) - : interval_(interval) {} - MentionMap::iterator begin() { return interval_.first; } - MentionMap::iterator end() { return interval_.second; } - - private: - MentionIteratorPair interval_; - }; - - // Iterates over all spans that evoke a frame, e.g.: - // for (const auto &it : document.EvokingSpans(h)) { - // Span *s = it.second; - // } - ConstMentionRange EvokingSpans(Handle handle) const { - return ConstMentionRange(mentions_.equal_range(handle)); - } - ConstMentionRange EvokingSpans(const Frame &frame) const { - return ConstMentionRange(mentions_.equal_range(frame.handle())); - } - - MentionRange EvokingSpans(Handle handle) { - return MentionRange(mentions_.equal_range(handle)); - } - MentionRange EvokingSpans(const Frame &frame) { - return MentionRange(mentions_.equal_range(frame.handle())); - } - - // Returns the number of spans evoking a frame. - int EvokingSpanCount(Handle handle) { - return mentions_.count(handle); - } - int EvokingSpanCount(const Frame &frame) { - return mentions_.count(frame.handle()); - } - // Clears annotations (mentions and themes) from document. void ClearAnnotations(); @@ -445,12 +389,6 @@ class Document { // Removes the span from the span index. void Remove(Span *span); - // Adds frame to mention mapping. - void AddMention(Handle handle, Span *span); - - // Removes frame from mention mapping. - void RemoveMention(Handle handle, Span *span); - // Document frame. Frame top_; @@ -474,11 +412,6 @@ class Document { // Additional slots that should be added to document. Slots *extras_ = nullptr; - // Inverse mapping from frames to spans that can be used for looking up all - // mentions of a frame. The handles are tracked by the mention frame in the - // span. - MentionMap mentions_; - // Document symbol names. const DocumentNames *names_; diff --git a/sling/nlp/document/lex.cc b/sling/nlp/document/lex.cc index 2ca7d1ca..7655ff1a 100644 --- a/sling/nlp/document/lex.cc +++ b/sling/nlp/document/lex.cc @@ -133,6 +133,7 @@ bool DocumentLexer::Lex(Document *document, Text lex) const { if (objects.size() != current_object + 1) return false; // Add mentions to document. + HandleSet added; for (auto &m : markables) { int begin = document->Locate(m.begin); int end = document->Locate(m.end); @@ -140,7 +141,9 @@ bool DocumentLexer::Lex(Document *document, Text lex) const { if (m.object != -1) { Array evoked(store, objects[m.object]); for (int i = 0; i < evoked.length(); ++i) { - span->Evoke(evoked.get(i)); + Handle frame = evoked.get(i); + span->Evoke(frame); + added.insert(frame); } } } @@ -148,7 +151,7 @@ bool DocumentLexer::Lex(Document *document, Text lex) const { // Add thematic frames. Do not add frames that are evoked by spans. for (int theme : themes) { Handle frame = objects[theme]; - if (document->EvokingSpanCount(frame) == 0) { + if (added.count(frame) == 0) { document->AddTheme(frame); } } diff --git a/sling/nlp/document/lexical-encoder.cc b/sling/nlp/document/lexical-encoder.cc index 1e5c8573..e7a54451 100644 --- a/sling/nlp/document/lexical-encoder.cc +++ b/sling/nlp/document/lexical-encoder.cc @@ -54,7 +54,7 @@ void LexicalFeatures::LoadLexicon(Flow *flow) { void LexicalFeatures::SaveLexicon(myelin::Flow *flow) const { // Save word vocabulary. Flow::Blob *vocabulary = flow->AddBlob("lexicon", "dict"); - vocabulary->SetAttr("delimiter", 10); + vocabulary->SetAttr("delimiter", 0); vocabulary->SetAttr("oov", lexicon_.oov()); auto normalization = lexicon_.normalization(); vocabulary->SetAttr("normalization", NormalizationString(normalization)); diff --git a/sling/nlp/parser/BUILD b/sling/nlp/parser/BUILD index b9df1877..73a4c8bf 100644 --- a/sling/nlp/parser/BUILD +++ b/sling/nlp/parser/BUILD @@ -139,6 +139,7 @@ cc_library( ":roles", "//sling/base", "//sling/frame:store", + "//sling/frame:serialization", "//sling/myelin:builder", "//sling/myelin:compiler", "//sling/myelin:gradient", diff --git a/sling/nlp/parser/action-table.cc b/sling/nlp/parser/action-table.cc index e12b28e2..31b1725c 100644 --- a/sling/nlp/parser/action-table.cc +++ b/sling/nlp/parser/action-table.cc @@ -38,37 +38,40 @@ void ActionTable::Init(Store *store) { CHECK(top.valid()); // Get all the integer fields. - max_actions_per_token_ = top.GetInt("/table/max_actions_per_token"); - frame_limit_ = top.GetInt("/table/frame_limit"); + max_actions_per_token_ = top.GetInt("/table/max_actions_per_token", 5); + frame_limit_ = top.GetInt("/table/frame_limit", 5); // Read the action index. Array actions = top.Get("/table/actions").AsArray(); CHECK(actions.valid()); - Handle action_type = store->LookupExisting("/table/action/type"); - Handle action_length = store->LookupExisting("/table/action/length"); - Handle action_source = store->LookupExisting("/table/action/source"); - Handle action_target = store->LookupExisting("/table/action/target"); - Handle action_role = store->LookupExisting("/table/action/role"); - Handle action_label = store->LookupExisting("/table/action/label"); + Handle n_type = store->Lookup("/table/action/type"); + Handle n_length = store->Lookup("/table/action/length"); + Handle n_source = store->Lookup("/table/action/source"); + Handle n_target = store->Lookup("/table/action/target"); + Handle n_role = store->Lookup("/table/action/role"); + Handle n_label = store->Lookup("/table/action/label"); + Handle n_delegate = store->Lookup("/table/action/delegate"); for (int i = 0; i < actions.length(); ++i) { ParserAction action; Frame item(store, actions.get(i)); CHECK(item.valid()); for (const Slot &slot : item) { - if (slot.name == action_type) { + if (slot.name == n_type) { action.type = static_cast(slot.value.AsInt()); - } else if (slot.name == action_length) { + } else if (slot.name == n_length) { action.length = slot.value.AsInt(); - } else if (slot.name == action_source) { + } else if (slot.name == n_source) { action.source = slot.value.AsInt(); - } else if (slot.name == action_target) { + } else if (slot.name == n_target) { action.target = slot.value.AsInt(); - } else if (slot.name == action_role) { + } else if (slot.name == n_role) { action.role = slot.value; - } else if (slot.name == action_label) { + } else if (slot.name == n_label) { action.label = slot.value; + } else if (slot.name == n_delegate) { + action.delegate = slot.value.AsInt(); } } @@ -82,66 +85,62 @@ void ActionTable::Save(const Store *global, const string &file) const { } string ActionTable::Serialize(const Store *global) const { + // Build frame with action table. Store store(global); - Builder top(&store); - top.AddId("/table"); + Builder table(&store); + table.AddId("/table"); + Write(&table); + StringEncoder encoder(&store); + encoder.Encode(table.Create()); + return encoder.buffer(); +} + +void ActionTable::Write(Builder *frame) const { // Save the action table. - Handle action_type = store.Lookup("/table/action/type"); - Handle action_length = store.Lookup("/table/action/length"); - Handle action_source = store.Lookup("/table/action/source"); - Handle action_target = store.Lookup("/table/action/target"); - Handle action_role = store.Lookup("/table/action/role"); - Handle action_label = store.Lookup("/table/action/label"); - - Array actions(&store, actions_.size()); + Store *store = frame->store(); + Handle n_type = store->Lookup("/table/action/type"); + Handle n_length = store->Lookup("/table/action/length"); + Handle n_source = store->Lookup("/table/action/source"); + Handle n_target = store->Lookup("/table/action/target"); + Handle n_role = store->Lookup("/table/action/role"); + Handle n_label = store->Lookup("/table/action/label"); + Handle n_delegate = store->Lookup("/table/action/delegate"); + + Array actions(store, actions_.size()); int index = 0; for (const ParserAction &action : actions_) { auto type = action.type; - Builder b(&store); - b.Add(action_type, static_cast(type)); + Builder b(store); + b.Add(n_type, static_cast(type)); if (type == ParserAction::REFER || type == ParserAction::EVOKE) { if (action.length > 0) { - b.Add(action_length, static_cast(action.length)); + b.Add(n_length, static_cast(action.length)); } } if (type == ParserAction::ASSIGN || type == ParserAction::ELABORATE || type == ParserAction::CONNECT) { if (action.source != 0) { - b.Add(action_source, static_cast(action.source)); + b.Add(n_source, static_cast(action.source)); } } if (type == ParserAction::EMBED || type == ParserAction::REFER || type == ParserAction::CONNECT) { if (action.target != 0) { - b.Add(action_target, static_cast(action.target)); + b.Add(n_target, static_cast(action.target)); } } - if (!action.role.IsNil()) b.Add(action_role, action.role); - if (!action.label.IsNil()) b.Add(action_label, action.label); + if (type == ParserAction::CASCADE) { + b.Add(n_delegate, static_cast(action.delegate)); + } + if (!action.role.IsNil()) b.Add(n_role, action.role); + if (!action.label.IsNil()) b.Add(n_label, action.label); actions.set(index++, b.Create().handle()); } - top.Add("/table/actions", actions); - - // Add artificial links to symbols used in serialization. This is needed as - // some action types might be unseen, so their corresponding symbols won't be - // serialized. However we still want handles to them during Load(). - // For example, if we have only seen EVOKE, SHIFT, and STOP actions, then - // the symbol /table/fp/refer for REFER won't be serialized unless the table - // links to it. - std::vector symbols = { - action_type, action_length, action_source, action_target, - action_role, action_label - }; - Array symbols_array(&store, symbols); - top.Add("/table/symbols", symbols_array); - - StringEncoder encoder(&store); - encoder.Encode(top.Create()); - return encoder.buffer(); + frame->Add("/table/actions", actions); } } // namespace nlp diff --git a/sling/nlp/parser/action-table.h b/sling/nlp/parser/action-table.h index 8b76a358..90dece81 100644 --- a/sling/nlp/parser/action-table.h +++ b/sling/nlp/parser/action-table.h @@ -54,6 +54,9 @@ class ActionTable { // Returns the serialization of the table. string Serialize(const Store *global) const; + // Write action table in frame. + void Write(Builder *frame) const; + // Initialize the action table from store. void Init(Store *store); diff --git a/sling/nlp/parser/cascade.cc b/sling/nlp/parser/cascade.cc index 8b8d8f38..09624084 100644 --- a/sling/nlp/parser/cascade.cc +++ b/sling/nlp/parser/cascade.cc @@ -170,6 +170,5 @@ void CascadeInstance::Compute(myelin::Channel *activations, } } - } // namespace nlp } // namespace sling diff --git a/sling/nlp/parser/caspar-trainer.cc b/sling/nlp/parser/caspar-trainer.cc index 849d4db3..d50a80c4 100644 --- a/sling/nlp/parser/caspar-trainer.cc +++ b/sling/nlp/parser/caspar-trainer.cc @@ -74,6 +74,58 @@ class MultiClassDelegateLearner : public DelegateLearner { return new DelegateInstance(this); } + void Save(Flow *flow, Builder *data) override { + // Save delegate type. + data->Add("name", name_); + data->Add("runtime", "SoftmaxDelegate"); + data->Add("cell", cell_->name()); + + // Save the action table. + Store *store = data->store(); + Handle n_type = store->Lookup("/table/action/type"); + Handle n_length = store->Lookup("/table/action/length"); + Handle n_source = store->Lookup("/table/action/source"); + Handle n_target = store->Lookup("/table/action/target"); + Handle n_role = store->Lookup("/table/action/role"); + Handle n_label = store->Lookup("/table/action/label"); + Handle n_delegate = store->Lookup("/table/action/delegate"); + + Array actions(store, actions_.size()); + int index = 0; + for (const ParserAction &action : actions_.list()) { + auto type = action.type; + Builder b(store); + b.Add(n_type, static_cast(type)); + + if (type == ParserAction::REFER || type == ParserAction::EVOKE) { + if (action.length > 0) { + b.Add(n_length, static_cast(action.length)); + } + } + if (type == ParserAction::ASSIGN || + type == ParserAction::ELABORATE || + type == ParserAction::CONNECT) { + if (action.source != 0) { + b.Add(n_source, static_cast(action.source)); + } + } + if (type == ParserAction::EMBED || + type == ParserAction::REFER || + type == ParserAction::CONNECT) { + if (action.target != 0) { + b.Add(n_target, static_cast(action.target)); + } + } + if (type == ParserAction::CASCADE) { + b.Add(n_delegate, static_cast(action.delegate)); + } + if (!action.role.IsNil()) b.Add(n_role, action.role); + if (!action.label.IsNil()) b.Add(n_label, action.label); + actions.set(index++, b.Create().handle()); + } + data->Add("actions", actions); + } + // Multi-class delegate instance. class DelegateInstance : public DelegateLearnerInstance { public: @@ -229,6 +281,15 @@ class CasparTrainer : public ParserTrainer { }); } + // Save action table in model. + void SaveModel(Flow *flow, Store *store) override { + // Save action table in store. + Builder table(store); + table.AddId("/table"); + actions_.Write(&table); + table.Create(); + } + private: // Parser actions. ActionTable actions_; diff --git a/sling/nlp/parser/parser-trainer.cc b/sling/nlp/parser/parser-trainer.cc index 43a480c5..ecac96bd 100644 --- a/sling/nlp/parser/parser-trainer.cc +++ b/sling/nlp/parser/parser-trainer.cc @@ -16,6 +16,7 @@ #include +#include "sling/frame/serialization.h" #include "sling/myelin/gradient.h" #include "sling/nlp/document/document.h" #include "sling/nlp/document/lexicon.h" @@ -67,6 +68,12 @@ void ParserTrainer::Run(task::Task *task) { evaluation_corpus_ = new DocumentCorpus(&commons_, task->GetInputFiles("evaluation_corpus")); + // Output file for model. + auto *model_file = task->GetOutput("model"); + if (model_file != nullptr) { + model_filename_ = model_file->resource()->name(); + } + // Set up encoder lexicon. string normalization = task->Get("normalization", "d"); spec_.lexicon.normalization = ParseNormalization(normalization); @@ -97,15 +104,15 @@ void ParserTrainer::Run(task::Task *task) { Setup(task); // Build parser model flow graph. - BuildFlow(&flow_, true); + Build(&flow_, true); optimizer_ = GetOptimizer(task); optimizer_->Build(&flow_); // Compile model. - compiler_.Compile(&flow_, &net_); + compiler_.Compile(&flow_, &model_); // Get decoder cells and tensors. - decoder_ = net_.GetCell("ff_trunk"); + decoder_ = model_.GetCell("ff_trunk"); activations_ = decoder_->GetParameter("ff_trunk/steps"); gdecoder_ = decoder_->Gradient(); primal_ = decoder_->Primal(); @@ -115,17 +122,23 @@ void ParserTrainer::Run(task::Task *task) { drl_ = gdecoder_->GetParameter("gradients/ff_trunk/d_rl_lstm"); // Initialize model. - feature_model_.Init(net_.GetCell("ff_trunk"), + feature_model_.Init(model_.GetCell("ff_trunk"), flow_.DataBlock("spec"), &roles_, frame_limit_); - net_.InitLearnableWeights(seed_, 0.0, 0.01); - encoder_.Initialize(net_); - optimizer_->Initialize(net_); - for (auto *d : delegates_) d->Initialize(net_); + model_.InitLearnableWeights(seed_, 0.0, 0.01); + encoder_.Initialize(model_); + optimizer_->Initialize(model_); + for (auto *d : delegates_) d->Initialize(model_); commons_.Freeze(); // Train model. - Train(task, &net_); + Train(task, &model_); + + // Save final model. + if (!model_filename_.empty()) { + LOG(INFO) << "Writing parser model to " << model_filename_; + Save(model_filename_); + } // Clean up. delete optimizer_; @@ -361,9 +374,13 @@ bool ParserTrainer::Evaluate(int64 epoch, Network *model) { } void ParserTrainer::Checkpoint(int64 epoch, Network *model) { + if (!model_filename_.empty()) { + LOG(INFO) << "Checkpointing model to " << model_filename_; + Save(model_filename_); + } } -void ParserTrainer::BuildFlow(Flow *flow, bool learn) { +void ParserTrainer::Build(Flow *flow, bool learn) { // Build document input encoder. BiLSTM::Outputs lstm; if (learn) { @@ -433,6 +450,7 @@ void ParserTrainer::BuildFlow(Flow *flow, bool learn) { bins.append(std::to_string(d)); } spec->SetAttr("mark_distance_bins", bins); + spec->SetAttr("frame_limit", frame_limit_); // Concatenate mapped feature inputs. auto *fv = f.Concat(features); @@ -490,6 +508,45 @@ Document *ParserTrainer::GetNextTrainingDocument(Store *store) { return document; } +void ParserTrainer::Save(const string &filename) { + // Build model. + Flow flow; + Build(&flow, false); + + // Copy weights from trained model. + model_.SaveLearnedWeights(&flow); + + // Save lexicon. + encoder_.SaveLexicon(&flow); + + // Save extra model data in store. + Store store(&commons_); + SaveModel(&flow, &store); + + // Save delegates. + Builder cascade(&store); + cascade.AddId("/cascade"); + Array delegates(&store, delegates_.size()); + for (int i = 0; i < delegates_.size(); ++i) { + Builder data(&store); + delegates_[i]->Save(&flow, &data); + delegates.set(i, data.Create().handle()); + } + cascade.Add("delegates", delegates); + cascade.Create(); + + // Save store in flow. + StringEncoder encoder(&store); + encoder.EncodeAll(); + + Flow::Blob *blob = flow.AddBlob("commons", "frames"); + blob->data = flow.AllocateMemory(encoder.buffer()); + blob->size = encoder.buffer().size(); + + // Save model to file. + flow.Save(filename); +} + ParserTrainer::ParserEvaulationCorpus::ParserEvaulationCorpus( ParserTrainer *trainer) : trainer_(trainer) { trainer_->evaluation_corpus_->Rewind(); diff --git a/sling/nlp/parser/parser-trainer.h b/sling/nlp/parser/parser-trainer.h index 08ba7e82..fc4329f6 100644 --- a/sling/nlp/parser/parser-trainer.h +++ b/sling/nlp/parser/parser-trainer.h @@ -55,6 +55,9 @@ class DelegateLearner { // Create instance of delegate. virtual DelegateLearnerInstance *CreateInstance() = 0; + + // Save model data to flow. + virtual void Save(myelin::Flow *flow, Builder *data) = 0; }; // Interface for delegate learner instance. @@ -95,9 +98,12 @@ class ParserTrainer : public task::LearnerTask { virtual void GenerateTransitions(const Document &document, std::vector *transitions) = 0; + // Abstract method for saving extra data in final model. + virtual void SaveModel(myelin::Flow *flow, Store *store) = 0; + private: // Build flow graph for parser model. - void BuildFlow(myelin::Flow *flow, bool learn); + void Build(myelin::Flow *flow, bool learn); // Build linked feature. static myelin::Flow::Variable *LinkedFeature( @@ -113,6 +119,9 @@ class ParserTrainer : public task::LearnerTask { // Parse document using current model. void Parse(Document *document) const; + // Save trained model to file. + void Save(const string &filename); + protected: // Parallel corpus for evaluating parser on golden corpus. class ParserEvaulationCorpus : public ParallelCorpus { @@ -135,6 +144,9 @@ class ParserTrainer : public task::LearnerTask { // Evaluation corpus. DocumentCorpus *evaluation_corpus_ = nullptr; + // File name for trained model. + string model_filename_; + // Word vocabulary. std::unordered_map words_; @@ -146,7 +158,7 @@ class ParserTrainer : public task::LearnerTask { // Neural network. myelin::Flow flow_; - myelin::Network net_; + myelin::Network model_; myelin::Compiler compiler_; myelin::Optimizer *optimizer_ = nullptr; diff --git a/sling/nlp/parser/tools/train.sh b/sling/nlp/parser/tools/train.sh index 77a88975..69365701 100755 --- a/sling/nlp/parser/tools/train.sh +++ b/sling/nlp/parser/tools/train.sh @@ -84,5 +84,5 @@ then echo "Need sudo to link the sling python module.." sudo ln -s $(realpath python) $SLING_SYMLINK fi -stdbuf -o 0 python sling/nlp/parser/tools/train_pytorch.py $ARGS 2>&1 | tee ${LOG_FILE} +stdbuf -o 0 python3 sling/nlp/parser/tools/train_pytorch.py $ARGS 2>&1 | tee ${LOG_FILE} echo "Done. Log is available at ${LOG_FILE}." diff --git a/sling/nlp/parser/tools/train_caspar.py b/sling/nlp/parser/tools/train_caspar.py index 070540a7..92749c0a 100644 --- a/sling/nlp/parser/tools/train_caspar.py +++ b/sling/nlp/parser/tools/train_caspar.py @@ -2,11 +2,14 @@ import sling.flags as flags import sling.task.workflow as workflow +# Start up workflow system. flags.parse() workflow.startup() +# Create worflow. wf = workflow.Workflow("parser-training") +# Parser trainer inputs and outputs. training_corpus = wf.resource( "local/data/corpora/caspar/train_shuffled.rec", format="record/document" @@ -22,6 +25,12 @@ format="embeddings" ) +parser_model = wf.resource( + "local/data/e/caspar/caspar.flow", + format="flow" +) + +# Parser trainer task. trainer = wf.task("caspar-trainer") trainer.add_params({ @@ -38,8 +47,11 @@ trainer.attach_input("training_corpus", training_corpus) trainer.attach_input("evaluation_corpus", evaluation_corpus) trainer.attach_input("word_embeddings", word_embeddings) +trainer.attach_output("model", parser_model) +# Run parser trainer. workflow.run(wf) +# Shut down. workflow.shutdown() diff --git a/sling/task/learner.cc b/sling/task/learner.cc index 4957f4b2..74e178c2 100644 --- a/sling/task/learner.cc +++ b/sling/task/learner.cc @@ -57,6 +57,7 @@ void LearnerTask::Train(Task *task, myelin::Network *model) { std::unique_lock lock(eval_mu_); eval_model_.wait(lock); } + if (done_) break; // Run evaluation. if (!Evaluate(epoch_, model)) {