From d95c7ef6cbc15e56265a3a779e162523a63b3687 Mon Sep 17 00:00:00 2001 From: Michael Ringgaard Date: Sat, 11 Jan 2020 16:42:41 +0100 Subject: [PATCH] RNN stacks (#437) --- doc/guide/caspar.md | 142 +-- doc/guide/myelin.md | 3 +- python/myelin/builder.py | 2 +- python/myelin/flow.py | 22 +- python/myelin/simulator.py | 2 +- sling/base/registry.h | 2 +- sling/file/posix.cc | 2 +- sling/myelin/builder.cc | 102 +-- sling/myelin/builder.h | 10 +- sling/myelin/compiler.cc | 20 + sling/myelin/compute.cc | 321 +++++-- sling/myelin/compute.h | 94 +- sling/myelin/flow.cc | 50 +- sling/myelin/flow.h | 73 +- sling/myelin/generator/elementwise.cc | 72 +- sling/myelin/generator/elementwise.h | 1 + sling/myelin/gradient.cc | 3 +- sling/myelin/graph.cc | 7 +- sling/myelin/kernel/arithmetic.cc | 24 +- sling/myelin/kernel/array.cc | 123 ++- sling/myelin/kernel/avx-math.cc | 4 +- sling/myelin/kernel/cuda-array.cc | 2 +- sling/myelin/kernel/generic-math.cc | 86 +- sling/myelin/kernel/generic-matmul.cc | 2 +- sling/myelin/kernel/generic.cc | 28 +- sling/myelin/kernel/gradients.cc | 39 +- sling/myelin/learning.cc | 5 +- sling/myelin/macro-assembler.cc | 104 ++- sling/myelin/macro-assembler.h | 39 +- sling/myelin/rnn.cc | 943 ++++++++++++++++---- sling/myelin/rnn.h | 366 ++++++-- sling/myelin/simd-assembler.cc | 40 +- sling/myelin/tests/opcheck.py | 7 +- sling/nlp/document/lexical-encoder.cc | 60 +- sling/nlp/document/lexical-encoder.h | 73 +- sling/nlp/embedding/BUILD | 1 - sling/nlp/embedding/embedding-model.cc | 6 +- sling/nlp/embedding/fact-embeddings.cc | 2 +- sling/nlp/embedding/fact-plausibility.cc | 6 +- sling/nlp/parser/BUILD | 49 +- sling/nlp/parser/action-table.cc | 70 +- sling/nlp/parser/action-table.h | 29 +- sling/nlp/parser/cascade.cc | 174 ---- sling/nlp/parser/cascade.h | 145 --- sling/nlp/parser/caspar-trainer.cc | 120 +-- sling/nlp/parser/frame-evaluation.cc | 36 +- sling/nlp/parser/frame-evaluation.h | 14 +- sling/nlp/parser/multiclass-delegate.cc | 69 ++ sling/nlp/parser/parser-action.cc | 18 +- sling/nlp/parser/parser-action.h | 43 +- sling/nlp/parser/parser-features.cc | 188 +--- sling/nlp/parser/parser-features.h | 68 +- sling/nlp/parser/parser-state.cc | 112 +-- sling/nlp/parser/parser-state.h | 24 +- sling/nlp/parser/parser-trainer.cc | 235 ++--- sling/nlp/parser/parser-trainer.h | 41 +- sling/nlp/parser/parser.cc | 144 +-- sling/nlp/parser/parser.h | 55 +- sling/nlp/parser/roles.cc | 21 +- sling/nlp/parser/roles.h | 10 +- sling/nlp/parser/tools/BUILD | 1 + sling/nlp/parser/tools/parse.cc | 3 - sling/nlp/parser/tools/train_caspar.py | 31 +- sling/nlp/parser/trace.cc | 120 --- sling/nlp/parser/trace.h | 76 -- sling/nlp/parser/trainer/pytorch_modules.py | 2 +- sling/nlp/parser/transition-generator.cc | 51 +- sling/pyapi/BUILD | 1 + sling/pyapi/pymyelin.cc | 2 + sling/task/learner.cc | 2 +- 70 files changed, 2680 insertions(+), 2162 deletions(-) delete mode 100644 sling/nlp/parser/cascade.cc delete mode 100644 sling/nlp/parser/cascade.h create mode 100644 sling/nlp/parser/multiclass-delegate.cc delete mode 100644 sling/nlp/parser/trace.cc delete mode 100644 sling/nlp/parser/trace.h diff --git a/doc/guide/caspar.md b/doc/guide/caspar.md index 4a8bf114..4217cd24 100644 --- a/doc/guide/caspar.md +++ b/doc/guide/caspar.md @@ -57,11 +57,23 @@ import sling import sling.flags as flags import sling.task.workflow as workflow -# Start up workflow system. +flags.define("--accurate", default=False,action='store_true') + flags.parse() + +if flags.arg.accurate: + modelfn = "local/data/e/caspar/caspar-accurate.flow" + rnn_layers = 3 + rnn_dim = 192 +else: + modelfn = "local/data/e/caspar/caspar.flow" + rnn_layers = 1 + rnn_dim = 128 + +# Start up workflow system. workflow.startup() -# Create worflow. +# Create workflow. wf = workflow.Workflow("parser-training") # Parser trainer inputs and outputs. @@ -80,23 +92,28 @@ word_embeddings = wf.resource( format="embeddings" ) -parser_model = wf.resource( - "local/data/e/caspar/caspar.flow", - format="flow" -) +parser_model = wf.resource(modelfn, format="flow") # Parser trainer task. trainer = wf.task("caspar-trainer") trainer.add_params({ + "rnn_type": 1, + "rnn_dim": rnn_dim, + "rnn_highways": True, + "rnn_layers": rnn_layers, + "dropout": 0.2, + "ff_l2reg": 0.0001, + "learning_rate": 1.0, "learning_rate_decay": 0.8, "clipping": 1, "optimizer": "sgd", - "epochs": 50000, "batch_size": 32, "rampup": 120, - "report_interval": 500 + "report_interval": 1000, + "learning_rate_cliff": 40000, + "epochs": 50000, }) trainer.attach_input("training_corpus", training_corpus) @@ -111,9 +128,10 @@ workflow.run(wf) workflow.shutdown() ``` -This model takes ~90 minutes to train. It will output evaluation metrics -each 500 epochs, and when it is done, the final parser model will be written -to `local/data/e/caspar/caspar.flow`. +This model takes ~30 minutes to train. It will output evaluation metrics +each 1000 epochs, and when it is done, the final parser model will be written +to `local/data/e/caspar/caspar.flow`. You can train a slightly more accurate, +but much slower parser by using the `--accurate` flag. If you don't have access to OntoNotes 5, you can download a pre-trained model from [here](http://www.jbox.dk/sling/caspar.flow). @@ -203,34 +221,48 @@ This tool takes the following commandline arguments: [... I sling/nlp/parser/tools/parse.cc:131] Load parser from local/data/e/caspar/caspar.flow [... I sling/nlp/parser/tools/parse.cc:140] 34.7368 ms loading parser [... I sling/nlp/parser/tools/parse.cc:235] Evaluating parser on local/data/corpora/caspar/dev.rec - SPAN_P+=76898 - SPAN_P-=6800 - SPAN_R+=76898 - SPAN_R-=6192 - SPAN_Precision=91.8755 - SPAN_Recall=92.5478 - SPAN_F1=92.2105 - FRAME_P+=77866 - FRAME_P-=5859 - FRAME_R+=77859 - FRAME_R-=5233 - FRAME_Precision=93.0021 - FRAME_Recall=93.7022 - FRAME_F1=93.3508 - TYPE_P+=74277 - TYPE_P-=9448 - TYPE_R+=74275 - TYPE_R-=8817 - TYPE_Precision=88.7154 - TYPE_Recall=89.3889 - TYPE_F1=89.0509 - ROLE_P+=37762 - ROLE_P-=16848 - ROLE_R+=37755 - ROLE_R-=16397 - ROLE_Precision=69.1485 - ROLE_Recall=69.7204 - ROLE_F1=69.4333 + SPAN_P+=77757 + SPAN_P-=6185 + SPAN_R+=77757 + SPAN_R-=5333 + SPAN_Precision=92.6318 + SPAN_Recall=93.5817 + SPAN_F1=93.1043 + FRAME_P+=78724 + FRAME_P-=5225 + FRAME_R+=78715 + FRAME_R-=4377 + FRAME_Precision=93.776 + FRAME_Recall=94.7323 + FRAME_F1=94.2517 + PAIR_P+=52597 + PAIR_P-=2339 + PAIR_R+=51988 + PAIR_R-=2164 + PAIR_Precision=95.7423 + PAIR_Recall=96.0038 + PAIR_F1=95.8729 + EDGE_P+=44432 + EDGE_P-=10504 + EDGE_R+=44400 + EDGE_R-=9752 + EDGE_Precision=80.8796 + EDGE_Recall=81.9914 + EDGE_F1=81.4317 + ROLE_P+=39836 + ROLE_P-=15100 + ROLE_R+=39826 + ROLE_R-=14326 + ROLE_Precision=72.5135 + ROLE_Recall=73.5448 + ROLE_F1=73.0255 + TYPE_P+=75604 + TYPE_P-=8345 + TYPE_R+=75595 + TYPE_R-=7497 + TYPE_Precision=90.0594 + TYPE_Recall=90.9775 + TYPE_F1=90.5161 LABEL_P+=0 LABEL_P-=0 LABEL_R+=0 @@ -238,24 +270,24 @@ This tool takes the following commandline arguments: LABEL_Precision=0 LABEL_Recall=0 LABEL_F1=0 - SLOT_P+=112039 - SLOT_P-=26296 - SLOT_R+=112030 - SLOT_R-=25214 - SLOT_Precision=80.9911 - SLOT_Recall=81.6283 - SLOT_F1=81.3085 - COMBINED_P+=266803 - COMBINED_P-=38955 - COMBINED_R+=266787 - COMBINED_R-=36639 - COMBINED_Precision=87.2595 - COMBINED_Recall=87.9249 - COMBINED_F1=87.591 + SLOT_P+=115440 + SLOT_P-=23445 + SLOT_R+=115421 + SLOT_R-=21823 + SLOT_Precision=83.1191 + SLOT_Recall=84.0991 + SLOT_F1=83.6063 + COMBINED_P+=271921 + COMBINED_P-=34855 + COMBINED_R+=271893 + COMBINED_R-=31533 + COMBINED_Precision=88.6383 + COMBINED_Recall=89.6077 + COMBINED_F1=89.1203 #GOLDEN_SPANS=83090 - #PREDICTED_SPANS=83698 + #PREDICTED_SPANS=83942 #GOLDEN_FRAMES=83092 - #PREDICTED_FRAMES=83725 + #PREDICTED_FRAMES=83949 ``` ## Using the CASPAR parser in Python diff --git a/doc/guide/myelin.md b/doc/guide/myelin.md index 1eb97602..2cab1d12 100644 --- a/doc/guide/myelin.md +++ b/doc/guide/myelin.md @@ -380,6 +380,7 @@ var = <#flags> (IN=1, OUT=2, REF=4, LEARNABLE=8 UNIQUE=16, from version 5) <#aliases> * + <#attrs> attr* (from version 6) <#bytes> value op = <#flags> (unused, from version 5) @@ -411,7 +412,7 @@ dtype = "float16" | "float32" | "float64" | "int8" | "uint8" | "int16" | "uint16" | "int32" | "uint64" "flow" = 0x776f6c66 -version = 3 | 4 | 5 +version = 3 | 4 | 5 | 6 ``` A flow file begins with the _magic_ string "flow" followed by a version number. diff --git a/python/myelin/builder.py b/python/myelin/builder.py index 09cce185..d80971fd 100644 --- a/python/myelin/builder.py +++ b/python/myelin/builder.py @@ -226,7 +226,7 @@ def varname(self, var): index += 1 def concat(self, args, name=None): - op = self.rawop("ConcatV2", name) + op = self.rawop("Concat", name) shape = [args[0].shape[0], 0] for arg in args: op.add_input(arg) diff --git a/python/myelin/flow.py b/python/myelin/flow.py index 6c30675e..8d598056 100644 --- a/python/myelin/flow.py +++ b/python/myelin/flow.py @@ -155,6 +155,7 @@ def __init__(self, name): self.aliases = [] self.type = None self.shape = [] + self.attrs = {} self.data = None self.producer = None self.consumers = [] @@ -214,6 +215,13 @@ def unique(self, value): else: self.flags &= ~16 + def add_attr(self, name, value): + if type(value) is bool: value = int(value) + self.attrs[name] = str(value) + + def attr(self, name): + return self.attrs.get(name, None) + def shape_defined(self): for d in self.shape: if d == -1: return False @@ -560,7 +568,7 @@ def save(self, filename): # Write flow file header f = FileWriter(filename) f.write('flow') - f.write_int(5) + f.write_int(6) f.write_int(self.flags) # Write variables. @@ -574,6 +582,10 @@ def save(self, filename): f.write_string(var.type) f.write_int(len(var.shape)) for d in var.shape: f.write_int(d) + f.write_int(len(var.attrs)) + for a in var.attrs: + f.write_string(a) + f.write_string(op.attrs[a]) f.write_object(var.data) # Write operations. @@ -636,7 +648,7 @@ def load(self, filename): assert magic == memoryview(b'flow'), magic.tobytes() version = f.read_int() - assert version == 4 or version == 5, version + assert version == 4 or version == 5 or version == 6, version if version >= 5: self.flags = f.read_int() num_vars = f.read_int() @@ -658,6 +670,12 @@ def load(self, filename): shape.append(f.read_int()) var = self.var(name, type=t, shape=shape) var.flags = flags + if version >= 6: + num_attr = f.read_int() + for _ in range(num_attr): + attr_name = f.read_string() + attr_val = f.read_string() + var.add_attr(attr_name, attr_val) size = f.read_long() if size > 0: var.data = f.slice(size) # avoid creating a copy diff --git a/python/myelin/simulator.py b/python/myelin/simulator.py index 39d6694a..25952f66 100644 --- a/python/myelin/simulator.py +++ b/python/myelin/simulator.py @@ -218,7 +218,7 @@ def compute(flow, f, data): v[o[0]] = np.array(len(v[i[0]].shape)) elif op.type == "Identity": v[o[0]] = v[i[0]] - elif op.type == "ConcatV2": + elif op.type == "Concat": n = int(op.attr("N")) axis = v[i[n]] seq = [] diff --git a/sling/base/registry.h b/sling/base/registry.h index 3e34b412..6ee0300b 100644 --- a/sling/base/registry.h +++ b/sling/base/registry.h @@ -192,7 +192,7 @@ template struct ComponentRegistry { Registrar *r = components; while (r != nullptr && strcmp(type, r->type()) != 0) r = r->next(); if (r == nullptr) { - LOG(FATAL) << "Unknown " << name << " component: '" << type << "'."; + LOG(FATAL) << "Unknown " << name << " component: " << type; } return r; } diff --git a/sling/file/posix.cc b/sling/file/posix.cc index 911433d3..fe04a0b8 100644 --- a/sling/file/posix.cc +++ b/sling/file/posix.cc @@ -271,7 +271,7 @@ class PosixFileSystem : public FileSystem { return Status::OK; } - Status FlushMappedMemory(void *data, size_t size) { + Status FlushMappedMemory(void *data, size_t size) override { if (msync(data, size, MS_SYNC) != 0) return IOError("msync", errno); return Status::OK; } diff --git a/sling/myelin/builder.cc b/sling/myelin/builder.cc index 0051debe..c16f752b 100644 --- a/sling/myelin/builder.cc +++ b/sling/myelin/builder.cc @@ -60,8 +60,18 @@ Flow::Variable *FlowBuilder::Parameter(const string &name, return var; } -Flow::Variable *FlowBuilder::Random(Variable *var) { - var->set_random(true); +Flow::Variable *FlowBuilder::RandomUniform(Variable *var) { + var->init = Flow::Variable::INIT_UNIFORM; + return var; +} + +Flow::Variable *FlowBuilder::RandomNormal(Variable *var) { + var->init = Flow::Variable::INIT_NORMAL; + return var; +} + +Flow::Variable *FlowBuilder::RandomOrtho(Variable *var) { + var->init = Flow::Variable::INIT_ORTHO; return var; } @@ -224,11 +234,29 @@ Flow::Variable *FlowBuilder::Concat(const std::vector &parts, shape.set(axis, width); std::vector args = parts; args.push_back(Const(axis)); - auto *concat = Op("ConcatV2", args, parts[0]->type, shape); + auto *concat = Op("Concat", args, parts[0]->type, shape); concat->producer->SetAttr("N", n); return concat; } +std::vector FlowBuilder::Split(Variable *v, int splits, + int axis) { + CHECK(v->dim(axis) % splits == 0) + << "Cannot split " << v->shape.ToString() << " into " << splits + << " parts along dimension " << axis; + std::vector parts; + Operation *op = RawOp("Split", {v, Const(splits), Const(axis)}); + Shape shape = v->shape; + shape.set(axis, shape.dim(axis) / splits); + for (int i = 0; i < splits; ++i) { + string name = op->name + ":" + std::to_string(i); + Variable *out = flow_->AddVariable(name, v->type, shape); + op->AddOutput(out); + parts.push_back(out); + } + return parts; +} + Flow::Variable *FlowBuilder::FFLayers(Variable *input, std::vector layers, int hidden, @@ -242,7 +270,8 @@ Flow::Variable *FlowBuilder::FFLayers(Variable *input, int width = layers[l]; // Add weight matrix. - auto *W = Random(Parameter("W" + std::to_string(l), type, {height, width})); + auto *W = Parameter("W" + std::to_string(l), type, {height, width}); + RandomNormal(W); v = MatMul(v, W); // Optionally add bias. @@ -267,71 +296,6 @@ Flow::Variable *FlowBuilder::FFLayers(Variable *input, return logits; } -Flow::Variable *FlowBuilder::LSTMLayer(Variable *input, int size) { - // Get LSTM dimensions. - Type type = input->type; - int input_dim = input->dim(1); - - // Define parameters. - auto *x2i = Random(Parameter("x2i", type, {input_dim, size})); - auto *h2i = Random(Parameter("h2i", type, {size, size})); - auto *c2i = Random(Parameter("c2i", type, {size, size})); - auto *bi = Parameter("bi", type, {1, size}); - - auto *x2o = Random(Parameter("x2o", type, {input_dim, size})); - auto *h2o = Random(Parameter("h2o", type, {size, size})); - auto *c2o = Random(Parameter("c2o", type, {size, size})); - auto *bo = Parameter("bo", type, {1, size}); - - auto *x2c = Random(Parameter("x2c", type, {input_dim, size})); - auto *h2c = Random(Parameter("h2c", type, {size, size})); - auto *bc = Parameter("bc", type, {1, size}); - - // Channels -- h_in, c_in = h_{t-1}, c_{t-1} - auto *h_in = Placeholder("h_in", type, {1, size}, true); - auto *c_in = Placeholder("c_in", type, {1, size}, true); - - // Input -- i_t = sigmoid(x_t * x2i + h_in * h2i + c_in * c2i + bi) - auto *i_ait = Name(Add(MatMul(input, x2i), - Add(MatMul(h_in, h2i), - Add(MatMul(c_in, c2i), bi))), - "i_ait"); - auto *i_it = Name(Sigmoid(i_ait), "i_it"); - - // Forget -- f_t = 1 - i_t - auto *i_ft = Name(Sub(One(), i_it), "i_ft"); - - // Memory -- tanh(x_t * x2c + h_in * h2c + h_in * h2c + bc) - auto *i_awt = Name(Add(MatMul(input, x2c), - Add(MatMul(h_in, h2c), bc)), - "i_awt"); - auto *i_wt = Name(Tanh(i_awt), "i_wt"); - - // Control -- c_out = c_t = i_t * w_t + f_t * c_in - auto *c_out = Name(Add(Mul(i_it, i_wt), Mul(i_ft, c_in)), "c_out"); - c_out->set_out()->set_ref(); - - // Output -- o_t = sigmoid(x_t * x2o + c_t * c2o + h_in * h2o + bo) - auto *i_aot = Name(Add(MatMul(input, x2o), - Add(MatMul(c_out, c2o), - Add(MatMul(h_in, h2o), bo))), - "i_aot"); - auto *i_ot = Name(Sigmoid(i_aot), "i_ot"); - - // Hidden -- h_out = h_t = o_t * tanh(c_out) - auto *h_out = Name(Mul(i_ot, Tanh(c_out)), "h_out"); - h_out->set_out()->set_ref(); - - // Connectors for hidden and control channels. - flow_->Connect({h_in, h_out}); - flow_->Connect({c_in, c_out}); - - // The control channel has a single-source gradient. - c_in->set_unique(); - - return h_out; -} - } // namespace myelin } // namespace sling diff --git a/sling/myelin/builder.h b/sling/myelin/builder.h index 948607ea..2f18a275 100644 --- a/sling/myelin/builder.h +++ b/sling/myelin/builder.h @@ -80,7 +80,9 @@ class FlowBuilder : public Scope { Variable *Parameter(const string &name, Type type, const Shape &shape); // Initialize variable with random values. Returns the variable itself. - Variable *Random(Variable *var); + Variable *RandomUniform(Variable *var); + Variable *RandomNormal(Variable *var); + Variable *RandomOrtho(Variable *var); // Add input variable to function. Variable *Placeholder(const string &name, Type type, const Shape &shape, @@ -382,6 +384,9 @@ class FlowBuilder : public Scope { // Concatenation. Variable *Concat(const std::vector &parts, int axis = 1); + // Splitting. + std::vector Split(Variable *v, int splits, int axis = 1); + // Slicing. Variable *Slice(Variable *v, Variable *begin, const Shape &size) { return Op("Slice", {v, begin, Const(size)}, v->type, size); @@ -404,9 +409,6 @@ class FlowBuilder : public Scope { return FFLayers(input, {size}, -1, bias); } - // Long short-term memory (LSTM) layer. - Variable *LSTMLayer(Variable *input, int size); - // Return function for builder. Function *func() const { return func_; } diff --git a/sling/myelin/compiler.cc b/sling/myelin/compiler.cc index e6f2df7d..bc7d1233 100644 --- a/sling/myelin/compiler.cc +++ b/sling/myelin/compiler.cc @@ -47,6 +47,7 @@ DEFINE_bool(dump_input_flow, false, "Dump raw input flow to log"); DEFINE_bool(dump_final_flow, false, "Dump final analyzed flow to log"); DEFINE_bool(dump_cells, false, "Dump cells after compilation"); DEFINE_bool(dump_code, false, "Dump generated assembly code"); +DEFINE_bool(param_stats, false, "Dump model parameter statistics"); DEFINE_bool(check_flow_consistency, false, "Check that flow is consistent"); DEFINE_bool(dynamic_instance_allocation, false, "Dynamic instance allocation"); DEFINE_bool(mkl, false, "Use Intel Math Kernel Library"); @@ -60,6 +61,7 @@ DEFINE_bool(jit_debug, false, "Debug break in jit code"); DEFINE_int32(cuda_device, -1, "CUDA device number"); DEFINE_int32(cuda_context_flags, 0, "CUDA context flags"); DEFINE_int32(sparse_threshold, 64, "Minimum dimension size for sparse update"); +DEFINE_bool(compile_only, false, "Stop after compilation"); namespace sling { namespace myelin { @@ -186,6 +188,18 @@ void Compiler::Compile(Flow *flow, Network *net) { } } + // Optionally output parameter statictics. + if (FLAGS_param_stats) { + int total = 0; + for (Tensor *t : net->globals()) { + if (t->IsScalar()) continue; + if (t->type() != DT_FLOAT) continue; + printf("%8d %s\n", t->elements(), t->name().c_str()); + total += t->elements(); + } + printf("%8d TOTAL\n", total); + } + // Optionally output generated code to ELF file. if (!FLAGS_jit_code.empty() || FLAGS_dump_code) { // Link code. @@ -219,6 +233,12 @@ void Compiler::Compile(Flow *flow, Network *net) { unlink(tmpname); } } + + // Stop after compilation if requested. + if (FLAGS_compile_only) { + LOG(INFO) << "Stop after compilation"; + exit(1); + } } void Compiler::WriteGraph(const Flow &flow, diff --git a/sling/myelin/compute.cc b/sling/myelin/compute.cc index 20d61099..0790aae9 100644 --- a/sling/myelin/compute.cc +++ b/sling/myelin/compute.cc @@ -434,10 +434,8 @@ class InstanceAllocator { }; Library::~Library() { - if (owns_kernels_) { - for (auto o : kernels_) { - for (auto k : o.second) delete k; - } + for (auto o : kernels_) { + for (auto k : o.second) delete k; } } @@ -492,25 +490,6 @@ const Library::Kernels &Library::Lookup(const string &op) const { return f->second; } -bool Library::Singleton(const string &op, - const string &name, - Library *singleton) const { - // Singleton library must be empty or already non-owning. - CHECK(!singleton->owns_kernels_ || singleton->kernels_.empty()); - singleton->owns_kernels_ = false; - - // Find kernel. - auto f = kernels_.find(op); - if (f == kernels_.end()) return false; - for (Kernel *kernel : f->second) { - if (kernel->Name() == name) { - singleton->kernels_[kernel->Operation()].push_back(kernel); - return true; - } - } - return false; -} - void Tensor::Link(Tensor *link) { next_link_->prev_link_ = link->prev_link_; link->prev_link_->next_link_ = next_link_; @@ -574,6 +553,12 @@ int Tensor::ChannelElementSize() const { return Align(size(), byte_alignment()); } +int Tensor::AxisSize(int axis) const { + if (axis > 0) return stride(axis - 1); + if (dynamic_) return ChannelElementSize(); + return size_; +} + bool Tensor::SupportsOrder(Order order) { return CombinedOrder(order_, order) != CONFLICTING_ORDER; } @@ -649,6 +634,7 @@ string Tensor::TypeString() const { string str; if (ref_) str.append("&"); str.append(TypeTraits::of(type_).name()); + if (dynamic_) str.append("<>"); if (!shape_.scalar()) { str.append("["); str.append(shape_.ToString()); @@ -659,7 +645,15 @@ string Tensor::TypeString() const { string Tensor::ToString(const char *data, bool deref) const { // Resolve references. - if (deref && ref()) data = *reinterpret_cast(data); + if (deref) { + if (dynamic()) { + data = *reinterpret_cast(data); + if (data == nullptr) return "null"; + data = *reinterpret_cast(data); + } else if (ref()) { + data = *reinterpret_cast(data); + } + } // Check for shape and null. if (!shape().defined()) return "*"; @@ -702,26 +696,30 @@ string Tensor::ToString(const char *data, bool deref) const { } Channel::Channel(const Tensor *format) : format_(format) { - // Align the element size to the byte alignment of the format tensor to ensure - // proper alignment of the elements in the channel array. - DCHECK(format->order() == ROW_MAJOR) << format->name(); - DCHECK_GE(format->rank(), 1) << format->name(); - DCHECK_EQ(format->dim(0), 1) << format->name(); - element_size_ = format->ChannelElementSize(); + if (format != nullptr) { + // Align the element size to the byte alignment of the format tensor to + // ensure proper alignment of the elements in the channel array. + DCHECK(format->order() == ROW_MAJOR) << format->name(); + DCHECK_GE(format->rank(), 1) << format->name(); + DCHECK_EQ(format->dim(0), 1) << format->name(); + element_size_ = format->ChannelElementSize(); - // Channel are aligned to the element alignment and cache lines. - EnsureAlignment(&alignment_, format->byte_alignment()); - EnsureAlignment(&alignment_, jit::CPU::CacheLineSize()); + // Channel are aligned to the element alignment and cache lines. + EnsureAlignment(&alignment_, format->byte_alignment()); + EnsureAlignment(&alignment_, jit::CPU::CacheLineSize()); + } } Channel::~Channel() { - runtime()->FreeChannel(data_, placement()); + if (data_ != nullptr) { + runtime()->FreeChannel(data_, placement()); + } } -void Channel::resize(int n) { +void Channel::resize(size_t n) { // Allocate more space if needed. if (n > capacity_) { - int cap = capacity_ * 2; + size_t cap = capacity_ * 2; if (cap < n) cap = n; if (cap < 8) cap = 8; reserve(cap); @@ -738,10 +736,10 @@ void Channel::resize(int n) { size_ = n; } -void Channel::reset(int n) { +void Channel::reset(size_t n) { // Allocate more space if needed. if (n > capacity_) { - int cap = capacity_ * 2; + size_t cap = capacity_ * 2; if (cap < n) cap = n; if (cap < 8) cap = 8; reserve(cap); @@ -754,7 +752,7 @@ void Channel::reset(int n) { size_ = n; } -void Channel::reserve(int n) { +void Channel::reserve(size_t n) { // Never remove any existing elements. if (n < size_) return; if (n == capacity_) return; @@ -770,7 +768,7 @@ void Channel::reserve(int n) { capacity_ = n; } -void Channel::zero(int n) { +void Channel::zero(size_t n) { runtime()->ClearChannel(data_, n * element_size_, element_size_, placement()); } @@ -805,11 +803,17 @@ ProfileSummary::~ProfileSummary() { } Instance::Instance(const Cell *cell) : cell_(cell) { - cell_->runtime()->AllocateInstance(this); + if (cell_ != nullptr) { + cell_->runtime()->AllocateInstance(this); + } else { + data_ = nullptr; + } } Instance::~Instance() { - cell_->runtime()->FreeInstance(this); + if (cell_ != nullptr) { + cell_->runtime()->FreeInstance(this); + } } void Instance::Clear() { @@ -859,6 +863,39 @@ string Instance::ToString() const { return str; } +InstanceArray::InstanceArray(Cell *cell) + : cell_(cell), begin_(nullptr), end_(nullptr), limit_(nullptr) {} + +InstanceArray::~InstanceArray() { + // Destruct all elements. + for (Instance *d = begin_; d < limit_; ++d) d->~Instance(); + + // Free array. + free(begin_); +} + +void InstanceArray::Resize(size_t size) { + int cap = capacity(); + if (size < cap) { + end_ = begin_ + size; + } else if (size > cap) { + // This awkward way of assigning the new data buffer to begin_ is needed to + // avoid getting a GCC 8+ class-memaccess warning. + size_t bytes = size * sizeof(Instance); + void **data = reinterpret_cast(&begin_); + *data = realloc(*data, bytes); + end_ = begin_ + cap; + limit_ = begin_ + size; + while (end_ < limit_) new (end_++) Instance(cell_); + } +} + +void InstanceArray::Clear() { + for (Instance *d = begin_; d < limit_; ++d) d->~Instance(); + free(begin_); + begin_ = end_ = limit_ = nullptr; +} + void Step::SetRegisterUsage(int regs) { if (cell_ != nullptr && cell_->register_usage_ < regs) { cell_->register_usage_ = regs; @@ -866,8 +903,8 @@ void Step::SetRegisterUsage(int regs) { } void Step::SetPreservedRegisterUsage(int regs) { - // There are eight caller-saved registers. - SetRegisterUsage(8 + regs); + // There are nine caller-saved registers. + SetRegisterUsage(9 + regs); } bool Step::AllowInPlace(int input, int output, bool preserved) { @@ -888,12 +925,20 @@ bool Step::AllowInPlace(int input, int output, bool preserved) { if (t->out()) return false; } if (t->ref() != out->ref()) return false; + if (t->dynamic() != out->dynamic()) return false; in = t; t = t->shared(); } // Check if output can be shared. if (out->shared()) return false; + if (out->ref()) { + if (preserved) { + if (out->out() && in->in()) return false; + } else { + if (out->out() || in->in()) return false; + } + } // Share input and output. out->set_shared(in); @@ -988,33 +1033,157 @@ Network::~Network() { for (auto *s : steps_) delete s; } -void Network::InitLearnableWeights(int64 seed, float mean, float stddev) { +// Orthogonalize a set of vectors stored as the columns of matrix A (m x n) +// using the Gram-Schmidt process. +static void OrthogonalizeColumns(float *A, int m, int n) { + // Orthogonalize one column vector at a time. + float *aj, *ak; + for (int j = 0; j < n; ++j) { + // To orthogonalize the vector in column j with respect to the previous + // vectors, subtract from it its projection onto each of the previous + // vectors. + for (int k = 0; k < j; ++k) { + // Compute dot product r = A_k * A_j. + float r = 0.0; + ak = A + k; + aj = A + j; + for (int i = 0; i < m; ++i, ak += n, aj += n) r += *ak * *aj; + + // Update A_j -= r * A_k. + ak = A + k; + aj = A + j; + for (int i = 0; i < m; ++i, ak += n, aj += n) *aj -= r * *ak; + } + + // Normalize A_j. + aj = A + j; + float sum = 0.0; + for (int i = 0; i < m; ++i, aj += n) sum += *aj * *aj; + float scaler = 1.0/ sqrt(sum); + aj = A + j; + for (int i = 0; i < m; ++i, aj += n) *aj *= scaler; + } +} + +// Orthogonalize a set of vectors stored as the rows of matrix A (m x n) +// using the Gram-Schmidt process. +static void OrthogonalizeRows(float *A, int m, int n) { + // Orthogonalize one row vector at a time. + float *aj, *ak; + for (int j = 0; j < m; ++j) { + // To orthogonalize the vector in row j with respect to the previous + // vectors, subtract from it its projection onto each of the previous + // vectors. + aj = A + j * n; + for (int k = 0; k < j; ++k) { + // Compute dot product r = A_k * A_j. + float r = 0.0; + ak = A + k * n; + for (int i = 0; i < n; ++i) r += ak[i] * aj[i]; + + // Update A_j -= r * A_k. + for (int i = 0; i < n; ++i) aj[i] -= r * ak[i]; + } + + // Normalize A_j. + float sum = 0.0; + for (int i = 0; i < n; ++i) sum += aj[i] * aj[i]; + float scaler = 1.0/ sqrt(sum); + for (int i = 0; i < n; ++i) aj[i] *= scaler; + } +} + +void Network::InitModelParameters(int64 seed) { // Initialize random generator. std::mt19937_64 prng; prng.seed(seed); - std::uniform_real_distribution dist(mean, stddev); + std::normal_distribution normal(0, 1.0); + std::uniform_real_distribution uniform(-1.0, 1.0); - // Initialize learnable variable with Gaussian noise. + // Initialize model parameters. for (auto *tensor : globals_) { - if (!tensor->random_init_) continue; if (tensor->type() != DT_FLOAT) continue; if (tensor->data() == nullptr) continue; + float dim = tensor->elements(); + float scale = 1.0 / sqrt(dim); + if (tensor->HasStandardLayout()) { float *data = reinterpret_cast(tensor->data()); - for (int i = 0; i < tensor->elements(); ++i) { - data[i] = dist(prng); + switch (tensor->init_) { + case Flow::Variable::INIT_ZERO: + // Variables are already zero-initialized. + break; + case Flow::Variable::INIT_UNIFORM: { + for (int i = 0; i < tensor->elements(); ++i) { + data[i] = uniform(prng) * scale; + } + break; + } + case Flow::Variable::INIT_NORMAL: { + for (int i = 0; i < tensor->elements(); ++i) { + data[i] = normal(prng) * scale; + } + break; + } + case Flow::Variable::INIT_ORTHO: { + for (int i = 0; i < tensor->elements(); ++i) { + data[i] = normal(prng); + } + + if (tensor->rank() >= 2) { + int m = tensor->dim(0); + int n = tensor->elements() / m; + if (n > m) { + OrthogonalizeRows(data, m, n); + } else { + OrthogonalizeColumns(data, m, n); + } + } + break; + } + default: + LOG(WARNING) << "Unknown initialization for " << tensor->name(); } } else { - for (int i = 0; i < tensor->elements(); ++i) { - size_t offset = tensor->LinearOffset(i); - *reinterpret_cast(tensor->data() + offset) = dist(prng); + switch (tensor->init_) { + case Flow::Variable::INIT_ZERO: + // Variables are already zero-initialized. + break; + case Flow::Variable::INIT_UNIFORM: { + for (int i = 0; i < tensor->elements(); ++i) { + size_t offset = tensor->LinearOffset(i); + float *p = reinterpret_cast(tensor->data() + offset); + *p = uniform(prng) * scale; + } + break; + } + case Flow::Variable::INIT_NORMAL: { + for (int i = 0; i < tensor->elements(); ++i) { + size_t offset = tensor->LinearOffset(i); + float *p = reinterpret_cast(tensor->data() + offset); + *p = normal(prng) * scale; + } + break; + } + case Flow::Variable::INIT_ORTHO: { + LOG(WARNING) << "Cannot initialize tensor with non-standard layout " + << "with orthogonal vectors: " << tensor->name(); + for (int i = 0; i < tensor->elements(); ++i) { + size_t offset = tensor->LinearOffset(i); + float *p = reinterpret_cast(tensor->data() + offset); + *p = normal(prng) * scale; + } + break; + } + default: + LOG(WARNING) << "Unknown initialization for " << tensor->name(); } } } } -void Network::SaveLearnedWeights(Flow *flow) const { +void Network::SaveParameters(Flow *flow) const { // Find all learnable variables in flow. for (Flow::Variable *var : flow->vars()) { if (!var->learnable()) continue; @@ -1046,7 +1215,44 @@ void Network::SaveLearnedWeights(Flow *flow) const { dst += element_size; } } - var->clear_learnable(); + var->set_learnable(false); + } +} + +void Network::LoadParameters(const Flow &flow) { + // Find all learnable variables in flow. + for (const Flow::Variable *var : flow.vars()) { + // Find tensor for variable. + Tensor *tensor = LookupParameter(var->name); + if (tensor == nullptr) continue; + if (!tensor->learnable()) continue; + + // Check that type and shape match. + if (tensor->type() != var->type || tensor->shape() != var->shape) { + LOG(WARNING) << "Tensor " << tensor->name() << " type mismatch: " + << tensor->TypeString() << " vs " << var->TypeString(); + continue; + } + + // If tensor data has standard layout we can copy the data directly. + // Otherwise, tensor data is copied element-by-element. + if (tensor->HasStandardLayout()) { + // Copy directly. + memcpy(tensor->data(), var->data, var->size); + } else { + // Allocate data. + int elements = tensor->shape().elements(); + int element_size = tensor->element_size(); + char *dst = tensor->data(); + char *src = var->data; + + // Copy elements one at a time. + for (int i = 0; i < elements; ++i) { + size_t offset = tensor->LinearOffset(i); + memcpy(dst + offset, src, element_size); + src += element_size; + } + } } } @@ -1102,13 +1308,14 @@ bool Network::Compile(const Flow &flow, const Library &library) { varmap[var] = tensor; tensor->constant_ = var->constant(); tensor->local_ = !var->global(); - tensor->random_init_ = var->global() && var->random() && var->learnable(); + tensor->init_ = var->init; tensor->name_ = var->name; for (const string &alias : var->aliases) { names_[alias] = tensor; } tensor->type_ = var->type; tensor->ref_ = var->ref(); + tensor->dynamic_ = var->dynamic(); tensor->shape_ = var->shape; tensor->aligned_ = var->shape; tensor->minalign_.fill(var->rank(), 1); @@ -1462,7 +1669,7 @@ bool Network::Compile(const Flow &flow, const Library &library) { // Set tensor size. tensor->size_ = size; - tensor->space_ = tensor->ref() ? sizeof(void *) : size; + tensor->space_ = tensor->ref() || tensor->dynamic() ? sizeof(void *) : size; // Determine placement for tensor based on producer and consumer locations. if (tensor->producer_ != nullptr) { diff --git a/sling/myelin/compute.h b/sling/myelin/compute.h index 209811b9..1d7de8ed 100644 --- a/sling/myelin/compute.h +++ b/sling/myelin/compute.h @@ -131,12 +131,6 @@ class Library : public Transformations { // Find kernels implementing operation. const Kernels &Lookup(const string &op) const; - // Find kernel and add to singleton library. The singleton library does not - // own the kernel. - bool Singleton(const string &op, - const string &name, - Library *singleton) const; - private: // Register custom kernel. CustomKernel &RegisterCustomKernel(const string &op, const string &name, @@ -145,9 +139,6 @@ class Library : public Transformations { // Map from op name to kernels implementing the op. std::unordered_map kernels_; - // Whether kernels are owned by library. - bool owns_kernels_ = true; - // Empty kernel list. Kernels no_kernels_; }; @@ -397,6 +388,10 @@ class Tensor { bool ref() const { return ref_; } void set_ref(bool ref) { ref_ = ref; } + // Reference to dynamically sized tensor channel. + bool dynamic() const { return dynamic_; } + void set_dynamic(bool dynamic) { dynamic_ = dynamic; } + // Tensor shape. const Shape &shape() const { return shape_; } int rank() const { return shape_.rank(); } @@ -611,6 +606,9 @@ class Tensor { // Size of of channel elements based on this tensor. int ChannelElementSize() const; + // Size of elements along an axis. + int AxisSize(int axis) const; + // Return corresponding gradient tensor. Tensor *Gradient() const; @@ -637,6 +635,9 @@ class Tensor { // Tensor reference. bool ref_ = false; + // Reference to dynamically sized tensor channel. + bool dynamic_ = false; + // Tensor shape. Shape shape_; @@ -681,8 +682,8 @@ class Tensor { // Constant tensors are global and cannot be modified. bool constant_ = false; - // Initialize tensor with random values from normal distribution. - bool random_init_ = false; + // Initialization for tensor. + Flow::Variable::Initialization init_ = Flow::Variable::INIT_ZERO; // Local tensors are allocated in the instance data block. bool local_ = true; @@ -845,19 +846,19 @@ class Channel { void clear() { resize(0); } // Change size of channel. - void resize(int n); + void resize(size_t n); // Change size of channel and clear all elements. - void reset(int n); + void reset(size_t n); // Reserve space for channel elements. - void reserve(int n); + void reserve(size_t n); // Zero-fill element in channel. - void zero(int n); + void zero(size_t n); // Return pointer to channel element. - char *at(int index) const { + char *at(size_t index) const { return data_ + (index * element_size_); } @@ -868,7 +869,7 @@ class Channel { void pop() { resize(size_ - 1); } // Return the number of elements in the channel. - int size() const { return size_; } + size_t size() const { return size_; } // Return placement of channel. Placement placement() const { @@ -889,10 +890,10 @@ class Channel { char *data_ = nullptr; // Number of elements in channel. - int size_ = 0; + size_t size_ = 0; // Number of allocated elements. - int capacity_ = 0; + size_t capacity_ = 0; // A tensor describing the element type of the channel. const Tensor *format_; @@ -1012,6 +1013,7 @@ class ProfileSummary { class Instance { public: // Create data instance. + Instance() : data_(nullptr), cell_(nullptr) {} Instance(const Cell *cell); Instance(const Flow::Function *func) : Instance(func->cell) {} @@ -1106,6 +1108,19 @@ class Instance { return SetReference(var->tensor, address); } + // Sets a dynamic tensor to channel. + void SetChannel(const Tensor *param, Channel *channel) { + DCHECK(param != nullptr); + DCHECK(param->IsLocal()) << param->name(); + DCHECK(param->dynamic()) << param->name(); + DCHECK(param->cell() == cell_) << param->name(); + *reinterpret_cast(data_ + param->offset()) = channel; + } + void SetChannel(const Flow::Variable *var, Channel *channel) { + DCHECK(var->tensor != nullptr) << var->name; + return SetChannel(var->tensor, channel); + } + // Clear instance tensor. void Clear(const Tensor *param) { memset(GetAddress(param), 0, param->space()); @@ -1164,6 +1179,36 @@ class Instance { const Cell *cell_; }; +// Resizable array of cell instances. +class InstanceArray { + public: + // Create empty array of cell instances. + InstanceArray(Cell *cell); + + // Deallocate instance array. + ~InstanceArray(); + + // Index operator. + Instance &operator[](size_t index) { return *(begin_ + index); } + const Instance &operator[](size_t index) const { return *(begin_ + index); } + + // Size and capacity. + size_t size() const { return end_ - begin_; } + size_t capacity() const { return limit_ - begin_; } + + // Resize array. This will never shrink the capacity of the array. + void Resize(size_t size); + + // Deallocate all the instances and reset the capacity to zero. + void Clear(); + + private: + Cell *cell_; // cell type for instances + Instance *begin_; // begining of instance array + Instance *end_; // end of used instances + Instance *limit_; // end of allocated instances +}; + // A cell contains generated code for executing computation of a function. class Cell { public: @@ -1364,15 +1409,16 @@ class Network { // Add resource to network. This is deleted together with the network. void AddResource(Resource *resource) { resources_.push_back(resource); } - // Initialize learnable weights with random values from a normal distribution. - void InitLearnableWeights(int64 seed = 0, - float mean = 0.0, - float stddev = 1e-4); + // Initialize model parameters. + void InitModelParameters(int64 seed = 0); // Save weights after training. This copies the value of each learnable tensor // in the network to the corresponding variable in the flow. This clears the // learning flag for the variable and turns it into a constant. - void SaveLearnedWeights(Flow *flow) const; + void SaveParameters(Flow *flow) const; + + // Copy weight from flow for learnable tensors. + void LoadParameters(const Flow &flow); // Runtime support functions. Runtime *runtime() const { return runtime_; } diff --git a/sling/myelin/flow.cc b/sling/myelin/flow.cc index 2e2c835b..9f8ce294 100644 --- a/sling/myelin/flow.cc +++ b/sling/myelin/flow.cc @@ -291,7 +291,7 @@ class Parser { Parser(const char *ptr, const char *end) : ptr_(ptr), end_(end) {} // Get data buffer from input and advance the current input pointer. - const char *Get(int len) { + const char *Get(size_t len) { CHECK_LE(len, end_ - ptr_) << "Unexpected end of input"; const char *p = ptr_; ptr_ += len; @@ -490,6 +490,7 @@ string Flow::Variable::TypeString() const { string str; if (ref()) str.append("&"); str.append(TypeTraits::of(type).name()); + if (dynamic()) str.append("<>"); if (!shape.scalar()) { str.append("["); str.append(shape.ToString()); @@ -502,7 +503,12 @@ string Flow::Variable::DataString() const { // Locate data. const char *p = data; if (p == nullptr) return "∅"; - if (ref()) { + if (dynamic()) { + p = *reinterpret_cast(p); + if (p == nullptr) return "null"; + p = *reinterpret_cast(p); + if (p == nullptr) return "null"; + } else if (ref()) { p = *reinterpret_cast(p); if (p == nullptr) return "null"; } @@ -804,9 +810,9 @@ void Flow::Read(const char *data, size_t size) { // Read header. Parser parser(data, data + size); int magic = parser.GetInt(); - CHECK_EQ(magic, kMagic) << "not a flow file"; + CHECK_EQ(magic, MAGIC) << "not a flow file"; int version = parser.GetInt(); - CHECK(version >= 3 && version <= 5) + CHECK(version >= 3 && version <= 6) << "unsupported flow file version " << version; if (version >= 5) parser.GetInt(); // unused flags @@ -850,6 +856,16 @@ void Flow::Read(const char *data, size_t size) { var->shape.add(size == -1 ? batch_size_ : size); } + // Get attributes. + if (version >= 6) { + int num_attrs = parser.GetInt(); + for (int j = 0; j < num_attrs; ++j) { + string name = parser.GetString(); + string value = parser.GetString(); + var->SetAttr(name, value); + } + } + // Get optional variable constant. int64 size = parser.GetLong(); if (size != 0) { @@ -975,8 +991,8 @@ void Flow::Save(const string &filename, int version) const { // Write header (magic and version). CHECK_GE(version, 3); - CHECK_LE(version, kVersion); - file.WriteInt(kMagic); + CHECK_LE(version, VERSION); + file.WriteInt(MAGIC); file.WriteInt(version); if (version >= 5) file.WriteInt(0); // unused flags @@ -1008,6 +1024,15 @@ void Flow::Save(const string &filename, int version) const { file.WriteInt(var->shape.dim(d)); } + // Write attributes. + if (version >= 6) { + file.WriteInt(var->attrs().size()); + for (const auto &attr : var->attrs()) { + file.WriteString(attr.name); + file.WriteString(attr.value); + } + } + // Write size. file.WriteInt64(var->size); @@ -2001,6 +2026,18 @@ bool Flow::IsConsistent() const { } } + // Check connectors. + for (const Connector *cnx : cnxs_) { + for (const Variable *link : cnx->links) { + // Check that link variable is in flow. + if (std::find(vars_.begin(), vars_.end(), link) == vars_.end()) { + LOG(WARNING) << "Link variable " << link->name << " is not in flow " + << "for connector " << cnx->name; + return false; + } + } + } + return true; } @@ -2014,6 +2051,7 @@ string Flow::ToString() const { if (var->in()) StringAppendF(&str, " in"); if (var->out()) StringAppendF(&str, " out"); if (var->unique()) StringAppendF(&str, " unique"); + if (var->is(Flow::Variable::NOGRADIENT)) StringAppendF(&str, " nograd"); if (var->constant()) { StringAppendF(&str, ", %" PRIu64 " bytes", var->size); } diff --git a/sling/myelin/flow.h b/sling/myelin/flow.h index 904acc22..5ea14bf5 100644 --- a/sling/myelin/flow.h +++ b/sling/myelin/flow.h @@ -326,8 +326,8 @@ class Flow { struct Function; // Flow file version - static const int kVersion = 5; - static const int kMagic = 0x776f6c66; + static const int VERSION = 6; + static const int MAGIC = 0x776f6c66; // Flow artifact. template struct Artifact { @@ -343,27 +343,33 @@ class Flow { } return static_cast(this); } - T *clear(uint32 flag, bool disable = true) { - return set(flag, !disable); - } string name; // artifact name uint32 flags = 0; // artifact flags (meaning depends on artifact type) }; // Flow variable. - struct Variable : public Artifact { + struct Variable : public Artifact, public Attributes { // Variable flags. enum Flag { - NONE = 0, // no flags - IN = 1, // input variable - OUT = 2, // output variable - REF = 4, // reference variable - LEARNABLE = 8, // learnable global variable - UNIQUE = 16, // input with single gradient - RANDOM = 32, // initialize with random values - ROW = 64, // request row-major order - COL = 128, // request column-major order + NONE = 0, // no flags + IN = 1, // input variable + OUT = 2, // output variable + REF = 4, // reference variable + LEARNABLE = 8, // learnable global variable + UNIQUE = 16, // input with single gradient + DYNAMIC = 32, // dynamically sized tensor channel + ROW = 64, // request row-major order + COL = 128, // request column-major order + NOGRADIENT = 256, // do not compute gradient for variable + }; + + // Initialization for learnable parameters. + enum Initialization { + INIT_ZERO = 0, // initialize to zero + INIT_UNIFORM = 1, // uniform random initialization + INIT_NORMAL = 2, // normal-distributed initialization + INIT_ORTHO = 3, // normal-distributed orthogonal initialization }; // Add alias for variable. @@ -384,25 +390,25 @@ class Flow { // Input variable flag. bool in() const { return is(IN); } Variable *set_in(bool enable = true) { return set(IN, enable); } - Variable *clear_in(bool disable = true) { return clear(IN, disable); } // Output variable flag. bool out() const { return is(OUT); } Variable *set_out(bool enable = true) { return set(OUT, enable); } - Variable *clear_out(bool disable = true) { return clear(OUT, disable); } // Reference variable flag. bool ref() const { return is(REF); } Variable *set_ref(bool enable = true) { return set(REF, enable); } - Variable *clear_ref(bool disable = true) { return clear(REF, disable); } // Learnable variable flag. bool learnable() const { return is(LEARNABLE); } Variable *set_learnable(bool enable = true) { return set(LEARNABLE, enable); } - Variable *clear_learnable(bool disable = true) { - return clear(LEARNABLE, disable); + + // Dynamic size flag. + bool dynamic() const { return is(DYNAMIC); } + Variable *set_dynamic(bool enable = true) { + return set(DYNAMIC, enable); } // Unique gradient flag. @@ -410,18 +416,6 @@ class Flow { Variable *set_unique(bool enable = true) { return set(UNIQUE, enable); } - Variable *clear_unique(bool disable = true) { - return clear(UNIQUE, disable); - } - - // Random intialize flag. - bool random() const { return is(RANDOM); } - Variable *set_random(bool enable = true) { - return set(RANDOM, enable); - } - Variable *clear_random(bool disable = true) { - return clear(RANDOM, disable); - } // Check if variable is a constant. bool constant() const { return data != nullptr; } @@ -490,6 +484,7 @@ class Flow { Shape shape; // variable shape char *data = nullptr; // data for constants (owned by flow) uint64_t size = 0; // size of data in bytes + Initialization init = INIT_ZERO; // initialization Operation *producer = nullptr; // operation producing variable std::vector consumers; // list of consumers of variable @@ -500,7 +495,7 @@ class Flow { struct Operation : public Artifact, public Attributes { // Variable flags. enum Flag { - NONE = 0, // no flags + NONE = 0, // no flags NOGRADIENT = 1, // do not compute gradient for op }; @@ -568,7 +563,7 @@ class Flow { struct Function : public Artifact { // Variable flags. enum Flag { - NONE = 0, // no flags + NONE = 0, // no flags TRAINING = 1, // function only needed for training BACKPROP = 2, // build gradient for function }; @@ -581,21 +576,15 @@ class Flow { Function *set_training(bool enable = true) { return set(TRAINING, enable); } - Function *clear_training(bool disable = true) { - return clear(TRAINING, disable); - } // Back-propagation flag. bool backprop() const { return is(BACKPROP); } Function *set_backkprop(bool enable = true) { return set(BACKPROP, enable); } - Function *clear_backprop(bool disable = true) { - return clear(BACKPROP, disable); - } std::vector ops; // ops for function in compute order - std::vector unused; // unused input variables + std::vector unused; // unused input/output variables Cell *cell = nullptr; // cell for function }; @@ -650,7 +639,7 @@ class Flow { void Read(const char *data, size_t size); // Save flow to file. - void Save(const string &filename, int version = kVersion) const; + void Save(const string &filename, int version = VERSION) const; // Analyze flow. void Analyze(const Transformations &transformations); diff --git a/sling/myelin/generator/elementwise.cc b/sling/myelin/generator/elementwise.cc index df49b1d5..abf3ee10 100644 --- a/sling/myelin/generator/elementwise.cc +++ b/sling/myelin/generator/elementwise.cc @@ -182,6 +182,24 @@ bool ElementwiseIndexGenerator::AllocateRegisters() { if (!offset_.is_valid()) return false; } + // Allocate registers for sparse iterator. + if (sparse_) { + bitmap_ = rr.try_alloc(); + if (!bitmap_.is_valid()) return false; + bits_ = rr.try_alloc(); + if (!bits_.is_valid()) return false; + mask_ = rr.try_alloc(); + if (!mask_.is_valid()) return false; + iend_ = rr.try_alloc(); + if (!iend_.is_valid()) return false; + } + + // Assignment target needs a base register. + if (output_ref_ != nullptr) { + input_[0]->base = rr.try_alloc(); + if (!input_[0]->base.is_valid()) return false; + } + // Allocate registers for iterators. for (auto *it : iterators_) { if (it->type == REPEAT || it->type == BROADCAST) { @@ -192,14 +210,16 @@ bool ElementwiseIndexGenerator::AllocateRegisters() { } // Allocate registers for locators. + std::vector simple_locators; for (auto *loc : locators_) { switch (loc->iterator->type) { case SIMPLE: case SCALAR: - // Allocate base register for non-instance variables. - if (loc->var->offset() == -1 || loc->var->ref()) { - loc->base = rr.try_alloc(); - if (!loc->base.is_valid()) return false; + // Base register only needed for non-instance variables. Allocation of + // a register is deferred until registers for other locators have been + // allocated. + if (loc->var->IsGlobal() || loc->var->ref()) { + if (!loc->base.is_valid()) simple_locators.push_back(loc); } break; case CONST: @@ -208,15 +228,15 @@ bool ElementwiseIndexGenerator::AllocateRegisters() { break; case REPEAT: // Allocate base register for non-instance variables. - if (loc->var->offset() == -1 || loc->var->ref()) { - loc->base = rr.try_alloc(); + if (loc->var->IsGlobal() || loc->var->ref()) { + if (!loc->base.is_valid()) loc->base = rr.try_alloc(); if (!loc->base.is_valid()) return false; } break; case SINGLE: case BROADCAST: // Allocate base and broadcast registers. - loc->base = rr.try_alloc(); + if (!loc->base.is_valid()) loc->base = rr.try_alloc(); if (!loc->base.is_valid()) return false; loc->repeat = rr.try_alloc(); if (!loc->repeat.is_valid()) return false; @@ -226,22 +246,21 @@ bool ElementwiseIndexGenerator::AllocateRegisters() { }; } - // Assignment target needs a base register. - if (output_ref_ != nullptr && !input_[0]->base.is_valid()) { - input_[0]->base = rr.try_alloc(); - if (!input_[0]->base.is_valid()) return false; + // Allocate registers for simple locators. These locators are loop-invariant + // and can either be initialized before the loop or on demand inside the + // loop depending on how many registers are available. + jit::Register scratch = jit::no_reg; + bool ondemand = rr.num_free() < simple_locators.size(); + if (ondemand) { + scratch = rr.try_alloc(); + if (!scratch.is_valid()) return false; } - - // Allocate registers for sparse iterator. - if (sparse_) { - bitmap_ = rr.try_alloc(); - if (!bitmap_.is_valid()) return false; - bits_ = rr.try_alloc(); - if (!bits_.is_valid()) return false; - mask_ = rr.try_alloc(); - if (!mask_.is_valid()) return false; - iend_ = rr.try_alloc(); - if (!iend_.is_valid()) return false; + for (auto *loc : simple_locators) { + loc->base = rr.try_alloc(); + if (!loc->base.is_valid()) { + loc->base = scratch; + loc->ondemand = true; + } } // Try to allocate extra base registers as an optimization. The base registers @@ -268,7 +287,7 @@ bool ElementwiseIndexGenerator::AllocateRegisters() { base_regs[loc->var->offset()] = loc->base; } } - } else { + } else if (loc->var->IsGlobal()) { loc->base = rr.try_alloc(); } } @@ -282,7 +301,7 @@ void ElementwiseIndexGenerator::GenerateInit() { // Load tensor addresses and initialize index registers. MacroAssembler *masm = masm_; for (auto *loc : locators_) { - if (loc->base.is_valid() && !loc->shared) { + if (loc->base.is_valid() && !loc->shared && !loc->ondemand) { __ LoadTensorAddress(loc->base, loc->var); } if (loc->repeat.is_valid()) { @@ -507,6 +526,11 @@ Operand ElementwiseIndexGenerator::addr(Express::Var *var) { CHECK(Valid(var)); Locator *loc = LookupLocator(var); + // Load base address on demand if needed. + if (loc->ondemand) { + masm_->LoadTensorAddress(loc->base, loc->var); + } + // Return operand for accessing variable. switch (loc->iterator->type) { case SIMPLE: diff --git a/sling/myelin/generator/elementwise.h b/sling/myelin/generator/elementwise.h index 502c993c..247dd9b2 100644 --- a/sling/myelin/generator/elementwise.h +++ b/sling/myelin/generator/elementwise.h @@ -105,6 +105,7 @@ class ElementwiseIndexGenerator : public IndexGenerator { jit::Register base = jit::no_reg; // base address register Iterator *iterator = nullptr; // iterator for iterating over elements bool shared = false; // shared base register + bool ondemand = false; // load base on demand size_t broadcast = 0; // broadcast iterations jit::Register repeat = jit::no_reg; // broadcast counter diff --git a/sling/myelin/gradient.cc b/sling/myelin/gradient.cc index cf045e25..bd818ef9 100644 --- a/sling/myelin/gradient.cc +++ b/sling/myelin/gradient.cc @@ -39,7 +39,7 @@ Gradients::Gradients(Flow *flow, // Create adjoints. for (Flow::Variable *v : vars) { // Constants have trivial derivatives. - if (v->constant()) continue; + if (v->constant() || v->is(Flow::Variable::NOGRADIENT)) continue; // Only floats are differentiable. if (v->type != DT_FLOAT && v->type != DT_DOUBLE) continue; @@ -49,6 +49,7 @@ Gradients::Gradients(Flow *flow, if (v->in()) dv->set_out(); if (v->out()) dv->set_in(); dv->set_ref(v->ref()); + dv->set_dynamic(v->dynamic()); // Connect adjoint to primal variable to ensure common layout. if (v->learnable()) flow->Connect({dv, v}); diff --git a/sling/myelin/graph.cc b/sling/myelin/graph.cc index 55e91040..d65ef71c 100644 --- a/sling/myelin/graph.cc +++ b/sling/myelin/graph.cc @@ -148,7 +148,11 @@ static void AppendOp(string *str, if (options.op_type_as_label) { if (op->HasAttr("expr")) { if (op->type == "Assign") str->append("↤ "); - str->append(op->GetAttr("expr")); + string expr = op->GetAttr("expr"); + for (char c : expr) { + str->push_back(c); + if (c == ';') str->append(" "); + } } else if (op->HasAttr("var")) { str->append("➔ "); str->append(op->GetAttr("var")); @@ -242,6 +246,7 @@ static void AppendVar(string *str, if (var->in()) str->append("in "); if (var->out()) str->append("out "); if (var->unique()) str->append("unique "); + if (var->is(Flow::Variable::NOGRADIENT)) str->append("nograd "); str->append("var "); str->append(var->name); if (!var->aliases.empty()) { diff --git a/sling/myelin/kernel/arithmetic.cc b/sling/myelin/kernel/arithmetic.cc index fc8af4f7..a2f7eac1 100644 --- a/sling/myelin/kernel/arithmetic.cc +++ b/sling/myelin/kernel/arithmetic.cc @@ -350,6 +350,8 @@ struct Expression { if (type == DT_FLOAT || type == DT_DOUBLE) { // Perform dry-run to estimate the number of SIMD registers needed. MacroAssembler masm(nullptr, 0, options); + masm.AllocateFunctionRegisters(); + masm.rr().reserve_all(); Expression expr(step, &masm, 0); CHECK(expr.AllocateRegisters()) << "Register overflow"; @@ -360,6 +362,18 @@ struct Expression { return spare_regs; } + // Return the number of registers used by expression. + static int RegisterUsage(const Step *step, const Options &options) { + MacroAssembler masm(nullptr, 0, options); + masm.AllocateFunctionRegisters(); + masm.rr().reserve_all(); + Expression expr(step, &masm, 0); + int before = masm.rr().num_free(); + CHECK(expr.AllocateRegisters()) << "Register overflow in " << step->name(); + int after = masm.rr().num_free(); + return before - after; + } + // Representative output (or input) from expression. Tensor *prototype; @@ -826,9 +840,11 @@ class ExpressionTransformer : public Transformer { Express expr; expr.Parse(fused_recipe); auto *vt = expr.Variable(Express::INPUT, target_index); - auto *v0 = expr.Variable(Express::INPUT, 0); + auto *v0 = expr.Variable(InputType(fused->inputs[0]), 0); vt->id = 0; + vt->type = Express::INPUT; v0->id = target_index; + v0->type = InputType(fused->inputs[0]); fused_recipe = expr.AsRecipe(); fused->SwapInputs(0, target_index); } @@ -1026,7 +1042,7 @@ class Calculate : public Kernel { return true; } - void Adjust(Step *step) override { + void Adjust(Step *step, const Options &options) override { Expression expression(step, nullptr); step->set_variant(expression.generator->Name()); @@ -1071,6 +1087,10 @@ class Calculate : public Kernel { } } } + + // Reserve extra registers. + int regs = Expression::RegisterUsage(step, options); + step->SetRegisterUsage(regs); } void Generate(Step *step, MacroAssembler *masm) override { diff --git a/sling/myelin/kernel/array.cc b/sling/myelin/kernel/array.cc index c56acf2f..facb203f 100644 --- a/sling/myelin/kernel/array.cc +++ b/sling/myelin/kernel/array.cc @@ -42,7 +42,7 @@ class Reshape : public Kernel { } void Adjust(Step *step) override { - CHECK(step->AllowInPlace(0, 0, true)); + CHECK(step->AllowInPlace(0, 0, true)) << step->name(); } void Generate(Step *step, MacroAssembler *masm) override { @@ -127,7 +127,7 @@ class Resize : public Kernel { public: bool Supports(Step *step) override { // Check inputs and outputs. - if (step->indegree() != 3 || step->outdegree() != 1) return false; + if (step->indegree() != 1 || step->outdegree() != 1) return false; Tensor *x = step->input(0); Tensor *y = step->output(0); if (x->type() != y->type()) return false; @@ -362,9 +362,6 @@ class Slice : public Kernel { return true; } - void Adjust(Step *step) override { - } - void Generate(Step *step, MacroAssembler *masm) override { // Get inputs and output. Tensor *source = step->input(0); @@ -401,7 +398,7 @@ class Slice : public Kernel { class BasicConcat : public Kernel { public: string Name() override { return "BasicConcat"; } - string Operation() override { return "ConcatV2"; } + string Operation() override { return "Concat"; } bool Supports(Step *step) override { // Check inputs and outputs. @@ -414,13 +411,11 @@ class BasicConcat : public Kernel { if (!axis->constant()) return false; int a = axis->value(); if (step->output(0)->shape().outer(a) != 1) return false; + if (step->output(0)->dynamic()) return false; return true; } - void Adjust(Step *step) override { - } - void Generate(Step *step, MacroAssembler *masm) override { // Get the number of tensors to concatenate. int n = step->GetAttr("N", step->indegree() - 1); @@ -454,7 +449,7 @@ class BasicConcat : public Kernel { class GeneralConcat : public Kernel { public: string Name() override { return "GeneralConcat"; } - string Operation() override { return "ConcatV2"; } + string Operation() override { return "Concat"; } bool Supports(Step *step) override { // Check inputs and outputs. @@ -475,22 +470,21 @@ class GeneralConcat : public Kernel { if (input->rank() < axis) return false; if (input->shape().outer(axis) != prefix) return false; if (input->type() != output->type()) return false; + if (input->dynamic() != output->dynamic()) return false; } return true; } - void Adjust(Step *step) override { - } - void Generate(Step *step, MacroAssembler *masm) override { // Get the number of tensors to concatenate. int n = step->GetAttr("N", step->indegree() - 1); + Tensor *output = step->output(0); // Allocate registers. Register src = masm->rr().alloc_preferred(rsi); Register dst = masm->rr().alloc_preferred(rdi); - Register out = masm->rr().alloc(); + Register cnt = masm->rr().alloc_preferred(rcx); Register idx = masm->rr().alloc(); std::vector in(n); for (int i = 0; i < n; ++i) in[i] = masm->rr().alloc(); @@ -501,32 +495,39 @@ class GeneralConcat : public Kernel { } // Load output tensor. - __ LoadTensorAddress(out, step->output(0)); - __ xorq(idx, idx); + __ LoadTensorAddress(dst, output); // Loop over outer prefix. Label l; int axis = step->input(n)->value(); - int prefix = step->output(0)->shape().outer(axis); + int repeat = output->shape().outer(axis); + if (output->dynamic()) { + __ LoadDynamicSize(idx, output, repeat); + step->set_variant("DYN"); + } else { + __ movq(idx, Immediate(repeat)); + } __ bind(&l); // Copy input tensors to output. - Tensor *output = step->output(0); + int copied = 0; for (int i = 0; i < n; ++i) { Tensor *input = step->input(i); - int size = axis > 0 ? input->stride(axis - 1) : input->size(); + int size = input->AxisSize(axis); __ movq(src, in[i]); - __ movq(dst, out); - __ Copy(dst, 0, src, 0, size); + __ movq(cnt, Immediate(size)); + __ repmovsb(); __ addq(in[i], Immediate(size)); + copied += size; } // Next chunk. - int size = axis > 0 ? output->stride(axis - 1) : output->size(); - __ addq(out, Immediate(size)); - __ incq(idx); - __ cmpq(idx, Immediate(prefix)); - __ j(less, &l); + int size = output->AxisSize(axis); + if (copied != size) { + __ addq(dst, Immediate(size - copied)); + } + __ decq(idx); + __ j(not_zero, &l); } int64 Complexity(const Step *step) override { @@ -558,7 +559,6 @@ class Split : public Kernel { if (axis->type() != DT_INT32 || !axis->constant()) return false; int a = axis->value(); if (a > input->rank() - 1) return false; - if (input->shape().outer(a) != 1) return false; // Check that outputs match the input. Type dt = input->type(); @@ -568,38 +568,74 @@ class Split : public Kernel { Tensor *output = step->output(i); if (output->type() != dt) return false; if (output->rank() != input->rank()) return false; - if (output->shape().outer(a) != 1) return false; if (output->shape().inner(a) != size / n) return false; + if (output->dynamic() != input->dynamic()) return false; } return true; } - void Adjust(Step *step) override { - } - void Generate(Step *step, MacroAssembler *masm) override { // Get input. Tensor *input = step->input(0); int n = step->input(1)->value(); int axis = step->input(2)->value(); - int chunk_size = input->shape().inner(axis) / n; - int stride = input->stride(axis) * chunk_size; + int repeat = input->shape().outer(axis); // Allocate registers. Register src = masm->rr().alloc_preferred(rsi); Register dst = masm->rr().alloc_preferred(rdi); - Register in = masm->rr().alloc(); + Register cnt = masm->rr().alloc_preferred(rcx); + Register idx = masm->rr().alloc_preferred(rcx); // Load input tensor. - __ LoadTensorAddress(in, input); + __ LoadTensorAddress(src, input); + + if (input->dynamic() || repeat > 1) { + // Load output tensors. + step->set_variant("REP"); + std::vector out(n); + for (int i = 0; i < n; ++i) { + out[i] = masm->rr().alloc(); + __ LoadTensorAddress(out[i], step->output(i)); + } - // Copy input tensors to output. - int offset = 0; - for (int i = 0; i < n; ++i) { - __ leaq(src, Operand(in, offset)); - __ LoadTensorAddress(dst, step->output(i)); - __ Copy(dst, 0, src, 0, stride); - offset += stride; + // Loop over outer prefix. + Label l; + if (input->dynamic()) { + __ LoadDynamicSize(idx, input, repeat); + step->set_variant("DYN"); + } else { + __ movq(idx, Immediate(repeat)); + } + __ bind(&l); + + // Split input to output. + int copied = 0; + for (int i = 0; i < n; ++i) { + Tensor *output = step->output(i); + int size = output->AxisSize(axis); + __ movq(dst, out[i]); + __ movq(cnt, Immediate(size)); + __ repmovsb(); + __ addq(out[i], Immediate(size)); + copied += size; + } + + // Next chunk. + int size = input->AxisSize(axis); + if (copied != size) { + __ addq(src, Immediate(size - copied)); + } + __ decq(idx); + __ j(not_zero, &l); + } else { + // Simple non-repeated split. + for (int i = 0; i < n; ++i) { + int size = step->output(i)->AxisSize(axis); + __ LoadTensorAddress(dst, step->output(i)); + __ movq(cnt, Immediate(size)); + __ repmovsb(); + } } } @@ -864,7 +900,7 @@ class PoolingGather : public Kernel { M->RequireOrder(ROW_MAJOR); // Reserve registers. - int regs = SIMDAssembler::RegisterUsage(type) + 8; + int regs = SIMDAssembler::RegisterUsage(type) + 9; step->SetRegisterUsage(regs); } @@ -1145,6 +1181,7 @@ class AssignAddScatter : public Kernel { // Reserve registers. int regs = SIMDAssembler::RegisterUsage(type) + 8; + if (args.scaler) regs++; step->SetRegisterUsage(regs); } diff --git a/sling/myelin/kernel/avx-math.cc b/sling/myelin/kernel/avx-math.cc index c023c55d..562a4455 100644 --- a/sling/myelin/kernel/avx-math.cc +++ b/sling/myelin/kernel/avx-math.cc @@ -40,8 +40,8 @@ class AVXFltArgMax : public Kernel { if (!CPU::Enabled(AVX2)) return false; // Check inputs and outputs. - if (step->inputs().size() != 1) return false; - if (step->outputs().size() != 1) return false; + if (step->indegree() != 1) return false; + if (step->outdegree() != 1) return false; Tensor *x = step->input(0); Tensor *y = step->output(0); diff --git a/sling/myelin/kernel/cuda-array.cc b/sling/myelin/kernel/cuda-array.cc index f4bf2110..764aba2d 100644 --- a/sling/myelin/kernel/cuda-array.cc +++ b/sling/myelin/kernel/cuda-array.cc @@ -27,7 +27,7 @@ class CUDABasicConcat : public CUDAKernel { static const int WORD_SIZE = 4; string Name() override { return "CUDABasicConcat"; } - string Operation() override { return "ConcatV2"; } + string Operation() override { return "Concat"; } bool Supports(Step *step) override { // Requires CUDA support. diff --git a/sling/myelin/kernel/generic-math.cc b/sling/myelin/kernel/generic-math.cc index a93d4f58..7a7736e1 100644 --- a/sling/myelin/kernel/generic-math.cc +++ b/sling/myelin/kernel/generic-math.cc @@ -15,6 +15,7 @@ #include "sling/myelin/kernel/generic.h" #include +#include #include #include "sling/myelin/compute.h" @@ -224,13 +225,17 @@ class GenericFltArgMax : public Kernel { bool Supports(Step *step) override { // Check inputs and outputs. - if (step->inputs().size() != 1) return false; - if (step->outputs().size() != 1) return false; + if (step->indegree() != 1) return false; + if (step->outdegree() != 1) return false; Tensor *x = step->input(0); Tensor *y = step->output(0); // Check type. - if (x->type() != DT_FLOAT) return false; + if (x->type() != DT_FLOAT && x->type() != DT_DOUBLE && + x->type() != DT_INT8 && x->type() != DT_INT16 && + x->type() != DT_INT32 && x->type() != DT_INT64) { + return false; + } if (y->type() != DT_INT32 && y->type() != DT_INT64) return false; if (y->elements() != 1) return false; @@ -241,37 +246,90 @@ class GenericFltArgMax : public Kernel { // Get input and output. Tensor *x = step->input(0); Tensor *y = step->output(0); + Type dt = x->type(); // Assign registers. Register input = masm->rr().alloc(); Register output = masm->rr().alloc(); Register idx = masm->rr().alloc(); Register best = masm->rr().alloc(); - XMMRegister value = masm->mm().allocx(); - XMMRegister maxval = masm->mm().allocx(); + Register ivalue = masm->rr().alloc(); + Register iextremum = masm->rr().alloc(); + XMMRegister fvalue = masm->mm().allocx(); + XMMRegister fextremum = masm->mm().allocx(); // Load tensor locations. __ LoadTensorAddress(input, x); __ LoadTensorAddress(output, y); - // Initialize max value. + // Initialize min/max value. __ movq(best, Immediate(-1)); - float inf = minimum_ ? INFINITY : -INFINITY; - __ movss(maxval, Operand(masm->GetConstant(inf)->address())); + if (minimum_) { + switch (dt) { + case DT_INT8: + case DT_INT16: + case DT_INT32: + case DT_INT64: + __ movq(iextremum, masm->MaxVal()->address()); + break; + case DT_FLOAT: + __ movss(fextremum, masm->MaxVal()->address()); + break; + case DT_DOUBLE: + __ movsd(fextremum, masm->MaxVal()->address()); + break; + default: ; + } + } else { + switch (dt) { + case DT_INT8: + case DT_INT16: + case DT_INT32: + case DT_INT64: + __ movq(iextremum, masm->MinVal()->address()); + break; + case DT_FLOAT: + __ movss(fextremum, masm->MinVal()->address()); + break; + case DT_DOUBLE: + __ movsd(fextremum, masm->MinVal()->address()); + break; + default: ; + } + } // Loop over elements in tensor. __ xorq(idx, idx); Label loop; __ LoopStart(&loop); - // Get next input value. - __ movss(value, Operand(input, idx, times_4)); - // Check if value is greater/less than current max value. + // Check if next value is greater/less than current extremum. Label l1; - __ ucomiss(value, maxval); - __ j(minimum_ ? above_equal : below_equal, &l1); - __ movss(maxval, value); + switch (dt) { + case DT_INT8: + case DT_INT16: + case DT_INT32: + case DT_INT64: + __ LoadInteger(ivalue, input, idx, dt); + __ cmpq(ivalue, iextremum); + __ j(minimum_ ? greater_equal : less_equal, &l1); + __ movq(iextremum, ivalue); + break; + case DT_FLOAT: + __ movss(fvalue, Operand(input, idx, times_4)); + __ ucomiss(fvalue, fextremum); + __ j(minimum_ ? above_equal : below_equal, &l1); + __ movss(fextremum, fvalue); + break; + case DT_DOUBLE: + __ movsd(fvalue, Operand(input, idx, times_8)); + __ ucomisd(fvalue, fextremum); + __ j(minimum_ ? above_equal : below_equal, &l1); + __ movsd(fextremum, fvalue); + break; + default: ; + } __ movq(best, idx); __ bind(&l1); diff --git a/sling/myelin/kernel/generic-matmul.cc b/sling/myelin/kernel/generic-matmul.cc index 2525e723..6487055e 100644 --- a/sling/myelin/kernel/generic-matmul.cc +++ b/sling/myelin/kernel/generic-matmul.cc @@ -639,7 +639,7 @@ class TransposeTransformer : public Transformer { var_refs++; } } - if (var_refs == 0) var->clear_out(); + if (var_refs == 0) var->set_out(false); } updates++; diff --git a/sling/myelin/kernel/generic.cc b/sling/myelin/kernel/generic.cc index b0b4c916..d3740aff 100644 --- a/sling/myelin/kernel/generic.cc +++ b/sling/myelin/kernel/generic.cc @@ -141,6 +141,14 @@ class RenameTransformer : public Transformer { op->type = "Add"; renames++; } + if (op->type == "ConcatV2") { + op->type = "Concat"; + renames++; + } + if (op->type == "GatherV2") { + op->type = "Gather"; + renames++; + } } return renames > 0; @@ -192,7 +200,7 @@ class IdentityTransformer : public Transformer { noops.push_back(op); } } - } else if (op->type == "Concat" || op->type == "ConcatV2") { + } else if (op->type == "Concat") { // Eliminate concatenations with only one input. int n = op->GetAttr("N", 0); if (n == 1) { @@ -227,7 +235,7 @@ class FlattenConcatTransformer : public Transformer { private: // Returns true if the operation is a concatenation. static bool IsConcat(const Flow::Operation &operation) { - if (operation.type != "ConcatV2") return false; + if (operation.type != "Concat") return false; if (!operation.HasAttr("N")) return false; const int num_to_concat = operation.GetAttr("N", -1); if (num_to_concat <= 0) return false; @@ -310,24 +318,14 @@ class FlattenConcatTransformer : public Transformer { } }; -// Normalizes "Gather" operations: -// 1. Replaces "GatherV2" with "Gather". -// 2. Removes the "axis" input when it is zero. +// Normalizes "Gather" operations. Removes the "axis" input when it is zero. class GatherTransformer : public Transformer { public: string Name() override { return "GatherTransformer"; } bool Transform(Flow *flow) override { + // Remove the "axis" input when it is zero. bool transformed = false; - - // First, normalize the operation type. - for (Flow::Operation *op : flow->ops()) { - if (op->type != "GatherV2") continue; - op->type = "Gather"; - transformed = true; - } - - // Next, remove the "axis" input when it is zero. for (Flow::Operation *op : flow->ops()) { if (op->type != "Gather") continue; // types were normalized above @@ -433,7 +431,7 @@ class StandardTyper : public Typer { } // Infer shape for concat operation. - if (op->type == "ConcatV2") { + if (op->type == "Concat") { int n = op->GetAttr("N", 0); if (n > op->indegree()) return false; int axis; diff --git a/sling/myelin/kernel/gradients.cc b/sling/myelin/kernel/gradients.cc index 65df99c6..8618281d 100644 --- a/sling/myelin/kernel/gradients.cc +++ b/sling/myelin/kernel/gradients.cc @@ -510,15 +510,41 @@ void concat_grad(Flow::Operation *op, Gradients *g) { int N = op->GetAttr("N", 0); int axis; CHECK(op->inputs.back()->GetData(&axis)); - Shape begin; - begin.redim(op->outputs[0]->rank()); + + // If all inputs have the same size the gradient is a Split. Otherwise a + // number of Slice ops are used. + bool equisized = true; for (int i = 0; i < N; ++i) { - auto vi = op->inputs[i]; - g->add(vi, g->Slice(g->d(v), g->Const(begin), vi->shape)); - begin.set(axis, begin.dim(axis) + vi->shape.dim(axis)); + if (op->inputs[0]->shape != op->inputs[i]->shape) equisized = false; + } + + if (equisized) { + auto parts = g->Split(g->d(v), N, axis); + for (int i = 0; i < N; ++i) { + g->add(op->inputs[i], parts[i]); + } + } else { + Shape begin; + begin.redim(op->outputs[0]->rank()); + for (int i = 0; i < N; ++i) { + auto vi = op->inputs[i]; + g->add(vi, g->Slice(g->d(v), g->Const(begin), vi->shape)); + begin.set(axis, begin.dim(axis) + vi->shape.dim(axis)); + } } } +// v_1, ..., v_n = split(v, n, axis) +// dv = concat({v_1, ..., v_n}, axis) +void split_grad(Flow::Operation *op, Gradients *g) { + auto v = op->inputs[0]; + int axis; + CHECK(op->inputs[2]->GetData(&axis)); + auto parts = op->outputs; + for (auto &p : parts) p = g->d(p); + g->add(v, g->Concat(parts, axis)); +} + // y = sum(x) // dx = dy void sum_grad(Flow::Operation *op, Gradients *g) { @@ -623,7 +649,8 @@ void RegisterStandardGradients() { RegisterGradient("Reshape", reshape_grad); RegisterGradient("Gather", gather_grad); RegisterGradient("GatherSum", gathersum_grad); - RegisterGradient("ConcatV2", concat_grad); + RegisterGradient("Concat", concat_grad); + RegisterGradient("Split", split_grad); RegisterGradient("Sum", sum_grad); RegisterGradient("Min", min_grad); RegisterGradient("Max", max_grad); diff --git a/sling/myelin/learning.cc b/sling/myelin/learning.cc index 7927d3f3..783198fa 100644 --- a/sling/myelin/learning.cc +++ b/sling/myelin/learning.cc @@ -192,8 +192,9 @@ void GradientDescentOptimizer::BuildOptimizer(const GradientMap &gradmap, auto *dv = it.second; // Add scaled gradient to parameters. - if (lambda_ != 0.0) { - tf.Assign(v, tf.Add(tf.Mul(tf.Sub(tf.One(), tf.Const(lambda_)), v), + float lambda = v->GetAttr("l2reg", lambda_); + if (lambda != 0.0) { + tf.Assign(v, tf.Add(tf.Mul(tf.Sub(tf.One(), tf.Const(lambda)), v), tf.Mul(dv, multiplier))); } else { tf.AssignAdd(v, tf.Mul(dv, multiplier)); diff --git a/sling/myelin/macro-assembler.cc b/sling/myelin/macro-assembler.cc index 0f6548f1..2c932a36 100644 --- a/sling/myelin/macro-assembler.cc +++ b/sling/myelin/macro-assembler.cc @@ -25,6 +25,25 @@ namespace myelin { using namespace jit; +// Register usage: +// +// rax: 1st return register, temporary register +// rbx: extra register, caller-preserved +// rcx: 4th argument register, temporary register +// rdx: 3rd argument register, 2nd return register, temporary register +// rdi: 1st argument register, temporary register +// rsi: 2nd argument register, temporary register +// rbp: data instance address, caller-preserved +// rsp: stack pointer +// r8 : 5th argument register, temporary register +// r9 : 6th argument register, temporary register +// r10: temporary register +// r11: temporary register +// r12: extra register, caller-preserved +// r13: extra register, caller-preserved +// r14: extra register, caller-preserved +// r15: extra register, caller-preserved, profiler timestamp register + #ifdef NDEBUG // Base register for data instance. static Register datareg = rbp; @@ -38,7 +57,7 @@ static Register tsreg = r14; #endif Register Registers::try_alloc() { - for (int r = 0; r < kNumRegisters; ++r) { + for (int r = 0; r < NUM_REGISTERS; ++r) { if (!used(r)) { use(r); return Register::from_code(r); @@ -54,7 +73,7 @@ Register Registers::alloc() { } Register Registers::try_alloc_preserved() { - for (int r = 0; r < kNumRegisters; ++r) { + for (int r = 0; r < NUM_REGISTERS; ++r) { if (!used(r) && preserved(r)) { use(r); return Register::from_code(r); @@ -107,7 +126,7 @@ Register Registers::arg(int n) { } Register Registers::alloc_extra() { - for (int r = 0; r < kNumRegisters; ++r) { + for (int r = 0; r < NUM_REGISTERS; ++r) { if (extra(r) && !saved(r)) { reserve(r); use(r); @@ -124,6 +143,14 @@ void Registers::reserve(int r) { used_regs_ &= ~(1 << r); } +void Registers::reserve_all() { + for (int r = 0; r < NUM_REGISTERS; ++r) { + if (extra(r) && !saved(r)) { + reserve(r); + } + } +} + void Registers::free(int r) { CHECK(saved(r)) << r; CHECK(!used(r)) << r; @@ -131,14 +158,16 @@ void Registers::free(int r) { used_regs_ |= (1 << r); } + bool Registers::usage(int n) { switch (n) { - case 13: reserve(r15); FALLTHROUGH_INTENDED; - case 12: reserve(r14); FALLTHROUGH_INTENDED; - case 11: reserve(r13); FALLTHROUGH_INTENDED; - case 10: reserve(r12); FALLTHROUGH_INTENDED; - case 9: reserve(rbx); FALLTHROUGH_INTENDED; - case 8: case 7: case 6: case 5: case 4: case 3: case 2: case 1: case 0: + case 14: reserve(r15); FALLTHROUGH_INTENDED; + case 13: reserve(r14); FALLTHROUGH_INTENDED; + case 12: reserve(r13); FALLTHROUGH_INTENDED; + case 11: reserve(r12); FALLTHROUGH_INTENDED; + case 10: reserve(rbx); FALLTHROUGH_INTENDED; + case 9: case 8: case 7: case 6: case 5: + case 4: case 3: case 2: case 1: case 0: return true; } return false; @@ -146,14 +175,14 @@ bool Registers::usage(int n) { int Registers::num_free() const { int n = 0; - for (int r = 0; r < kNumRegisters; ++r) { + for (int r = 0; r < NUM_REGISTERS; ++r) { if (!used(r)) n++; } return n; } int SIMDRegisters::try_alloc(bool extended) { - int n = extended ? kNumZRegisters : kNumXRegisters; + int n = extended ? NUM_Z_REGISTERS : NUM_X_REGISTERS; for (int i = next_; i < n + next_; ++i) { int r = i % n; if ((used_regs_ & (1 << r)) == 0) { @@ -172,7 +201,7 @@ int SIMDRegisters::alloc(bool extended) { } OpmaskRegister OpmaskRegisters::try_alloc() { - for (int r = 0; r < kNumRegisters; ++r) { + for (int r = 0; r < NUM_REGISTERS; ++r) { OpmaskRegister k = OpmaskRegister::from_code(r); if (!used(k)) { use(k); @@ -237,13 +266,7 @@ Register MacroAssembler::instance() const { return datareg; } -void MacroAssembler::Prologue() { - // Zero upper part of YMM register if CPU needs it to avoid AVX-SSE transition - // penalties. - if (CPU::VZeroNeeded() && Enabled(AVX)) { - vzeroupper(); - } - +void MacroAssembler::AllocateFunctionRegisters() { // Reserve data instance register. rr_.reserve(datareg); rr_.use(datareg); @@ -253,6 +276,17 @@ void MacroAssembler::Prologue() { rr_.reserve(tsreg); rr_.use(tsreg); } +} + +void MacroAssembler::Prologue() { + // Zero upper part of YMM register if CPU needs it to avoid AVX-SSE transition + // penalties. + if (CPU::VZeroNeeded() && Enabled(AVX)) { + vzeroupper(); + } + + // Reserve registers for function. + AllocateFunctionRegisters(); // Save preserved registers on stack. if (rr_.saved(rbp)) pushq(rbp); @@ -341,18 +375,27 @@ void MacroAssembler::LoadTensorAddress(Register dst, Tensor *tensor) { if (tensor->IsGlobal()) { DCHECK(tensor->data() != nullptr); load_extern(dst, tensor->data(), tensor->name(), options_.pic); - if (tensor->ref()) { + if (tensor->dynamic()) { + movq(dst, Operand(dst)); + movq(dst, Operand(dst)); + } else if (tensor->ref()) { movq(dst, Operand(dst)); } } else if (tensor->offset() == 0) { - if (tensor->ref()) { + if (tensor->dynamic()) { + movq(dst, Operand(datareg)); + movq(dst, Operand(dst)); + } else if (tensor->ref()) { movq(dst, Operand(datareg)); } else { movq(dst, datareg); } } else { DCHECK(tensor->offset() != -1) << tensor->name(); - if (tensor->ref()) { + if (tensor->dynamic()) { + movq(dst, Operand(datareg, tensor->offset())); + movq(dst, Operand(dst)); + } else if (tensor->ref()) { movq(dst, Operand(datareg, tensor->offset())); } else { leaq(dst, Operand(datareg, tensor->offset())); @@ -371,7 +414,7 @@ void MacroAssembler::LoadTensorAddress(Register dst, Tensor *tensor, std::vector index; CHECK(indices->GetData(&index)); int offset = tensor->offset(index); - if (tensor->IsGlobal() || tensor->ref()) { + if (tensor->IsGlobal() || tensor->ref() || tensor->dynamic()) { LoadTensorAddress(dst, tensor); if (offset != 0) addq(dst, Immediate(offset)); } else { @@ -383,7 +426,11 @@ void MacroAssembler::LoadTensorAddress(Register dst, Tensor *tensor, Register acc = rr_.alloc(); if (indices->rank() < 2) { LoadTensorAddress(dst, tensor); - if (indices->ref()) { + if (indices->dynamic()) { + movq(iptr, Operand(instance(), indices->offset())); + movq(iptr, Operand(iptr)); + movsxlq(acc, Operand(iptr)); + } else if (indices->ref()) { movq(iptr, Operand(instance(), indices->offset())); movsxlq(acc, Operand(iptr)); } else if (indices->IsGlobal()) { @@ -434,6 +481,15 @@ void MacroAssembler::LoadTensorDeviceAddress(Register dst, Tensor *tensor) { } } +void MacroAssembler::LoadDynamicSize(Register dst, Tensor *tensor, int scalar) { + CHECK(tensor->dynamic()); + CHECK(!tensor->ref()); + CHECK(tensor->IsLocal()); + movq(dst, Operand(datareg, tensor->offset())); + movq(dst, Operand(dst, sizeof(char *))); + Multiply(dst, scalar); +} + void MacroAssembler::Copy(Register dst, int ddisp, Register src, int sdisp, int size) { diff --git a/sling/myelin/macro-assembler.h b/sling/myelin/macro-assembler.h index 2473d887..dd33d870 100644 --- a/sling/myelin/macro-assembler.h +++ b/sling/myelin/macro-assembler.h @@ -15,6 +15,8 @@ #ifndef SLING_MYELIN_MACRO_ASSEMBLER_H_ #define SLING_MYELIN_MACRO_ASSEMBLER_H_ +#include + #include "sling/myelin/compute.h" #include "third_party/jit/assembler.h" @@ -37,11 +39,11 @@ class Registers { typedef jit::Register Register; // An x64 CPU has 16 general 64-bit registers. - static const int kNumRegisters = 16; + static const int NUM_REGISTERS = 16; // Initialize registers. Registers() - : used_regs_(kPreservedRegisters), saved_regs_(0) {} + : used_regs_(PRESERVED_REGISTERS), saved_regs_(0) {} Registers(const Registers &rr) : used_regs_(rr.used_regs_), saved_regs_(rr.saved_regs_) {} Registers &operator=(const Registers &rr) { @@ -87,11 +89,12 @@ class Registers { bool used(Register r) { return used(r.code()); } // Reset allocated registers. - void reset() { used_regs_ = kPreservedRegisters & ~saved_regs_; } + void reset() { used_regs_ = PRESERVED_REGISTERS & ~saved_regs_; } // Reserve callee-saved register for use. void reserve(int r); void reserve(Register r) { reserve(r.code()); } + void reserve_all(); // Free callee-saved register after it has been restored. void free(int r); @@ -106,11 +109,11 @@ class Registers { bool saved(Register r) { return saved(r.code()); } // Check if register is a callee-saved register. - static bool preserved(int r) { return ((1 << r) & kPreservedRegisters) != 0; } + static bool preserved(int r) { return ((1 << r) & PRESERVED_REGISTERS) != 0; } static bool preserved(Register r) { return preserved(r.code()); } // Check if register is an extra callee-saved register. - static bool extra(int r) { return ((1 << r) & kExtraRegisters) != 0; } + static bool extra(int r) { return ((1 << r) & EXTRA_REGISTERS) != 0; } static bool extra(Register r) { return extra(r.code()); } // Return the number of free registers. @@ -118,7 +121,7 @@ class Registers { private: // Preserved registers. - static const int kPreservedRegisters = + static const int PRESERVED_REGISTERS = 1 << Register::kCode_rbx | 1 << Register::kCode_rsp | 1 << Register::kCode_rbp | @@ -128,7 +131,7 @@ class Registers { 1 << Register::kCode_r15; // Extra callee-saved registers. - static const int kExtraRegisters = + static const int EXTRA_REGISTERS = 1 << Register::kCode_rbx | 1 << Register::kCode_r12 | 1 << Register::kCode_r13 | @@ -150,8 +153,8 @@ class SIMDRegisters { typedef jit::ZMMRegister ZMMRegister; // An x64 CPU has up to 16 SIMD registers (or 32 in AVX512 mode). - static const int kNumXRegisters = 16; - static const int kNumZRegisters = 32; + static const int NUM_X_REGISTERS = 16; + static const int NUM_Z_REGISTERS = 32; // Initialize SIMD registers. SIMDRegisters() : used_regs_(0) {} @@ -217,7 +220,7 @@ class OpmaskRegisters { typedef jit::OpmaskRegister OpmaskRegister; // There are 8 opmask registers (k0 to k7) where k0 is a constant register. - static const int kNumRegisters = 8; + static const int NUM_REGISTERS = 8; // Initialize opmask registers. OpmaskRegisters() : used_regs_(kSpecialRegisters) {} @@ -307,6 +310,9 @@ class MacroAssembler : public jit::Assembler { MacroAssembler(void *buffer, int buffer_size, const Options &options); ~MacroAssembler(); + // Allocate registers for function prologue/epilogue. + void AllocateFunctionRegisters(); + // Generate function prologue. void Prologue(); @@ -359,6 +365,16 @@ class MacroAssembler : public jit::Assembler { return data; } + // Get data block for minimum value for type. + template StaticData *MinVal(int repeat = 1) { + return GetConstant(std::numeric_limits::lowest(), repeat); + } + + // Get data block for maximum value for type. + template StaticData *MaxVal(int repeat = 1) { + return GetConstant(std::numeric_limits::max(), repeat); + } + // Generate static data blocks in the code buffer. void GenerateDataBlocks(); @@ -371,6 +387,9 @@ class MacroAssembler : public jit::Assembler { // Load address of tensor on device. void LoadTensorDeviceAddress(Register dst, Tensor *tensor); + // Load size of dynamic tensor (e.g. channel) and multiply with scalar. + void LoadDynamicSize(Register dst, Tensor *tensor, int scalar = 1); + // Emit breakpoint. void Breakpoint() { int3(); } diff --git a/sling/myelin/rnn.cc b/sling/myelin/rnn.cc index 0e4c4162..5ac73875 100644 --- a/sling/myelin/rnn.cc +++ b/sling/myelin/rnn.cc @@ -20,248 +20,843 @@ namespace sling { namespace myelin { -void BiLSTM::LSTM::Initialize(const Network &net, const string &name) { - // Initialize LSTM cell. +RNN::Variables RNN::Build(Flow *flow, + Flow::Variable *input, + Flow::Variable *dinput) { + // Build RNN cell. + bool learn = dinput != nullptr; + Variables vars; + FlowBuilder tf(flow, name); + auto dt = input->type; + int input_dim = input->dim(1); + int rnn_dim = spec.dim; + + // Build inputs. + auto *x = tf.Placeholder("input", dt, input->shape, true); + auto *h_in = tf.Placeholder("h_in", dt, {1, rnn_dim}, true); + Flow::Variable *c_in = nullptr; + if (spec.type != GRU) { + c_in = tf.Placeholder("c_in", dt, {1, rnn_dim}, true); + } + + // Build recurrent unit. + Flow::Variable *h_out = nullptr; // hidden output + Flow::Variable *c_out = nullptr; // control output + Flow::Variable *residual = nullptr; // residial gate for highway connection + switch (spec.type) { + case LSTM: { + // Standard LSTM. + auto *x2i = tf.Parameter("x2i", dt, {input_dim, rnn_dim}); + auto *h2i = tf.Parameter("h2i", dt, {rnn_dim, rnn_dim}); + auto *bi = tf.Parameter("bi", dt, {1, rnn_dim}); + tf.RandomOrtho(x2i); + tf.RandomOrtho(h2i); + + auto *x2f = tf.Parameter("x2f", dt, {input_dim, rnn_dim}); + auto *h2f = tf.Parameter("h2f", dt, {rnn_dim, rnn_dim}); + auto *bf = tf.Parameter("bf", dt, {1, rnn_dim}); + tf.RandomOrtho(x2f); + tf.RandomOrtho(h2f); + + auto *x2g = tf.Parameter("x2g", dt, {input_dim, rnn_dim}); + auto *h2g = tf.Parameter("h2g", dt, {rnn_dim, rnn_dim}); + auto *bg = tf.Parameter("bg", dt, {1, rnn_dim}); + tf.RandomOrtho(x2g); + tf.RandomOrtho(h2g); + + auto *x2o = tf.Parameter("x2o", dt, {input_dim, rnn_dim}); + auto *h2o = tf.Parameter("h2o", dt, {rnn_dim, rnn_dim}); + auto *bo = tf.Parameter("bo", dt, {1, rnn_dim}); + tf.RandomOrtho(x2o); + tf.RandomOrtho(h2o); + + // i = sigmoid(x * x2i + h_in * h2i + bi) + auto *ia = tf.Add(tf.MatMul(x, x2i), tf.Add(tf.MatMul(h_in, h2i), bi)); + auto *i = tf.Name(tf.Sigmoid(ia), "i"); + + // f = sigmoid(x * x2f + h_in * h2f + bf) + auto *fa = tf.Add(tf.MatMul(x, x2f), tf.Add(tf.MatMul(h_in, h2f), bf)); + auto *f = tf.Name(tf.Sigmoid(fa), "f"); + + // g = tanh(x * x2g + h_in * h2g + bg) + auto *ga = tf.Add(tf.MatMul(x, x2g), tf.Add(tf.MatMul(h_in, h2g), bg)); + auto *g = tf.Name(tf.Tanh(ga), "g"); + + // o = sigmoid(x * x2o + h_in * h2o + bo) + auto *oa = tf.Add(tf.MatMul(x, x2o), tf.Add(tf.MatMul(h_in, h2o), bo)); + auto *o = tf.Name(tf.Sigmoid(oa), "o"); + + // residual = sigmoid(x * x2r + h_in * h2r + br) + if (spec.highways) { + auto *x2r = tf.Parameter("x2r", dt, {input_dim, rnn_dim}); + auto *h2r = tf.Parameter("h2r", dt, {rnn_dim, rnn_dim}); + auto *br = tf.Parameter("br", dt, {1, rnn_dim}); + tf.RandomOrtho(x2r); + tf.RandomOrtho(h2r); + + auto *ra = tf.Add(tf.Add(tf.MatMul(x, x2r),tf.MatMul(h_in, h2r)), br); + residual = tf.Name(tf.Sigmoid(ra), "r"); + } + + // c_out = f * c_in + i * g + c_out = tf.Add(tf.Mul(f, c_in), tf.Mul(i, g)); + + // h_out = o * tanh(c_out) + h_out = tf.Mul(o, tf.Tanh(c_out)); + break; + } + + case DRAGNN_LSTM: { + // DRAGNN LSTM with peephole and coupled gates. + auto *x2i = tf.Parameter("x2i", dt, {input_dim, rnn_dim}); + auto *h2i = tf.Parameter("h2i", dt, {rnn_dim, rnn_dim}); + auto *c2i = tf.Parameter("c2i", dt, {rnn_dim, rnn_dim}); + auto *bi = tf.Parameter("bi", dt, {1, rnn_dim}); + tf.RandomOrtho(x2i); + tf.RandomOrtho(h2i); + tf.RandomOrtho(c2i); + + auto *x2o = tf.Parameter("x2o", dt, {input_dim, rnn_dim}); + auto *h2o = tf.Parameter("h2o", dt, {rnn_dim, rnn_dim}); + auto *c2o = tf.Parameter("c2o", dt, {rnn_dim, rnn_dim}); + auto *bo = tf.Parameter("bo", dt, {1, rnn_dim}); + tf.RandomOrtho(x2o); + tf.RandomOrtho(h2o); + tf.RandomOrtho(c2o); + + auto *x2c = tf.Parameter("x2c", dt, {input_dim, rnn_dim}); + auto *h2c = tf.Parameter("h2c", dt, {rnn_dim, rnn_dim}); + auto *bc = tf.Parameter("bc", dt, {1, rnn_dim}); + tf.RandomOrtho(x2c); + tf.RandomOrtho(h2c); + + // i = sigmoid(x * x2i + h_in * h2i + c_in * c2i + bi) + auto *ia = tf.Add(tf.MatMul(x, x2i), + tf.Add(tf.MatMul(h_in, h2i), + tf.Add(tf.MatMul(c_in, c2i), bi))); + auto *i = tf.Name(tf.Sigmoid(ia), "i"); + + // f = 1 - i + auto *f = tf.Name(tf.Sub(tf.One(), i), "f"); + + // w = tanh(x * x2c + h_in * h2c + bc) + auto *wa = tf.Add(tf.MatMul(x, x2c), + tf.Add(tf.MatMul(h_in, h2c), bc)); + auto *w = tf.Name(tf.Tanh(wa), "w"); + + // c_out = i * w + f * c_in + c_out = tf.Add(tf.Mul(i, w), tf.Mul(f, c_in)); + + // o = sigmoid(x * x2o + c_out * c2o + h_in * h2o + bo) + auto *oa = tf.Add(tf.MatMul(x, x2o), + tf.Add(tf.MatMul(c_out, c2o), + tf.Add(tf.MatMul(h_in, h2o), bo))); + auto *o = tf.Name(tf.Sigmoid(oa), "o"); + + // r = sigmoid(x * x2r + h_in * h2r + br) + if (spec.highways) { + auto *x2r = tf.Parameter("x2r", dt, {input_dim, rnn_dim}); + auto *h2r = tf.Parameter("h2r", dt, {rnn_dim, rnn_dim}); + auto *br = tf.Parameter("br", dt, {1, rnn_dim}); + tf.RandomOrtho(x2r); + tf.RandomOrtho(h2r); + auto *ra = tf.Add(tf.Add(tf.MatMul(x, x2r),tf.MatMul(h_in, h2r)), br); + residual = tf.Name(tf.Sigmoid(ra), "r"); + } + + // h_out = o * tanh(c_out) + h_out = tf.Mul(o, tf.Tanh(c_out)); + break; + } + + case DOZAT_LSTM: { + // Standard LSTM with one matrix multiplication. + int gates = spec.highways ? 5 : 4; + auto *w = tf.Parameter("W", dt, {input_dim + rnn_dim, gates * rnn_dim}); + auto *b = tf.Parameter("b", dt, {1, gates * rnn_dim}); + tf.RandomOrtho(w); + + // Preactivations. + auto *xh = tf.Concat({x, h_in}); + auto p = tf.Split(tf.Add(tf.MatMul(xh, w), b), gates, 1); + + // Gates. + auto *f = tf.Name(tf.Sigmoid(p[0]), "f"); + auto *i = tf.Name(tf.Sigmoid(p[1]), "i"); + auto *o = tf.Name(tf.Sigmoid(p[2]), "o"); + auto *g = tf.Name(tf.Tanh(p[3]), "g"); + if (spec.highways) { + residual = tf.Name(tf.Sigmoid(p[4]), "r"); + } + + // Outputs. + c_out = tf.Add(tf.Mul(f, c_in), tf.Mul(i, g)); + h_out = tf.Mul(o, tf.Tanh(c_out)); + break; + } + + case PYTORCH_LSTM: { + // Standard LSTM with two matrix multiplications. + int gates = spec.highways ? 5 : 4; + auto *w_ih = tf.Parameter("w_ih", dt, {input_dim, gates * rnn_dim}); + auto *w_hh = tf.Parameter("w_hh", dt, {rnn_dim, gates * rnn_dim}); + auto *b_ih = tf.Parameter("b_ih", dt, {1, gates * rnn_dim}); + auto *b_hh = tf.Parameter("b_hh", dt, {1, gates * rnn_dim}); + tf.RandomOrtho(w_ih); + tf.RandomOrtho(w_hh); + + // Preactivations. + auto *ih = tf.Add(tf.MatMul(x, w_ih), b_ih); + auto *hh = tf.Add(tf.MatMul(h_in, w_hh), b_hh); + auto p = tf.Split(tf.Add(ih, hh), gates, 1); + + // Gates. + auto *f = tf.Name(tf.Sigmoid(p[0]), "f"); + auto *i = tf.Name(tf.Sigmoid(p[1]), "i"); + auto *o = tf.Name(tf.Sigmoid(p[2]), "o"); + auto *g = tf.Name(tf.Tanh(p[3]), "g"); + if (spec.highways) { + residual = tf.Name(tf.Sigmoid(p[4]), "r"); + } + + // Outputs. + c_out = tf.Add(tf.Mul(f, c_in), tf.Mul(i, g)); + h_out = tf.Mul(o, tf.Tanh(c_out)); + break; + } + + case GRU: { + // Gated Recurrent Unit. + auto *x2z = tf.Parameter("x2z", dt, {input_dim, rnn_dim}); + auto *h2z = tf.Parameter("h2z", dt, {rnn_dim, rnn_dim}); + tf.RandomOrtho(x2z); + tf.RandomOrtho(h2z); + + auto *x2r = tf.Parameter("x2r", dt, {input_dim, rnn_dim}); + auto *h2r = tf.Parameter("h2r", dt, {rnn_dim, rnn_dim}); + tf.RandomOrtho(x2r); + tf.RandomOrtho(h2r); + + auto *x2h = tf.Parameter("x2h", dt, {input_dim, rnn_dim}); + auto *h2h = tf.Parameter("h2h", dt, {rnn_dim, rnn_dim}); + tf.RandomOrtho(x2h); + tf.RandomOrtho(h2h); + + // z = sigmoid(x * x2z + h_in * h2z) + auto *za = tf.Add(tf.MatMul(x, x2z), tf.MatMul(h_in, h2z)); + auto *z = tf.Name(tf.Sigmoid(za), "z"); + + // r = sigmoid(x * x2r + h_in * h2r) + auto *ra = tf.Add(tf.MatMul(x, x2r), tf.MatMul(h_in, h2r)); + auto *r = tf.Name(tf.Sigmoid(ra), "r"); + + // h = tanh(x * x2h + (r * h_in) * h2h) + auto *ha = tf.Add(tf.MatMul(x, x2h), tf.MatMul(tf.Mul(r, h_in), h2h)); + auto *h = tf.Name(tf.Tanh(ha), "h"); + + // residual = sigmoid(x * x2b + h_in * h2b) + if (spec.highways) { + auto *x2b = tf.Parameter("x2b", dt, {input_dim, rnn_dim}); + auto *h2b = tf.Parameter("h2b", dt, {rnn_dim, rnn_dim}); + tf.RandomOrtho(x2b); + tf.RandomOrtho(h2b); + auto *ra = tf.Add(tf.MatMul(x, x2b),tf.MatMul(h_in, h2b)); + residual = tf.Name(tf.Sigmoid(ra), "r"); + } + + // h_out = (1 - z) * h_in + z * h + h_out = tf.Add(tf.Mul(tf.Sub(tf.One(), z), h_in), tf.Mul(z, h)); + break; + } + + default: + LOG(FATAL) << "RNN type not supported: " << spec.type; + } + + // Highway connection. + if (residual != nullptr) { + // Highway connection. + auto *bypass = x; + if (input_dim != rnn_dim) { + // Linear transform from input to output dimension. + auto *wx = tf.RandomOrtho(tf.Parameter("Wr", dt, {input_dim, rnn_dim})); + bypass = tf.MatMul(x, wx); + } + h_out = tf.Add(tf.Mul(residual, h_out), + tf.Mul(tf.Sub(tf.One(), residual), bypass)); + } + + // Apply dropout to output. + if (learn && spec.dropout != 0.0) { + auto *mask = tf.Placeholder("mask", DT_FLOAT, {1, rnn_dim}, true); + mask->set(Flow::Variable::NOGRADIENT); + h_out = tf.Mul(h_out, mask); + + // The no-dropout mask is used for testing during training when no dropout + // should be applied. + std::vector ones(rnn_dim, 1.0); + auto *nodropout = tf.Name(tf.Const(ones), "nodropout"); + nodropout->set_out(); + flow->Connect({nodropout, mask}); + } + + // Name RNN outputs. + if (h_out != nullptr) tf.Name(h_out, "h_out"); + if (c_out != nullptr) tf.Name(c_out, "c_out"); + + // Make zero element. + auto *zero = tf.Name(tf.Const(nullptr, dt, {1, rnn_dim}), "zero"); + zero->set_out(); + + // Connect RNN units. + vars.input = x; + vars.output = h_out; + flow->Connect({x, input}); + h_out->set_out()->set_ref(); + flow->Connect({h_in, h_out, zero}); + if (c_in != nullptr) { + c_out->set_out()->set_ref(); + flow->Connect({c_in, c_out, zero}); + + // The control channel has a single-source gradient. + c_in->set_unique(); + } + + // Build gradients for learning. + if (learn) { + auto *gf = Gradient(flow, tf.func()); + vars.dinput = flow->GradientVar(vars.input); + vars.doutput = flow->GradientVar(vars.output); + flow->Connect({vars.dinput, dinput}); + + // Make sink variable for final channel gradients. + auto *sink = tf.Var("sink", dt, {1, rnn_dim})->set_out(); + gf->unused.push_back(sink); + auto *dh_in = flow->GradientVar(h_in); + auto *dh_out = flow->GradientVar(h_out); + flow->Connect({dh_in, dh_out, sink}); + if (c_out != nullptr) { + auto *dc_in = flow->GradientVar(c_in); + auto *dc_out = flow->GradientVar(c_out); + flow->Connect({dc_in, dc_out, sink}); + } + } + + return vars; +} + +void RNN::Initialize(const Network &net) { + // Initialize RNN cell. Control channel is optional. cell = net.GetCell(name); input = net.GetParameter(name + "/input"); h_in = net.GetParameter(name + "/h_in"); h_out = net.GetParameter(name + "/h_out"); - c_in = net.GetParameter(name + "/c_in"); - c_out = net.GetParameter(name + "/c_out"); + c_in = net.LookupParameter(name + "/c_in"); + c_out = net.LookupParameter(name + "/c_out"); + zero = net.GetParameter(name + "/zero"); - // Initialize gradient cell for LSTM. + // Initialize gradient cell for RNN. gcell = cell->Gradient(); if (gcell != nullptr) { primal = cell->Primal(); dinput = input->Gradient(); dh_in = h_in->Gradient(); dh_out = h_out->Gradient(); - dc_in = c_in->Gradient(); - dc_out = c_out->Gradient(); + dc_in = c_in == nullptr ? nullptr : c_in->Gradient(); + dc_out = c_out == nullptr ? nullptr : c_out->Gradient(); + sink = net.GetParameter(name + "/sink"); + } + + // Initialize dropout mask. + if (spec.dropout != 0.0) { + mask = net.GetParameter(name + "/mask"); + nodropout = net.GetParameter(name + "/nodropout"); } } -// Build flows for LSTMs. -BiLSTM::Outputs BiLSTM::Build(Flow *flow, int dim, - Flow::Variable *input, - Flow::Variable *dinput) { - Outputs out; +RNNMerger::Variables RNNMerger::Build(Flow *flow, + Flow::Variable *left, + Flow::Variable *right, + Flow::Variable *dleft, + Flow::Variable *dright) { + Variables vars; - // Build left-to-right LSTM flow. - FlowBuilder lr(flow, name_ + "/lr"); - auto *lr_input = lr.Placeholder("input", input->type, input->shape, true); - out.lr = lr.LSTMLayer(lr_input, dim); + // Build merger cell. + FlowBuilder f(flow, name); + vars.left = f.Placeholder("left", left->type, left->shape); + vars.left->set_dynamic()->set_unique(); - // Build right-to-left LSTM flow. - FlowBuilder rl(flow, name_ + "/rl"); - auto *rl_input = rl.Placeholder("input", input->type, input->shape, true); - out.rl = rl.LSTMLayer(rl_input, dim); + vars.right = f.Placeholder("right", right->type, right->shape); + vars.right->set_dynamic()->set_unique(); - // Connect input to LSTMs. - flow->Connect({input, lr_input, rl_input}); + vars.merged = f.Name(f.Concat({vars.left, vars.right}, 1), "merged"); + vars.merged->set_dynamic(); + flow->Connect({vars.left, left}); + flow->Connect({vars.right, right}); // Build gradients for learning. - if (dinput != nullptr) { - Gradient(flow, lr.func()); - Gradient(flow, rl.func()); - out.dlr = flow->GradientVar(lr_input); - out.drl = flow->GradientVar(rl_input); - flow->Connect({dinput, out.dlr, out.drl}); + if (dleft != nullptr && dright != nullptr) { + Gradient(flow, f.func()); + vars.dmerged = flow->GradientVar(vars.merged); + vars.dleft = flow->GradientVar(vars.left); + vars.dright = flow->GradientVar(vars.right); + flow->Connect({vars.dleft, dleft}); + flow->Connect({vars.dright, dright}); } else { - out.dlr = nullptr; - out.drl = nullptr; + vars.dmerged = vars.dleft = vars.dright = nullptr; } - return out; + return vars; } -void BiLSTM::Initialize(const Network &net) { - lr_.Initialize(net, name_ + "/lr"); - rl_.Initialize(net, name_ + "/rl"); +void RNNMerger::Initialize(const Network &net) { + cell = net.GetCell(name); + left = net.GetParameter(name + "/left"); + right = net.GetParameter(name + "/right"); + merged = net.GetParameter(name + "/merged"); + + gcell = cell->Gradient(); + if (gcell != nullptr) { + dmerged = merged->Gradient(); + dleft = left->Gradient(); + dright = right->Gradient(); + } +} + +RNNLayer::RNNLayer(const string &name, const RNN::Spec &spec, bool bidir) + : name_(name), + bidir_(bidir), + dropout_(spec.dropout), + lr_(bidir ? name + "/lr" : name, spec), + rl_(name + "/rl", spec), + merger_(name) {} + +RNN::Variables RNNLayer::Build(Flow *flow, + Flow::Variable *input, + Flow::Variable *dinput) { + if (bidir_) { + // Build left-to-right and right-to-left RNNs. + auto l = lr_.Build(flow, input, dinput); + auto r = rl_.Build(flow, input, dinput); + + // Build channel merger. + auto m = merger_.Build(flow, l.output, r.output, l.doutput, r.doutput); + + // Return outputs. + RNN::Variables vars; + vars.input = l.input; + vars.output = m.merged; + vars.dinput = l.dinput; + vars.doutput = m.dmerged; + + return vars; + } else { + return lr_.Build(flow, input, dinput); + } } -BiLSTMInstance::BiLSTMInstance(const BiLSTM &bilstm) - : bilstm_(bilstm), - lr_(bilstm.lr_.cell), - rl_(bilstm.rl_.cell), - lr_hidden_(bilstm.lr_.h_out), - lr_control_(bilstm.lr_.c_out), - rl_hidden_(bilstm.rl_.h_out), - rl_control_(bilstm.rl_.c_out) {} - -BiChannel BiLSTMInstance::Compute(Channel *input) { - // Reset hidden and control channels. +void RNNLayer::Initialize(const Network &net) { + lr_.Initialize(net); + if (bidir_) { + rl_.Initialize(net); + merger_.Initialize(net); + } +} + +RNNInstance::RNNInstance(const RNNLayer *rnn) + : rnn_(rnn), + lr_(rnn->lr_.cell), + lr_hidden_(rnn->lr_.h_out), + lr_control_(rnn->lr_.c_out), + rl_(rnn->rl_.cell), + rl_hidden_(rnn->rl_.h_out), + rl_control_(rnn->rl_.c_out), + merger_(rnn->merger_.cell), + merged_(rnn->merger_.merged) {} + +Channel *RNNInstance::Compute(Channel *input) { + // Get sequence length. int length = input->size(); - lr_hidden_.reset(length + 1); - rl_hidden_.reset(length + 1); - lr_control_.resize(length + 1); - rl_control_.resize(length + 1); - lr_control_.zero(length); - rl_control_.zero(length); - - // Compute left-to-right LSTM. - for (int i = 0; i < length; ++i) { - // Input. - lr_.Set(bilstm_.lr_.input, input, i); - lr_.Set(bilstm_.lr_.h_in, &lr_hidden_, i > 0 ? i - 1 : length); - lr_.Set(bilstm_.lr_.c_in, &lr_control_, i > 0 ? i - 1 : length); - - // Output. - lr_.Set(bilstm_.lr_.h_out, &lr_hidden_, i); - lr_.Set(bilstm_.lr_.c_out, &lr_control_, i); - - // Compute LSTM cell. + bool ctrl = rnn_->lr_.has_control(); + + // Set pass-through dropout mask. + if (rnn_->lr_.has_mask()) { + lr_.SetReference(rnn_->lr_.mask, rnn_->lr_.nodropout->data()); + } + + // Compute left-to-right RNN. + lr_hidden_.resize(length); + if (ctrl) lr_control_.resize(length); + + if (length > 0) { + lr_.Set(rnn_->lr_.input, input, 0); + lr_.SetReference(rnn_->lr_.h_in, rnn_->lr_.zero->data()); + lr_.Set(rnn_->lr_.h_out, &lr_hidden_, 0); + if (ctrl) { + lr_.SetReference(rnn_->lr_.c_in, rnn_->lr_.zero->data()); + lr_.Set(rnn_->lr_.c_out, &lr_control_, 0); + } + lr_.Compute(); + } + + for (int i = 1; i < length; ++i) { + lr_.Set(rnn_->lr_.input, input, i); + lr_.Set(rnn_->lr_.h_in, &lr_hidden_, i - 1); + lr_.Set(rnn_->lr_.h_out, &lr_hidden_, i); + if (ctrl) { + lr_.Set(rnn_->lr_.c_in, &lr_control_, i - 1); + lr_.Set(rnn_->lr_.c_out, &lr_control_, i); + } lr_.Compute(); } - // Compute right-to-left LSTM. - for (int i = length - 1; i >= 0; --i) { - // Input. - rl_.Set(bilstm_.rl_.input, input, i); - rl_.Set(bilstm_.rl_.h_in, &rl_hidden_, i + 1); - rl_.Set(bilstm_.rl_.c_in, &rl_control_, i + 1); + // Return left-to-right hidden channel for unidirectional RNN. + if (!rnn_->bidir_) return &lr_hidden_; - // Output. - rl_.Set(bilstm_.rl_.h_out, &rl_hidden_, i); - rl_.Set(bilstm_.rl_.c_out, &rl_control_, i); + // Set pass-through dropout mask. + if (rnn_->rl_.has_mask()) { + rl_.SetReference(rnn_->rl_.mask, rnn_->rl_.nodropout->data()); + } + + // Compute right-to-left RNN. + rl_hidden_.resize(length); + if (ctrl) rl_control_.resize(length); + + if (length > 0) { + rl_.Set(rnn_->rl_.input, input, length - 1); + rl_.SetReference(rnn_->rl_.h_in, rnn_->rl_.zero->data()); + rl_.Set(rnn_->rl_.h_out, &rl_hidden_, length - 1); + if (ctrl) { + rl_.SetReference(rnn_->rl_.c_in, rnn_->rl_.zero->data()); + rl_.Set(rnn_->rl_.c_out, &rl_control_, length - 1); + } + rl_.Compute(); + } - // Compute LSTM cell. + for (int i = length - 2; i >= 0; --i) { + rl_.Set(rnn_->rl_.input, input, i); + rl_.Set(rnn_->rl_.h_in, &rl_hidden_, i + 1); + rl_.Set(rnn_->rl_.h_out, &rl_hidden_, i); + if (ctrl) { + rl_.Set(rnn_->rl_.c_in, &rl_control_, i + 1); + rl_.Set(rnn_->rl_.c_out, &rl_control_, i); + } rl_.Compute(); } - return BiChannel(&lr_hidden_, &rl_hidden_); + // Merge outputs. + merged_.resize(length); + merger_.SetChannel(rnn_->merger_.left, &lr_hidden_); + merger_.SetChannel(rnn_->merger_.right, &rl_hidden_); + merger_.SetChannel(rnn_->merger_.merged, &merged_); + merger_.Compute(); + + return &merged_; } -BiLSTMLearner::BiLSTMLearner(const BiLSTM &bilstm) - : bilstm_(bilstm), - lr_gradient_(bilstm.lr_.gcell), - rl_gradient_(bilstm.rl_.gcell), - lr_hidden_(bilstm.lr_.h_out), - lr_control_(bilstm.lr_.c_out), - rl_hidden_(bilstm.rl_.h_out), - rl_control_(bilstm.rl_.c_out), - dlr_hidden_(bilstm.lr_.dh_in), - dlr_control_(bilstm.lr_.dc_in), - drl_hidden_(bilstm.rl_.dh_in), - drl_control_(bilstm.rl_.dc_in), - dinput_(bilstm.lr_.dinput) {} - -BiLSTMLearner::~BiLSTMLearner() { - for (Instance *data : lr_) delete data; - for (Instance *data : rl_) delete data; +RNNLearner::RNNLearner(const RNNLayer *rnn) + : rnn_(rnn), + lr_fwd_(rnn->lr_.cell), + lr_hidden_(rnn->lr_.h_out), + lr_control_(rnn->lr_.c_out), + lr_bkw_(rnn->lr_.gcell), + lr_dhidden_(rnn->lr_.dh_in), + lr_dcontrol_(rnn->lr_.dc_in), + rl_fwd_(rnn->rl_.cell), + rl_hidden_(rnn->rl_.h_out), + rl_control_(rnn->rl_.c_out), + rl_bkw_(rnn->rl_.gcell), + rl_dhidden_(rnn->rl_.dh_in), + rl_dcontrol_(rnn->rl_.dc_in), + dinput_(rnn_->lr_.dinput), + merger_(rnn->merger_.cell), + splitter_(rnn->merger_.gcell), + merged_(rnn->merger_.merged), + dleft_(rnn->merger_.dleft), + dright_(rnn->merger_.dright), + mask_(rnn_->lr_.mask) { + if (rnn->dropout_ != 0.0) { + mask_.resize(1); + } } -BiChannel BiLSTMLearner::Compute(Channel *input) { - // Allocate instances. +Channel *RNNLearner::Compute(Channel *input) { + // Get sequence length. int length = input->size(); - for (auto *data : lr_) delete data; - for (auto *data : rl_) delete data; - lr_.resize(length); - rl_.resize(length); - for (int i = 0; i < length; ++i) { - lr_[i] = new Instance(bilstm_.lr_.cell); - rl_[i] = new Instance(bilstm_.rl_.cell); + bool ctrl = rnn_->lr_.has_control(); + + // Set up dropout mask. + bool dropout = rnn_->dropout_ != 0.0; + if (dropout) { + float *mask = reinterpret_cast(mask_.at(0)); + float rate = rnn_->dropout_; + float scaler = 1.0 / (1.0 - rate); + int size = rnn_->lr_.spec.dim; + for (int i = 0; i < size; ++i) { + mask[i] = Random() < rate ? 0.0 : scaler; + } } - // Reset hidden and control channels. - lr_hidden_.reset(length + 1); - rl_hidden_.reset(length + 1); - lr_control_.resize(length + 1); - rl_control_.resize(length + 1); - lr_control_.zero(length); - rl_control_.zero(length); - - // Compute left-to-right LSTM. - for (int i = 0; i < length; ++i) { - Instance *lr = lr_[i]; + // Compute left-to-right RNN. + lr_fwd_.Resize(length); + lr_hidden_.resize(length); + if (ctrl) lr_control_.resize(length); + + if (length > 0) { + Instance &data = lr_fwd_[0]; + data.Set(rnn_->lr_.input, input, 0); + data.SetReference(rnn_->lr_.h_in, rnn_->lr_.zero->data()); + data.Set(rnn_->lr_.h_out, &lr_hidden_, 0); + if (ctrl) { + data.SetReference(rnn_->lr_.c_in, rnn_->lr_.zero->data()); + data.Set(rnn_->lr_.c_out, &lr_control_, 0); + } + if (dropout) { + data.Set(rnn_->lr_.mask, &mask_, 0); + } + data.Compute(); + } - // Input. - lr->Set(bilstm_.lr_.input, input, i); - lr->Set(bilstm_.lr_.h_in, &lr_hidden_, i > 0 ? i - 1 : length); - lr->Set(bilstm_.lr_.c_in, &lr_control_, i > 0 ? i - 1 : length); + for (int i = 1; i < length; ++i) { + Instance &data = lr_fwd_[i]; + data.Set(rnn_->lr_.input, input, i); + data.Set(rnn_->lr_.h_in, &lr_hidden_, i - 1); + data.Set(rnn_->lr_.h_out, &lr_hidden_, i); + if (ctrl) { + data.Set(rnn_->lr_.c_in, &lr_control_, i - 1); + data.Set(rnn_->lr_.c_out, &lr_control_, i); + } + if (dropout) { + data.Set(rnn_->lr_.mask, &mask_, 0); + } + data.Compute(); + } - /// Output. - lr->Set(bilstm_.lr_.h_out, &lr_hidden_, i); - lr->Set(bilstm_.lr_.c_out, &lr_control_, i); + // Return left-to-right hidden channel for unidirectional RNN. + if (!rnn_->bidir_) return &lr_hidden_; + + // Compute right-to-left RNN. + rl_fwd_.Resize(length); + rl_hidden_.resize(length); + if (ctrl) rl_control_.resize(length); + + if (length > 0) { + Instance &data = rl_fwd_[length - 1]; + data.Set(rnn_->rl_.input, input, length - 1); + data.SetReference(rnn_->rl_.h_in, rnn_->rl_.zero->data()); + data.Set(rnn_->rl_.h_out, &rl_hidden_, length - 1); + if (ctrl) { + data.SetReference(rnn_->rl_.c_in, rnn_->rl_.zero->data()); + data.Set(rnn_->rl_.c_out, &rl_control_, length - 1); + } + if (dropout) { + data.Set(rnn_->rl_.mask, &mask_, 0); + } + data.Compute(); + } - // Compute LSTM cell. - lr->Compute(); + for (int i = length - 2; i >= 0; --i) { + Instance &data = rl_fwd_[i]; + data.Set(rnn_->rl_.input, input, i); + data.Set(rnn_->rl_.h_in, &rl_hidden_, i + 1); + data.Set(rnn_->rl_.h_out, &rl_hidden_, i); + if (ctrl) { + data.Set(rnn_->rl_.c_in, &rl_control_, i + 1); + data.Set(rnn_->rl_.c_out, &rl_control_, i); + } + if (dropout) { + data.Set(rnn_->rl_.mask, &mask_, 0); + } + data.Compute(); } - // Compute right-to-left LSTM. - for (int i = length - 1; i >= 0; --i) { - Instance *rl = rl_[i]; + // Merge outputs. + merged_.resize(length); + merger_.SetChannel(rnn_->merger_.left, &lr_hidden_); + merger_.SetChannel(rnn_->merger_.right, &rl_hidden_); + merger_.SetChannel(rnn_->merger_.merged, &merged_); + merger_.Compute(); + + return &merged_; +} - // Input. - rl->Set(bilstm_.rl_.input, input, i); - rl->Set(bilstm_.rl_.h_in, &rl_hidden_, i + 1); - rl->Set(bilstm_.rl_.c_in, &rl_control_, i + 1); +Channel *RNNLearner::Backpropagate(Channel *doutput) { + // Clear input gradient. + int length = doutput->size(); + dinput_.reset(length); + bool ctrl = rnn_->lr_.has_control(); + + // Split gradient for bidirectional RNN. + Channel *dleft; + Channel *dright; + if (rnn_->bidir_) { + // Split gradients. + dleft_.resize(length); + dright_.resize(length); + splitter_.SetChannel(rnn_->merger_.dmerged, doutput); + splitter_.SetChannel(rnn_->merger_.dleft, &dleft_); + splitter_.SetChannel(rnn_->merger_.dright, &dright_); + splitter_.Compute(); + dleft = &dleft_; + dright = &dright_; + } else { + dleft = doutput; + dright = nullptr; + } - // Output. - rl->Set(bilstm_.rl_.h_out, &rl_hidden_, i); - rl->Set(bilstm_.rl_.c_out, &rl_control_, i); + // Propagate gradients for left-to-right RNN. + if (dleft != nullptr) { + if (ctrl) lr_dcontrol_.reset(length); + + for (int i = length - 1; i > 0; --i) { + lr_bkw_.Set(rnn_->lr_.primal, &lr_fwd_[i]); + lr_bkw_.Set(rnn_->lr_.dh_out, dleft, i); + lr_bkw_.Set(rnn_->lr_.dh_in, dleft, i - 1); + lr_bkw_.Set(rnn_->lr_.dinput, &dinput_, i); + if (ctrl) { + lr_bkw_.Set(rnn_->lr_.dc_out, &lr_dcontrol_, i); + lr_bkw_.Set(rnn_->lr_.dc_in, &lr_dcontrol_, i - 1); + } + lr_bkw_.Compute(); + } + + if (length > 0) { + void *sink = lr_bkw_.GetAddress(rnn_->lr_.sink); + lr_bkw_.Set(rnn_->lr_.primal, &lr_fwd_[0]); + lr_bkw_.Set(rnn_->lr_.dh_out, dleft, 0); + lr_bkw_.SetReference(rnn_->lr_.dh_in, sink); + lr_bkw_.Set(rnn_->lr_.dinput, &dinput_, 0); + if (ctrl) { + lr_bkw_.Set(rnn_->lr_.dc_out, &lr_dcontrol_, 0); + lr_bkw_.SetReference(rnn_->lr_.dc_in, sink); + } + lr_bkw_.Compute(); + } + } - // Compute LSTM cell. - rl->Compute(); + // Propagate gradients for right-to-left RNN. + if (dright != nullptr) { + if (ctrl) rl_dcontrol_.reset(length); + + for (int i = 0; i < length - 1; ++i) { + rl_bkw_.Set(rnn_->rl_.primal, &rl_fwd_[i]); + rl_bkw_.Set(rnn_->rl_.dh_out, dright, i); + rl_bkw_.Set(rnn_->rl_.dh_in, dright, i + 1); + rl_bkw_.Set(rnn_->rl_.dinput, &dinput_, i); + if (ctrl) { + rl_bkw_.Set(rnn_->rl_.dc_out, &rl_dcontrol_, i); + rl_bkw_.Set(rnn_->rl_.dc_in, &rl_dcontrol_, i + 1); + } + rl_bkw_.Compute(); + } + + if (length > 0) { + void *sink = rl_bkw_.GetAddress(rnn_->rl_.sink); + rl_bkw_.Set(rnn_->rl_.primal, &rl_fwd_[length - 1]); + rl_bkw_.Set(rnn_->rl_.dh_out, dright, length - 1); + rl_bkw_.SetReference(rnn_->rl_.dh_in, sink); + rl_bkw_.Set(rnn_->rl_.dinput, &dinput_, length - 1); + if (ctrl) { + rl_bkw_.Set(rnn_->rl_.dc_out, &rl_dcontrol_, length - 1); + rl_bkw_.SetReference(rnn_->rl_.dc_in, sink); + } + rl_bkw_.Compute(); + } } - return BiChannel(&lr_hidden_, &rl_hidden_); + // Return input gradient. + return &dinput_; } -BiChannel BiLSTMLearner::PrepareGradientChannels(int length) { - dlr_hidden_.reset(length + 1); - drl_hidden_.reset(length + 1); - dlr_control_.resize(length + 1); - drl_control_.resize(length + 1); - dlr_control_.zero(length); - drl_control_.zero(length); +void RNNLearner::Clear() { + lr_bkw_.Clear(); + if (rnn_->bidir_) rl_bkw_.Clear(); +} - return BiChannel(&dlr_hidden_, &drl_hidden_); +void RNNLearner::CollectGradients(std::vector *gradients) { + gradients->push_back(&lr_bkw_); + if (rnn_->bidir_) gradients->push_back(&rl_bkw_); } -Channel *BiLSTMLearner::Backpropagate() { - // Clear input gradient. - int length = lr_.size(); - dinput_.reset(length); +void RNNStack::AddLayer(const RNN::Spec &spec, bool bidir) { + string name = name_ + "/rnn" + std::to_string(layers_.size()); + layers_.emplace_back(name, spec, bidir); +} - // Propagate gradients for left-to-right LSTM. - for (int i = length - 1; i >= 0; --i) { - // Set reference to primal cell. - lr_gradient_.Set(bilstm_.lr_.primal, lr_[i]); +void RNNStack::AddLayers(int layers, const RNN::Spec &spec, bool bidir) { + for (int l = 0; l < layers; ++l) { + AddLayer(spec, bidir); + } +} - // Gradient inputs. - lr_gradient_.Set(bilstm_.lr_.dh_out, &dlr_hidden_, i); - lr_gradient_.Set(bilstm_.lr_.dc_out, &dlr_control_, i); +RNN::Variables RNNStack::Build(Flow *flow, + Flow::Variable *input, + Flow::Variable *dinput) { + RNN::Variables vars; + vars.input = vars.output = input; + vars.dinput = vars.doutput = dinput; + for (RNNLayer &l : layers_) { + RNN::Variables v = l.Build(flow, vars.output, vars.doutput); + vars.output = v.output; + vars.doutput = v.doutput; + } + return vars; +} - // Gradient outputs. - lr_gradient_.Set(bilstm_.lr_.dh_in, &dlr_hidden_, i > 0 ? i - 1 : length); - lr_gradient_.Set(bilstm_.lr_.dc_in, &dlr_control_, i > 0 ? i - 1 : length); - lr_gradient_.Set(bilstm_.lr_.dinput, &dinput_, i); +void RNNStack::Initialize(const Network &net) { + for (RNNLayer &l : layers_) { + l.Initialize(net); + } +} - // Compute backward. - lr_gradient_.Compute(); +RNNStackInstance::RNNStackInstance(const RNNStack &stack) { + layers_.reserve(stack.layers().size()); + for (const RNNLayer &l : stack.layers()) { + layers_.emplace_back(&l); } +} - // Propagate gradients for right-to-left LSTM. - for (int i = 0; i < length; ++i) { - // Set reference to primal cell. - rl_gradient_.Set(bilstm_.rl_.primal, rl_[i]); +Channel *RNNStackInstance::Compute(Channel *input) { + Channel *channel = input; + for (RNNInstance &l : layers_) { + channel = l.Compute(channel); + } + return channel; +} - // Gradient inputs. - rl_gradient_.Set(bilstm_.rl_.dh_out, &drl_hidden_, i); - rl_gradient_.Set(bilstm_.rl_.dc_out, &drl_control_, i); +RNNStackLearner::RNNStackLearner(const RNNStack &stack) { + layers_.reserve(stack.layers().size()); + for (const RNNLayer &l : stack.layers()) { + layers_.emplace_back(&l); + } +} - // Gradient outputs. - rl_gradient_.Set(bilstm_.rl_.dh_in, &drl_hidden_, i + 1); - rl_gradient_.Set(bilstm_.rl_.dc_in, &drl_control_, i + 1); - rl_gradient_.Set(bilstm_.rl_.dinput, &dinput_, i); +Channel *RNNStackLearner::Compute(Channel *input) { + Channel *channel = input; + for (RNNLearner &l : layers_) { + channel = l.Compute(channel); + } + return channel; +} - // Compute backward. - rl_gradient_.Compute(); +Channel *RNNStackLearner::Backpropagate(Channel *doutput) { + Channel *channel = doutput; + for (int i = layers_.size() - 1; i >= 0; --i) { + channel = layers_[i].Backpropagate(channel); } + return channel; +} - // Return input gradient. - return &dinput_; +void RNNStackLearner::Clear() { + for (RNNLearner &l : layers_) { + l.Clear(); + } +} + +void RNNStackLearner::CollectGradients(std::vector *gradients) { + for (RNNLearner &l : layers_) { + l.CollectGradients(gradients); + } } } // namespace myelin diff --git a/sling/myelin/rnn.h b/sling/myelin/rnn.h index fa880670..16650bd3 100644 --- a/sling/myelin/rnn.h +++ b/sling/myelin/rnn.h @@ -15,6 +15,8 @@ #ifndef SLING_MYELIN_RNN_H_ #define SLING_MYELIN_RNN_H_ +#include +#include #include #include "sling/myelin/compute.h" @@ -23,133 +25,309 @@ namespace sling { namespace myelin { -// Channel pair with left-to-right and right-to-left channels. -struct BiChannel { - BiChannel(Channel *lr, Channel *rl) : lr(lr), rl(rl) {} - Channel *lr; // left-to-right channel - Channel *rl; // right-to-left channel +class RNNInstance; +class RNNLearner; + +// Recurrent Neural Network (RNN) cell. +struct RNN { + // RNN types. + enum Type { + // Standard LSTM [Hochreiter & Schmidhuber 1997]. + LSTM = 0, + + // LSTM with peephole connections [Gers & Schmidhuber 2000] and coupled + // forget and input gates [Greff et al. 2015]. + DRAGNN_LSTM = 1, + + // Standard LSTM with one matrix multiplication [Dozat & Manning 2017]. + DOZAT_LSTM = 2, + + // Standard LSTM with two matrix multiplications [Paszke et al. 2019]. + PYTORCH_LSTM = 3, + + // Gated Recurrent Unit (GRU) [Cho et al. 2014]. + GRU = 4, + }; + + // RNN specification. + struct Spec { + Type type = LSTM; // RNN type + int dim = 128; // RNN dimension + bool highways = false; // use highway connections between layers + float dropout = 0.0; // dropout rate during training (0=no dropout) + }; + + // Flow input/output variables. + struct Variables { + Flow::Variable *input = nullptr; // input to forward path + Flow::Variable *output = nullptr; // output from forward path + Flow::Variable *doutput = nullptr; // gradient input to backward path + Flow::Variable *dinput = nullptr; // gradient output from backward path + }; + + // Initialize RNN. + RNN(const string &name, const Spec &spec) : name(name), spec(spec) {} + + // Build flow for RNN. If dinput is not null, the corresponding gradient + // function is also built. + Variables Build(Flow *flow, + Flow::Variable *input, + Flow::Variable *dinput = nullptr); + + // Initialize RNN. + void Initialize(const Network &net); + + // Control channel is optional for RNN. + bool has_control() const { return c_in != nullptr; } + + // Dropout is only needed during training. + bool has_mask() const { return mask != nullptr; } + + string name; // RNN cell name + Spec spec; // RNN specification + + Cell *cell = nullptr; // RNN cell + Tensor *input = nullptr; // RNN feature input + Tensor *h_in = nullptr; // link to RNN hidden input + Tensor *h_out = nullptr; // link to RNN hidden output + Tensor *c_in = nullptr; // link to RNN control input + Tensor *c_out = nullptr; // link to RNN control output + Tensor *zero = nullptr; // zero element for channels + Tensor *mask = nullptr; // dropout mask input + + Tensor *nodropout = nullptr; // dropout mask with no dropout + + Cell *gcell = nullptr; // RNN gradient cell + Tensor *dinput = nullptr; // input gradient + Tensor *primal = nullptr; // link to primal RNN cell + Tensor *dh_in = nullptr; // gradient for RNN hidden input + Tensor *dh_out = nullptr; // gradient for RNN hidden output + Tensor *dc_in = nullptr; // gradient for RNN control input + Tensor *dc_out = nullptr; // gradient for RNN control output + Tensor *sink = nullptr; // scratch element for channels }; -// Bi-directional long short-term memory (LSTM) module. -class BiLSTM { - public: - // Flow output variables. - struct Outputs { - Flow::Variable *lr; // output from left-to-right LSTM (hidden) - Flow::Variable *rl; // output from right-to-left LSTM (hidden) - Flow::Variable *dlr; // gradient output from right-to-left LSTM (dinput) - Flow::Variable *drl; // gradient output from right-to-left LSTM (dinput) +// Channel merger cell for merging the outputs from two RNNs. +struct RNNMerger { + // Flow input/output variables. + struct Variables { + Flow::Variable *left; // left input to forward path + Flow::Variable *right; // right input to forward path + Flow::Variable *merged; // merged output from forward path + + Flow::Variable *dmerged; // merged gradient from backward path + Flow::Variable *dleft; // left gradient output from backward path + Flow::Variable *dright; // right gradient output from backward path }; - // Initialize bi-directional LSTM. - BiLSTM(const string &name = "lstm") : name_(name) {} + // Initialize RNN merger. + RNNMerger(const string &name) : name(name) {} + + // Build flow for channel merger. If dleft and dright are not null, the + // corresponding gradient function is also built. + Variables Build(Flow *flow, + Flow::Variable *left, Flow::Variable *right, + Flow::Variable *dleft, Flow::Variable *dright); + + // Initialize channel merger. + void Initialize(const Network &net); + + string name; // cell name + + Cell *cell = nullptr; // merger cell + Tensor *left = nullptr; // left channel input + Tensor *right = nullptr; // right channel input + Tensor *merged = nullptr; // merged output channel + + Cell *gcell = nullptr; // merger gradient cell + Tensor *dmerged = nullptr; // gradient for merged channel + Tensor *dleft = nullptr; // gradient for left channel + Tensor *dright = nullptr; // gradient for right channel +}; + +// An RNN layer can be either unidirectional (left-to-right) or bidirectional +// (both left-to-right and right-to-left). The outputs from the the two RNNs +// in a bidirectional RNN are merged using an RNN channel merger. +class RNNLayer { + public: + // Set up RNN layer. + RNNLayer(const string &name, const RNN::Spec &spec, bool bidir); - // Build flows for LSTMs. - Outputs Build(Flow *flow, int dim, - Flow::Variable *input, - Flow::Variable *dinput = nullptr); + // Build flow for RNN. If dinput is not null, the corresponding gradient + // function is also built. + RNN::Variables Build(Flow *flow, + Flow::Variable *input, + Flow::Variable *dinput = nullptr); - // Initialize LSTMs. + // Initialize RNN. void Initialize(const Network &net); private: - // Network for LSTM cell. - struct LSTM { - // Initialize LSTM cell from network. - void Initialize(const Network &net, const string &name); - - Cell *cell = nullptr; // LSTM cell - Tensor *input = nullptr; // LSTM feature input - Tensor *h_in = nullptr; // link to LSTM hidden input - Tensor *h_out = nullptr; // link to LSTM hidden output - Tensor *c_in = nullptr; // link to LSTM control input - Tensor *c_out = nullptr; // link to LSTM control output - - Cell *gcell = nullptr; // LSTM gradient cell - Tensor *dinput = nullptr; // input gradient - Tensor *primal = nullptr; // link to primal LSTM cell - Tensor *dh_in = nullptr; // gradient for LSTM hidden input - Tensor *dh_out = nullptr; // gradient for LSTM hidden output - Tensor *dc_in = nullptr; // gradient for LSTM control input - Tensor *dc_out = nullptr; // gradient for LSTM control output - }; + string name_; // cell name prefix + bool bidir_; // bidirectional RNN + float dropout_; // dropout ratio during learning. - string name_; // LSTM cell name prefix - LSTM lr_; // left-to-right LSTM - LSTM rl_; // right-to-left LSTM + RNN lr_; // left-to-right RNN + RNN rl_; // right-to-left RNN (if bidirectional) + RNNMerger merger_; // channel merger for bidirectional RNN - friend class BiLSTMInstance; - friend class BiLSTMLearner; + friend class RNNInstance; + friend class RNNLearner; }; -// Bi-directional LSTM instance. -class BiLSTMInstance { +// Instance of RNN layer for inference. +class RNNInstance { public: - // Initialize bi-directional LSTM instance. - BiLSTMInstance(const BiLSTM &bilstm); + RNNInstance(const RNNLayer *rnn); - // Compute left-to-right and right-to-left LSTM sequences for input. - BiChannel Compute(Channel *input); + // Compute RNN over input sequence and return output sequence. + Channel *Compute(Channel *input); private: - const BiLSTM &bilstm_; // bi-directional LSTM + // Descriptor for RNN layer. + const RNNLayer *rnn_; + + // Left-to-right RNN. + Instance lr_; + Channel lr_hidden_; + Channel lr_control_; + + // Right-to-left RNN for bidirectional RNN. + Instance rl_; + Channel rl_hidden_; + Channel rl_control_; + + // RNN channel merger for bidirectional RNN. + Instance merger_; + Channel merged_; +}; + +// Instance of RNN layer for learning. +class RNNLearner { + public: + RNNLearner(const RNNLayer *rnn); - Instance lr_; // left-to-right LSTM instance - Instance rl_; // right-to-left LSTM instance + // Compute RNN over input sequence and return output sequence. Dropout is + // only applied in learning mode. + Channel *Compute(Channel *input); - Channel lr_hidden_; // left-to-right LSTM hidden channel - Channel lr_control_; // left-to-right LSTM control channel - Channel rl_hidden_; // right-to-left LSTM hidden channel - Channel rl_control_; // right-to-left LSTM control channel + // Backpropagate gradients returning the output of backpropagation, i.e. the + // gradient of the input sequence. + Channel *Backpropagate(Channel *doutput); + + // Clear accumulated gradients. + void Clear(); + + // Collect instances with gradient updates. + void CollectGradients(std::vector *gradients); + + private: + // Generate uniform random number between 0 and 1. + float Random() { return prob_(prng_); } + + // Descriptor for RNN layer. + const RNNLayer *rnn_; + + // Left-to-right RNN. + InstanceArray lr_fwd_; + Channel lr_hidden_; + Channel lr_control_; + + Instance lr_bkw_; + Channel lr_dhidden_; + Channel lr_dcontrol_; + + // Right-to-left RNN for bidirectional RNN. + InstanceArray rl_fwd_; + Channel rl_hidden_; + Channel rl_control_; + + Instance rl_bkw_; + Channel rl_dhidden_; + Channel rl_dcontrol_; + + // Channel for gradient output. + Channel dinput_; + + // RNN channel merger for bidirectional RNN. + Instance merger_; + Instance splitter_; + Channel merged_; + Channel dleft_; + Channel dright_; + + // Channel for dropout mask. + Channel mask_; + + // Random generator for dropout. + std::mt19937_64 prng_; + std::uniform_real_distribution prob_{0.0, 1.0}; }; -// Bi-directional LSTM learner. -class BiLSTMLearner { +// Multi-layer RNN. +class RNNStack { public: - // Initialize bi-directional LSTM learner. - BiLSTMLearner(const BiLSTM &bilstm); - ~BiLSTMLearner(); + RNNStack(const string &name) : name_(name) {} - // Compute left-to-right and right-to-left LSTM sequences for input. - BiChannel Compute(Channel *input); + // Add RNN layer. + void AddLayer(const RNN::Spec &spec, bool bidir); - // Prepare gradient channels. - BiChannel PrepareGradientChannels(int length); + // Add multiple RNN layers of the same type. + void AddLayers(int layers, const RNN::Spec &spec, bool bidir); - // Backpropagate hidden gradients to input gradient. - Channel *Backpropagate(); + // Build flow for RNNs. + RNN::Variables Build(Flow *flow, + Flow::Variable *input, + Flow::Variable *dinput = nullptr); - // Collect gradients. - void CollectGradients(std::vector *gradients) { - gradients->push_back(&lr_gradient_); - gradients->push_back(&rl_gradient_); - } + // Initialize RNN stack. + void Initialize(const Network &net); - // Clear gradients. - void Clear() { - lr_gradient_.Clear(); - rl_gradient_.Clear(); - } + // Layers in RNN stack. + const std::vector &layers() const { return layers_; } private: - const BiLSTM &bilstm_; // bi-directional LSTM + // Name prefix for RNN cells. + string name_; - std::vector lr_; // left-to-right LSTM instances - std::vector rl_; // right-to-left LSTM instances - Instance lr_gradient_; // left-to-right LSTM gradients - Instance rl_gradient_; // right-to-left LSTM gradients + // RNN layers. + std::vector layers_; +}; - Channel lr_hidden_; // left-to-right LSTM hidden channel - Channel lr_control_; // left-to-right LSTM control channel - Channel rl_hidden_; // right-to-left LSTM hidden channel - Channel rl_control_; // right-to-left LSTM control channel +// Multi-layer RNN instance for prediction. +class RNNStackInstance { + public: + RNNStackInstance(const RNNStack &stack); + + // Compute RNN over input sequence and return output sequence. + Channel *Compute(Channel *input); + + private: + // RNN prediction instances for all layers. + std::vector layers_; +}; - Channel dlr_hidden_; // left-to-right LSTM hidden gradient channel - Channel dlr_control_; // left-to-right LSTM control gradient channel - Channel drl_hidden_; // right-to-left LSTM hidden gradient channel - Channel drl_control_; // right-to-left LSTM control gradient channel +// Multi-layer RNN layer for learning. +class RNNStackLearner { + public: + RNNStackLearner(const RNNStack &stack); - Channel dinput_; // input gradient channel + // Compute RNN over input sequence and return output sequence. + Channel *Compute(Channel *input); + + // Backpropagate gradients returning the output of backpropagation, i.e. the + // gradient of the input sequence. + Channel *Backpropagate(Channel *doutput); + + // Clear accumulated gradients. + void Clear(); + + // Collect instances with gradient updates. + void CollectGradients(std::vector *gradients); + + private: + // RNN learner instances for all layers. + std::vector layers_; }; } // namespace myelin diff --git a/sling/myelin/simd-assembler.cc b/sling/myelin/simd-assembler.cc index f8d99856..2edeaed2 100644 --- a/sling/myelin/simd-assembler.cc +++ b/sling/myelin/simd-assembler.cc @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include - #include "sling/myelin/simd-assembler.h" namespace sling { @@ -21,14 +19,6 @@ namespace myelin { using namespace jit; -template StaticData *MinVal(MacroAssembler *masm, int repeat) { - return masm->GetConstant(std::numeric_limits::min(), repeat); -} - -template StaticData *MaxVal(MacroAssembler *masm, int repeat) { - return masm->GetConstant(std::numeric_limits::max(), repeat); -} - StaticData *SIMDGenerator::NeutralElement(Reduction op, Type type, int repeat) { switch (op) { case REDUCE_ADD: @@ -45,22 +35,22 @@ StaticData *SIMDGenerator::NeutralElement(Reduction op, Type type, int repeat) { } case REDUCE_MIN: switch (type) { - case DT_FLOAT: return MaxVal(masm_, repeat); - case DT_DOUBLE: return MaxVal(masm_, repeat); - case DT_INT8: return MaxVal(masm_, repeat); - case DT_INT16: return MaxVal(masm_, repeat); - case DT_INT32: return MaxVal(masm_, repeat); - case DT_INT64: return MaxVal(masm_, repeat); + case DT_FLOAT: return masm_->MaxVal(repeat); + case DT_DOUBLE: return masm_->MaxVal(repeat); + case DT_INT8: return masm_->MaxVal(repeat); + case DT_INT16: return masm_->MaxVal(repeat); + case DT_INT32: return masm_->MaxVal(repeat); + case DT_INT64: return masm_->MaxVal(repeat); default: return nullptr; } case REDUCE_MAX: switch (type) { - case DT_FLOAT: return MinVal(masm_, repeat); - case DT_DOUBLE: return MinVal(masm_, repeat); - case DT_INT8: return MinVal(masm_, repeat); - case DT_INT16: return MinVal(masm_, repeat); - case DT_INT32: return MinVal(masm_, repeat); - case DT_INT64: return MinVal(masm_, repeat); + case DT_FLOAT: return masm_->MinVal(repeat); + case DT_DOUBLE: return masm_->MinVal(repeat); + case DT_INT8: return masm_->MinVal(repeat); + case DT_INT16: return masm_->MinVal(repeat); + case DT_INT32: return masm_->MinVal(repeat); + case DT_INT64: return masm_->MinVal(repeat); default: return nullptr; } case REDUCE_AND: @@ -2815,15 +2805,15 @@ int SIMDAssembler::RegisterUsage(Type type) { switch (type) { case DT_INT8: case DT_INT16: - return 1; + return 2; case DT_INT32: if (CPU::Enabled(AVX512F)) return 0; if (CPU::Enabled(AVX2)) return 0; if (CPU::Enabled(SSE4_1) && CPU::Enabled(SSSE3)) return 0; - return 1; + return 2; case DT_INT64: if (CPU::Enabled(AVX512F)) return 0; - return 1; + return 2; default: return 0; } diff --git a/sling/myelin/tests/opcheck.py b/sling/myelin/tests/opcheck.py index 9b5e3c4e..7e620f3e 100644 --- a/sling/myelin/tests/opcheck.py +++ b/sling/myelin/tests/opcheck.py @@ -849,6 +849,8 @@ def acc_matmul_test(m, k, n): div_test(i) minimum_test(i) maximum_test(i) + argmax_test(i) + argmin_test(i) neg_test(i) abs_test(i) square_test(i) @@ -921,11 +923,6 @@ def acc_matmul_test(m, k, n): cond_test(i) select_test(i) - if dt != myelin.DT_DOUBLE: - # No support yet for argmax and argmin for doubles. - argmax_test(i) - argmin_test(i) - for i in sizes: for j in sizes: matmul_transpose_test(i, j) diff --git a/sling/nlp/document/lexical-encoder.cc b/sling/nlp/document/lexical-encoder.cc index e7a54451..c851ab8f 100644 --- a/sling/nlp/document/lexical-encoder.cc +++ b/sling/nlp/document/lexical-encoder.cc @@ -51,7 +51,7 @@ void LexicalFeatures::LoadLexicon(Flow *flow) { lexicon_.PrecomputeShapes(); } -void LexicalFeatures::SaveLexicon(myelin::Flow *flow) const { +void LexicalFeatures::SaveLexicon(Flow *flow) const { // Save word vocabulary. Flow::Blob *vocabulary = flow->AddBlob("lexicon", "dict"); vocabulary->SetAttr("delimiter", 0); @@ -402,26 +402,6 @@ void LexicalFeatureExtractor::Extract(const Document &document, for (int i = 0; i < length; ++i) { float *f = reinterpret_cast(fv->at(i)); Compute(features, i, f); - if (trace_) OutputTrace(i); - } -} - -void LexicalFeatureExtractor::OutputTrace(int token) { - OutputTrace(token, lex_.word_feature_); - OutputTrace(token, lex_.prefix_feature_); - OutputTrace(token, lex_.suffix_feature_); - OutputTrace(token, lex_.hyphen_feature_); - OutputTrace(token, lex_.caps_feature_); - OutputTrace(token, lex_.punct_feature_); - OutputTrace(token, lex_.quote_feature_); - OutputTrace(token, lex_.digit_feature_); -} - -void LexicalFeatureExtractor::OutputTrace(int token, myelin::Tensor *feature) { - if (feature == nullptr) return; - const int *values = data_.Get(feature); - for (int i = 0; i < feature->elements(); ++i) { - if (values[i] != -1) trace_(token, feature->name(), values[i]); } } @@ -456,43 +436,43 @@ void LexicalFeatureLearner::Backpropagate(Channel *dfv) { } } -BiLSTM::Outputs LexicalEncoder::Build(Flow *flow, - const LexicalFeatures::Spec &spec, - Vocabulary::Iterator *words, - int dim, bool learn) { +RNN::Variables LexicalEncoder::Build(Flow *flow, + const LexicalFeatures::Spec &spec, + Vocabulary::Iterator *words, + bool learn) { if (words != nullptr) { lex_.InitializeLexicon(words, spec.lexicon); } auto lexvars = lex_.Build(flow, spec, learn); - return bilstm_.Build(flow, dim, lexvars.fv, lexvars.dfv); + return rnn_.Build(flow, lexvars.fv, lexvars.dfv); } -void LexicalEncoder::Initialize(const myelin::Network &net) { +void LexicalEncoder::Initialize(const Network &net) { lex_.Initialize(net); - bilstm_.Initialize(net); + rnn_.Initialize(net); } -myelin::BiChannel LexicalEncoderInstance::Compute(const Document &document, - int begin, int end) { +Channel *LexicalEncoderInstance::Compute(const Document &document, + int begin, int end) { // Extract feature and map through feature embeddings. features_.Extract(document, begin, end, &fv_); - // Compute hidden states of LSTMs. - return bilstm_.Compute(&fv_); + // Compute hidden states for RNN. + return rnn_.Compute(&fv_); } -myelin::BiChannel LexicalEncoderLearner::Compute(const Document &document, - int begin, int end) { +Channel *LexicalEncoderLearner::Compute(const Document &document, + int begin, int end) { // Extract feature and map through feature embeddings. - myelin::Channel *fv = features_.Extract(document, begin, end); + Channel *fv = features_.Extract(document, begin, end); - // Compute hidden states of LSTMs. - return bilstm_.Compute(fv); + // Compute hidden states for RNN. + return rnn_.Compute(fv); } -void LexicalEncoderLearner::Backpropagate() { - // Backpropagate hidden state gradients through LSTMs. - myelin::Channel *dfv = bilstm_.Backpropagate(); +void LexicalEncoderLearner::Backpropagate(Channel *doutput) { + // Backpropagate hidden state gradients through RNN. + Channel *dfv = rnn_.Backpropagate(doutput); // Backpropagate feature vector gradients to feature embeddings. features_.Backpropagate(dfv); diff --git a/sling/nlp/document/lexical-encoder.h b/sling/nlp/document/lexical-encoder.h index f958ed9c..c7a79033 100644 --- a/sling/nlp/document/lexical-encoder.h +++ b/sling/nlp/document/lexical-encoder.h @@ -121,12 +121,8 @@ class LexicalFeatures { friend class LexicalFeatureLearner; }; -// Callback for tracing extracted features. -typedef std::function< - void(int token, const string &feature, int value)> LexicalFeatureTrace; - // Lexical feature extractor for extracting features from document tokens and -// mapping these though feature embeddings. +// mapping these through feature embeddings. class LexicalFeatureExtractor { public: LexicalFeatureExtractor(const LexicalFeatures &lex) @@ -143,17 +139,9 @@ class LexicalFeatureExtractor { // Data instance for feature extraction. myelin::Instance *data() { return &data_; } - // Set trace callback. - void set_trace(LexicalFeatureTrace trace) { trace_ = trace; } - private: - // Call trace function for extracted feature values for current token. - void OutputTrace(int token); - void OutputTrace(int token, myelin::Tensor *feature); - const LexicalFeatures &lex_; myelin::Instance data_; - LexicalFeatureTrace trace_; }; // Lexical feature learner for training feature embeddings. @@ -185,20 +173,22 @@ class LexicalFeatureLearner { myelin::Instance gradient_; }; -// A lexical encoder is a lexical feature extractor with a bi-directional LSTM -// on top. +// A lexical encoder is a lexical feature extractor with an RNN on top. class LexicalEncoder { public: LexicalEncoder(const string &lexname = "features", - const string &lstmname = "lstm") - : lex_(lexname), bilstm_(lstmname) {} + const string &rnnname = "encoder") + : lex_(lexname), rnn_(rnnname) {} + + // Add RNN layers to encoder. + void AddLayers(int layers, const myelin::RNN::Spec spec, bool bidir) { + rnn_.AddLayers(layers, spec, bidir); + } - // Build flow for lexical encoder. Returns the output variables from the - // LSTMs. - myelin::BiLSTM::Outputs Build(myelin::Flow *flow, - const LexicalFeatures::Spec &spec, - Vocabulary::Iterator *words, - int dim, bool learn); + // Build flow for lexical encoder. Returns the output variables from the RNN. + myelin::RNN::Variables Build(myelin::Flow *flow, + const LexicalFeatures::Spec &spec, + Vocabulary::Iterator *words, bool learn); // Initialize feature extractor from existing model. void Initialize(const myelin::Network &net); @@ -216,8 +206,8 @@ class LexicalEncoder { // Lexical feature extractor with embeddings. LexicalFeatures lex_; - // Bi-directional LSTM. - myelin::BiLSTM bilstm_; + // RNN encoder. + myelin::RNNStack rnn_; friend class LexicalEncoderInstance; friend class LexicalEncoderLearner; @@ -229,22 +219,18 @@ class LexicalEncoderInstance { LexicalEncoderInstance(const LexicalEncoder &encoder) : encoder_(encoder), features_(encoder_.lex_), - bilstm_(encoder_.bilstm_), + rnn_(encoder_.rnn_), fv_(encoder.lex().feature_vector()) {} // Extract lexical features from a range of tokens in a document, map the - // features through the feature embeddings, and run the bi-directional LSTM - // encoder. Returns the left-to-right and right-to-left channels for the - // hidden state of the LSTMs. - myelin::BiChannel Compute(const Document &document, int begin, int end); - - // Sets feature extraction tracing callback. - void set_trace(LexicalFeatureTrace trace) { features_.set_trace(trace); } + // features through the feature embeddings, and run the RNN encoder. Returns + // the channel for the hidden state of the RNN. + myelin::Channel *Compute(const Document &document, int begin, int end); private: const LexicalEncoder &encoder_; LexicalFeatureExtractor features_; - myelin::BiLSTMInstance bilstm_; + myelin::RNNStackInstance rnn_; myelin::Channel fv_; }; @@ -254,35 +240,30 @@ class LexicalEncoderLearner { LexicalEncoderLearner(const LexicalEncoder &encoder) : encoder_(encoder), features_(encoder.lex_), - bilstm_(encoder_.bilstm_) {} + rnn_(encoder_.rnn_) {} - // Compute hidden states for the LSTMs from input document. - myelin::BiChannel Compute(const Document &document, int begin, int end); - - // Prepare gradient channels. - myelin::BiChannel PrepareGradientChannels(int length) { - return bilstm_.PrepareGradientChannels(length); - } + // Compute hidden state for the RNN from input document. + myelin::Channel *Compute(const Document &document, int begin, int end); // Backpropagate hidden state gradients. - void Backpropagate(); + void Backpropagate(myelin::Channel *doutput); // Collect gradients. void CollectGradients(std::vector *gradients) { features_.CollectGradients(gradients); - bilstm_.CollectGradients(gradients); + rnn_.CollectGradients(gradients); } // Clear gradients. void Clear() { features_.Clear(); - bilstm_.Clear(); + rnn_.Clear(); } private: const LexicalEncoder &encoder_; LexicalFeatureLearner features_; - myelin::BiLSTMLearner bilstm_; + myelin::RNNStackLearner rnn_; }; } // namespace nlp diff --git a/sling/nlp/embedding/BUILD b/sling/nlp/embedding/BUILD index abcc0567..15e666ce 100644 --- a/sling/nlp/embedding/BUILD +++ b/sling/nlp/embedding/BUILD @@ -10,7 +10,6 @@ cc_library( "//sling/myelin:flow", "//sling/myelin:gradient", "//sling/myelin:learning", - "//sling/util:random", ], ) diff --git a/sling/nlp/embedding/embedding-model.cc b/sling/nlp/embedding/embedding-model.cc index af5997fe..07f9351e 100644 --- a/sling/nlp/embedding/embedding-model.cc +++ b/sling/nlp/embedding/embedding-model.cc @@ -16,7 +16,6 @@ #include "sling/myelin/builder.h" #include "sling/myelin/gradient.h" -#include "sling/util/random.h" namespace sling { namespace nlp { @@ -33,7 +32,7 @@ void MikolovFlow::Build() { void MikolovFlow::BuildModel() { W0 = AddWeights("W0", DT_FLOAT, {inputs, dims}); W1 = AddWeights("W1", DT_FLOAT, {outputs, dims}); - W0->set_random(); + W0->init = Variable::INIT_UNIFORM; } void MikolovFlow::BuildLayer0() { @@ -116,7 +115,8 @@ void DualEncoderFlow::BuildEncoder(Encoder *encoder) { encoder->forward = AddFunction(encoder->name); FlowBuilder tf(this, encoder->forward); encoder->embeddings = - tf.Random(tf.Parameter("embeddings", DT_FLOAT, {encoder->dims, dims})); + tf.Parameter("embeddings", DT_FLOAT, {encoder->dims, dims}); + tf.RandomNormal(encoder->embeddings); encoder->features = tf.Placeholder("features", DT_INT32, {1, encoder->max_features}); auto *sum = tf.GatherSum(encoder->embeddings, encoder->features); diff --git a/sling/nlp/embedding/fact-embeddings.cc b/sling/nlp/embedding/fact-embeddings.cc index 2957e903..9bcea051 100644 --- a/sling/nlp/embedding/fact-embeddings.cc +++ b/sling/nlp/embedding/fact-embeddings.cc @@ -93,7 +93,7 @@ class FactEmbeddingsTrainer : public LearnerTask { loss_.Initialize(model); // Initialize weights. - model.InitLearnableWeights(task->Get("seed", 0), 0.0, 0.01); + model.InitModelParameters(task->Get("seed", 0)); // Read training instances from input. LOG(INFO) << "Reading training data"; diff --git a/sling/nlp/embedding/fact-plausibility.cc b/sling/nlp/embedding/fact-plausibility.cc index 3b3c87e7..c10eb620 100644 --- a/sling/nlp/embedding/fact-plausibility.cc +++ b/sling/nlp/embedding/fact-plausibility.cc @@ -37,7 +37,7 @@ struct FactPlausibilityFlow : public Flow { scorer = AddFunction("scorer"); FlowBuilder f(this, scorer); auto *embeddings = - f.Random(f.Parameter("embeddings", DT_FLOAT, {facts, dims})); + f.RandomNormal(f.Parameter("embeddings", DT_FLOAT, {facts, dims})); premise = f.Placeholder("premise", DT_INT32, {1, max_features}); auto *pencoding = f.GatherSum(embeddings, premise); @@ -134,7 +134,7 @@ class FactPlausibilityTrainer : public LearnerTask { loss_.Initialize(model); // Initialize weights. - model.InitLearnableWeights(seed_, 0.0, 0.01); + model.InitModelParameters(seed_); // Read training instances from input. LOG(INFO) << "Reading training data"; @@ -476,7 +476,7 @@ class FactPlausibilityTrainer : public LearnerTask { Build(&flow, false); // Copy weights from trained model. - model.SaveLearnedWeights(&flow); + model.SaveParameters(&flow); // Add fact lexicon. string encoded = Encode(fact_lexicon_); diff --git a/sling/nlp/parser/BUILD b/sling/nlp/parser/BUILD index aaf5797b..e60556b8 100644 --- a/sling/nlp/parser/BUILD +++ b/sling/nlp/parser/BUILD @@ -23,7 +23,6 @@ cc_library( "//sling/frame:object", "//sling/frame:store", "//sling/nlp/document", - "//sling/string:strcat", ], ) @@ -33,13 +32,10 @@ cc_library( hdrs = ["action-table.h"], deps = [ ":parser-action", - ":parser-state", "//sling/base", - "//sling/file", "//sling/frame:object", "//sling/frame:serialization", "//sling/frame:store", - "//sling/string:text", ], ) @@ -62,39 +58,8 @@ cc_library( ":parser-action", ":parser-state", ":roles", - ":trace", "//sling/base", "//sling/myelin:compute", - "//sling/myelin:rnn", - ], -) - -cc_library( - name = "cascade", - srcs = ["cascade.cc"], - hdrs = ["cascade.h"], - deps = [ - ":action-table", - ":parser-state", - ":trace", - "//sling/base", - "//sling/frame:serialization", - "//sling/frame:store", - "//sling/myelin:compute", - "//sling/myelin:flow", - ], -) - -cc_library( - name = "trace", - srcs = ["trace.cc"], - hdrs = ["trace.h"], - deps = [ - ":parser-action", - "//sling/base", - "//sling/frame:object", - "//sling/frame:store", - "//sling/nlp/document:document", ], ) @@ -103,12 +68,9 @@ cc_library( srcs = ["parser.cc"], hdrs = ["parser.h"], deps = [ - ":action-table", - ":cascade", ":parser-features", ":parser-state", ":roles", - ":trace", "//sling/base", "//sling/frame:serialization", "//sling/frame:store", @@ -124,6 +86,7 @@ cc_library( srcs = ["parser-annotator.cc"], deps = [ ":parser", + ":multiclass-delegate", "//sling/nlp/document:annotator", ], alwayslink = 1, @@ -203,3 +166,13 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "multiclass-delegate", + srcs = ["multiclass-delegate.cc"], + deps = [ + ":action-table", + ":parser", + ], + alwayslink = 1, +) + diff --git a/sling/nlp/parser/action-table.cc b/sling/nlp/parser/action-table.cc index 31b1725c..ead3ab3e 100644 --- a/sling/nlp/parser/action-table.cc +++ b/sling/nlp/parser/action-table.cc @@ -16,9 +16,7 @@ #include "sling/base/logging.h" #include "sling/base/types.h" -#include "sling/file/file.h" #include "sling/frame/serialization.h" -#include "sling/string/text.h" namespace sling { namespace nlp { @@ -33,25 +31,18 @@ void ActionTable::Add(const ParserAction &action) { } } -void ActionTable::Init(Store *store) { - Frame top(store, "/table"); - CHECK(top.valid()); - - // Get all the integer fields. - max_actions_per_token_ = top.GetInt("/table/max_actions_per_token", 5); - frame_limit_ = top.GetInt("/table/frame_limit", 5); - - // Read the action index. - Array actions = top.Get("/table/actions").AsArray(); +void ActionTable::Read(const Frame &frame) { + Array actions = frame.Get("actions").AsArray(); CHECK(actions.valid()); - Handle n_type = store->Lookup("/table/action/type"); - Handle n_length = store->Lookup("/table/action/length"); - Handle n_source = store->Lookup("/table/action/source"); - Handle n_target = store->Lookup("/table/action/target"); - Handle n_role = store->Lookup("/table/action/role"); - Handle n_label = store->Lookup("/table/action/label"); - Handle n_delegate = store->Lookup("/table/action/delegate"); + Store *store = frame.store(); + Handle n_type = store->Lookup("type"); + Handle n_length = store->Lookup("length"); + Handle n_source = store->Lookup("source"); + Handle n_target = store->Lookup("target"); + Handle n_role = store->Lookup("role"); + Handle n_label = store->Lookup("label"); + Handle n_delegate = store->Lookup("delegate"); for (int i = 0; i < actions.length(); ++i) { ParserAction action; Frame item(store, actions.get(i)); @@ -79,33 +70,16 @@ void ActionTable::Init(Store *store) { } } -void ActionTable::Save(const Store *global, const string &file) const { - string s = Serialize(global); - CHECK(File::WriteContents(file, s)); -} - -string ActionTable::Serialize(const Store *global) const { - // Build frame with action table. - Store store(global); - Builder table(&store); - table.AddId("/table"); - Write(&table); - - StringEncoder encoder(&store); - encoder.Encode(table.Create()); - return encoder.buffer(); -} - void ActionTable::Write(Builder *frame) const { // Save the action table. Store *store = frame->store(); - Handle n_type = store->Lookup("/table/action/type"); - Handle n_length = store->Lookup("/table/action/length"); - Handle n_source = store->Lookup("/table/action/source"); - Handle n_target = store->Lookup("/table/action/target"); - Handle n_role = store->Lookup("/table/action/role"); - Handle n_label = store->Lookup("/table/action/label"); - Handle n_delegate = store->Lookup("/table/action/delegate"); + Handle n_type = store->Lookup("type"); + Handle n_length = store->Lookup("length"); + Handle n_source = store->Lookup("source"); + Handle n_target = store->Lookup("target"); + Handle n_role = store->Lookup("role"); + Handle n_label = store->Lookup("label"); + Handle n_delegate = store->Lookup("delegate"); Array actions(store, actions_.size()); int index = 0; @@ -119,16 +93,12 @@ void ActionTable::Write(Builder *frame) const { b.Add(n_length, static_cast(action.length)); } } - if (type == ParserAction::ASSIGN || - type == ParserAction::ELABORATE || - type == ParserAction::CONNECT) { + if (type == ParserAction::ASSIGN || type == ParserAction::CONNECT) { if (action.source != 0) { b.Add(n_source, static_cast(action.source)); } } - if (type == ParserAction::EMBED || - type == ParserAction::REFER || - type == ParserAction::CONNECT) { + if (type == ParserAction::REFER || type == ParserAction::CONNECT) { if (action.target != 0) { b.Add(n_target, static_cast(action.target)); } @@ -140,7 +110,7 @@ void ActionTable::Write(Builder *frame) const { if (!action.label.IsNil()) b.Add(n_label, action.label); actions.set(index++, b.Create().handle()); } - frame->Add("/table/actions", actions); + frame->Add("actions", actions); } } // namespace nlp diff --git a/sling/nlp/parser/action-table.h b/sling/nlp/parser/action-table.h index 90dece81..055ec736 100644 --- a/sling/nlp/parser/action-table.h +++ b/sling/nlp/parser/action-table.h @@ -22,7 +22,6 @@ #include "sling/frame/object.h" #include "sling/frame/store.h" #include "sling/nlp/parser/parser-action.h" -#include "sling/nlp/parser/parser-state.h" namespace sling { namespace nlp { @@ -48,40 +47,18 @@ class ActionTable { // Return list of actions. const std::vector &list() const { return actions_; } - // Saves the action table. - void Save(const Store *global, const string &file) const; + // Read action table from frame. + void Read(const Frame &frame); - // Returns the serialization of the table. - string Serialize(const Store *global) const; - - // Write action table in frame. + // Write action table to frame. void Write(Builder *frame) const; - // Initialize the action table from store. - void Init(Store *store); - - // Maximum number of actions per token. - int max_actions_per_token() const { return max_actions_per_token_; } - void set_max_actions_per_token(int m) { - if (max_actions_per_token_ < m) max_actions_per_token_ = m; - } - - // Frame limit for source and target in parser actions. - int frame_limit() const { return frame_limit_; } - void set_frame_limit(int limit) { frame_limit_ = limit; } - private: // List of actions. std::vector actions_; // Mapping from parser action to index. std::unordered_map mapping_; - - // Maximum index of source and target for actions. - int frame_limit_ = 5; - - // Maximum number of actions taken per token. - int max_actions_per_token_ = -1; }; } // namespace nlp diff --git a/sling/nlp/parser/cascade.cc b/sling/nlp/parser/cascade.cc deleted file mode 100644 index 09624084..00000000 --- a/sling/nlp/parser/cascade.cc +++ /dev/null @@ -1,174 +0,0 @@ -// Copyright 2017 Google Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "sling/nlp/parser/cascade.h" - -#include "sling/frame/serialization.h" - -REGISTER_COMPONENT_REGISTRY("delegate runtime", sling::nlp::Delegate); - -namespace sling { -namespace nlp { - -// Delegate that assumes a softmax output. -class SoftmaxDelegate : public Delegate { - public: - void Initialize(const Cascade *cascade, const Frame &spec) override { - input_ = cell_->GetParameter(cell_->name() + "/input"); - output_ = cell_->GetParameter(cell_->name() + "/output"); - - // Read delegate's action table. - Store *store = spec.store(); - Array actions(store, spec.GetHandle("actions")); - CHECK(actions.valid()) << name(); - - Handle n_type = store->Lookup("/table/action/type"); - Handle n_delegate = store->Lookup("/table/action/delegate"); - Handle n_length = store->Lookup("/table/action/length"); - Handle n_source = store->Lookup("/table/action/source"); - Handle n_target = store->Lookup("/table/action/target"); - Handle n_label = store->Lookup("/table/action/label"); - Handle n_role = store->Lookup("/table/action/role"); - for (int i = 0; i < actions.length(); ++i) { - ParserAction action; - Frame frame(store, actions.get(i)); - action.type = static_cast(frame.GetInt(n_type)); - if (frame.Has(n_delegate)) action.delegate = frame.GetInt(n_delegate); - if (frame.Has(n_length)) action.length = frame.GetInt(n_length); - if (frame.Has(n_source)) action.source = frame.GetInt(n_source); - if (frame.Has(n_target)) action.target = frame.GetInt(n_target); - if (frame.Has(n_label)) action.label = frame.GetHandle(n_label); - if (frame.Has(n_role)) action.role = frame.GetHandle(n_role); - actions_.push_back(action); - } - } - - void Compute( - myelin::Instance *instance, ParserAction *action) const override { - int best_index = *instance->Get(output_); - - // NOTE: A more general and slightly more expensive approach would be - // to call another virtual method here: - // Overlay(actions_[best_index], action); - // Right now we overwrite the under-construction action with the output. - *action = actions_[best_index]; - } - - private: - // Location of the delegate output (argmax of the softmax layer). - myelin::Tensor *output_ = nullptr; - - // Action table for the delegate. - std::vector actions_; -}; - -REGISTER_DELEGATE_RUNTIME("SoftmaxDelegate", SoftmaxDelegate); - -Cascade::Cascade() { - shift_.type = ParserAction::SHIFT; - stop_.type = ParserAction::STOP; -} - -Cascade::~Cascade() { - for (auto *d : delegates_) delete d; -} - -void Cascade::Initialize(const myelin::Network &network, const Frame &spec) { - Store *store = spec.store(); - Array delegates(store, spec.GetHandle("delegates")); - CHECK(delegates.valid()); - - // Create delegates from the spec. - // - // For each delegate, the spec contains (among possibly other things): - // - Name of the the Myelin cell that implements it. - // - The name of the runtime (i.e. subclass of 'Delegate') used to run it. - // - The textual name (e.g. ShiftOrNot) of the delegate. - std::vector delegate_specs; - for (int i = 0; i < delegates.length(); ++i) { - Frame frame(store, delegates.get(i)); - string runtime = frame.GetText("runtime").str(); - - Delegate *d = Delegate::Create(runtime); - d->set_cell(network.GetCell(frame.GetText("cell").str())); - d->set_name(frame.GetText("name").str()); - d->set_runtime(runtime); - - delegates_.push_back(d); - delegate_specs.push_back(frame); - } - - // Initialize delegates. Delegates can choose to access other delegates in - // the cascade at this point. - int i = 0; - for (auto *d : delegates_) { - d->Initialize(this, delegate_specs[i++]); - } -} - -void Cascade::FallbackAction( - const ParserState *state, ParserAction *action) const { - *action = (state->current() < state->end()) ? shift_ : stop_; -} - -void DelegateInstance::Compute( - myelin::Channel *activations, int step, ParserAction *output) { - instance_.Clear(); - instance_.Set(delegate_->input(), activations, step); - instance_.Compute(); - delegate_->Compute(&instance_, output); -} - -CascadeInstance::CascadeInstance(const Cascade *cascade) - : cascade_(cascade) { - for (auto *d : cascade->delegates_) { - instances_.push_back(new DelegateInstance(d)); - } -} - -CascadeInstance::~CascadeInstance() { - for (auto *i : instances_) delete i; -} - -void CascadeInstance::Compute(myelin::Channel *activations, - ParserState *state, - ParserAction *output, - Trace *trace) { - int current = 0; - while (true) { - // Execute the current delegate's instance. - instances_[current]->Compute(activations, state->step(), output); - if (trace != nullptr) trace->Action(*output); - - // If there is a cascade down the chain then follow it. - // To avoid potential infinite loops, cascades to delegates - // up in the chain are disallowed. - bool is_cascade = output->type == ParserAction::CASCADE; - if (is_cascade && (output->delegate > current)) { - current = output->delegate; - continue; - } - - // If we have an applicable action then we are done with the cascade. - if (!is_cascade && state->CanApply(*output)) return; - - // Return a fallback action. - cascade_->FallbackAction(state, output); - if (trace != nullptr) trace->Fallback(*output); - return; - } -} - -} // namespace nlp -} // namespace sling diff --git a/sling/nlp/parser/cascade.h b/sling/nlp/parser/cascade.h deleted file mode 100644 index 2d1b9ca1..00000000 --- a/sling/nlp/parser/cascade.h +++ /dev/null @@ -1,145 +0,0 @@ -// Copyright 2017 Google Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef SLING_NLP_PARSER_CASCADE_H_ -#define SLING_NLP_PARSER_CASCADE_H_ - -#include -#include - -#include "sling/base/registry.h" -#include "sling/frame/object.h" -#include "sling/frame/store.h" -#include "sling/myelin/compute.h" -#include "sling/myelin/flow.h" -#include "sling/nlp/parser/parser-action.h" -#include "sling/nlp/parser/parser-state.h" -#include "sling/nlp/parser/trace.h" - -namespace sling { -namespace nlp { - -class Cascade; -class CascadeInstance; - -// Delegate runtime implementation. -class Delegate : public Component { - public: - virtual ~Delegate() {} - - // Initializes the delegate, which is a part of 'cascade', and whose - // specification is in 'spec'. The delegate implementation is - // already available in 'cell_'. - virtual void Initialize(const Cascade *cascade, const Frame &spec) = 0; - - // Modifies 'action' with the result of the already computed 'instance'. - virtual void Compute( - myelin::Instance *instance, ParserAction *action) const = 0; - - // Returns the location of the delegate input. - virtual myelin::Tensor *input() const { return input_; } - - // Cell accessors. - myelin::Cell *cell() const { return cell_; } - void set_cell(myelin::Cell *cell) { cell_ = cell; } - - // Other accessors. - const string &name() const { return name_; } - const string &runtime() const { return runtime_; } - void set_name(const string &n) { name_ = n; } - void set_runtime(const string &r) { runtime_ = r; } - - protected: - // Input to the delegate. - myelin::Tensor *input_ = nullptr; - - // Delegate cell. - myelin::Cell *cell_ = nullptr; - - // Name and runtime. - string name_; - string runtime_; -}; - -#define REGISTER_DELEGATE_RUNTIME(type, component) \ - REGISTER_COMPONENT_TYPE(sling::nlp::Delegate, type, component) - -// Cascade model. -class Cascade { - public: - Cascade(); - ~Cascade(); - - // Initializes the cascade by reading its specification from 'spec' - // and implementation from 'network'. - void Initialize(const myelin::Network &network, const Frame &spec); - - // Delegate accessors. - Delegate *delegate(int i) const { return delegates_[i]; } - int size() const { return delegates_.size(); } - - // Sets 'action' to the fallback action for 'state'. - void FallbackAction(const ParserState *state, ParserAction *action) const; - - private: - friend class CascadeInstance; - - // List of delegates. - std::vector delegates_; - - // Fallback actions. - ParserAction shift_; - ParserAction stop_; -}; - -// Instance for running a single delegate. -class DelegateInstance { - public: - DelegateInstance(const Delegate *d) : delegate_(d), instance_(d->cell()) {} - - // Runs the delegate with the specified input activation, and populates - // 'output' with the resulting action. - void Compute(myelin::Channel *activations, int step, ParserAction *output); - - private: - // Delegate. Not owned. - const Delegate *delegate_ = nullptr; - - // Underlying Myelin instance. - myelin::Instance instance_; -}; - -// Runs an instance of a cascade on a ParserState. -class CascadeInstance { - public: - CascadeInstance(const Cascade *cascade); - ~CascadeInstance(); - - // Outputs in 'output' the result of running the whole cascade on 'state'. - // The activation at index 'step' is used as input to all the delegates. - // Adds the predicted and final actions to 'trace' if it is not nullptr. - void Compute(myelin::Channel *activations, - ParserState *state, - ParserAction *output, - Trace *trace = nullptr); - - private: - const Cascade *const cascade_ = nullptr; // cascade; not owned - std::vector instances_; // delegate-specific instances -}; - -} // namespace nlp -} // namespace sling - -#endif // SLING_NLP_PARSER_CASCADE_H_ diff --git a/sling/nlp/parser/caspar-trainer.cc b/sling/nlp/parser/caspar-trainer.cc index cabe05ff..4461dfb0 100644 --- a/sling/nlp/parser/caspar-trainer.cc +++ b/sling/nlp/parser/caspar-trainer.cc @@ -35,21 +35,23 @@ class MultiClassDelegateLearner : public DelegateLearner { : name_(name), loss_(name + "_loss") {} void Build(Flow *flow, - Flow::Variable *activations, - Flow::Variable *dactivations, + Flow::Variable *activation, + Flow::Variable *dactivation, bool learn) override { FlowBuilder f(flow, name_); - int dim = activations->elements(); + int dim = activation->elements(); int size = actions_.size(); - auto *W = f.Random(f.Parameter("W", DT_FLOAT, {dim, size})); - auto *b = f.Random(f.Parameter("b", DT_FLOAT, {1, size})); + auto *W = f.Parameter("W", DT_FLOAT, {dim, size}); + auto *b = f.Parameter("b", DT_FLOAT, {1, size}); + f.RandomNormal(W); auto *input = f.Placeholder("input", DT_FLOAT, {1, dim}, true); auto *logits = f.Name(f.Add(f.MatMul(input, W), b), "logits"); - logits->set_out(); - f.Name(f.ArgMax(logits), "output"); + if (learn) logits->set_out(); + auto *output = f.Name(f.ArgMax(logits), "output"); + if (!learn) output->set_out(); - flow->Connect({activations, input}); + flow->Connect({activation, input}); if (learn) { Gradient(flow, f.func()); auto *dlogits = flow->GradientVar(logits); @@ -74,56 +76,11 @@ class MultiClassDelegateLearner : public DelegateLearner { return new DelegateInstance(this); } - void Save(Flow *flow, Builder *data) override { - // Save delegate type. - data->Add("name", name_); - data->Add("runtime", "SoftmaxDelegate"); - data->Add("cell", cell_->name()); - - // Save the action table. - Store *store = data->store(); - Handle n_type = store->Lookup("/table/action/type"); - Handle n_length = store->Lookup("/table/action/length"); - Handle n_source = store->Lookup("/table/action/source"); - Handle n_target = store->Lookup("/table/action/target"); - Handle n_role = store->Lookup("/table/action/role"); - Handle n_label = store->Lookup("/table/action/label"); - Handle n_delegate = store->Lookup("/table/action/delegate"); - - Array actions(store, actions_.size()); - int index = 0; - for (const ParserAction &action : actions_.list()) { - auto type = action.type; - Builder b(store); - b.Add(n_type, static_cast(type)); - - if (type == ParserAction::REFER || type == ParserAction::EVOKE) { - if (action.length > 0) { - b.Add(n_length, static_cast(action.length)); - } - } - if (type == ParserAction::ASSIGN || - type == ParserAction::ELABORATE || - type == ParserAction::CONNECT) { - if (action.source != 0) { - b.Add(n_source, static_cast(action.source)); - } - } - if (type == ParserAction::EMBED || - type == ParserAction::REFER || - type == ParserAction::CONNECT) { - if (action.target != 0) { - b.Add(n_target, static_cast(action.target)); - } - } - if (type == ParserAction::CASCADE) { - b.Add(n_delegate, static_cast(action.delegate)); - } - if (!action.role.IsNil()) b.Add(n_role, action.role); - if (!action.label.IsNil()) b.Add(n_label, action.label); - actions.set(index++, b.Create().handle()); - } - data->Add("actions", actions); + void Save(Flow *flow, Builder *spec) override { + spec->Add("name", name_); + spec->Add("type", "multiclass"); + spec->Add("cell", cell_->name()); + actions_.Write(spec); } // Multi-class delegate instance. @@ -132,8 +89,7 @@ class MultiClassDelegateLearner : public DelegateLearner { DelegateInstance(MultiClassDelegateLearner *learner) : learner_(learner), forward_(learner->cell_), - backward_(learner->dcell_) { - } + backward_(learner->dcell_) {} void CollectGradients(std::vector *gradients) override { gradients->push_back(&backward_); @@ -143,15 +99,15 @@ class MultiClassDelegateLearner : public DelegateLearner { backward_.Clear(); } - float Compute(float *activations, - float *dactivations, + float Compute(float *activation, + float *dactivation, const ParserAction &action) override { // Look up index for action. Skip backpropagation if action is unknown. int target = learner_->actions_.Index(action); if (target == -1) return 0.0; - // Compute logits from activations. - forward_.SetReference(learner_->input_, activations); + // Compute logits from activation. + forward_.SetReference(learner_->input_, activation); forward_.Compute(); // Compute loss. @@ -161,15 +117,15 @@ class MultiClassDelegateLearner : public DelegateLearner { // Backpropagate loss. backward_.Set(learner_->primal_, &forward_); - backward_.SetReference(learner_->dinput_, dactivations); + backward_.SetReference(learner_->dinput_, dactivation); backward_.Compute(); return loss; } - void Predict(float *activations, ParserAction *action) override { + void Predict(float *activation, ParserAction *action) override { // Predict action from activations. - forward_.SetReference(learner_->input_, activations); + forward_.SetReference(learner_->input_, activation); forward_.Compute(); int argmax = *forward_.Get(learner_->output_); *action = learner_->actions_.Action(argmax); @@ -226,7 +182,15 @@ class CasparTrainer : public ParserTrainer { public: // Set up caspar parser model. void Setup(task::Task *task) override { + // Get training parameters. + task->Fetch("max_source", &max_source_); + task->Fetch("max_target", &max_target_); + + // Reset parser state between sentences. + sentence_reset_ = true; + // Collect word and action vocabularies from training corpus. + ActionTable actions; training_corpus_->Rewind(); for (;;) { // Get next document. @@ -249,23 +213,21 @@ class CasparTrainer : public ParserTrainer { if (action.target > max_target_) skip = true; break; case ParserAction::ASSIGN: - case ParserAction::EMBED: - case ParserAction::ELABORATE: if (action.source > max_source_) skip = true; break; default: break; } - if (!skip) actions_.Add(action); + if (!skip) actions.Add(action); }); delete document; } - roles_.Init(actions_.list()); + roles_.Add(actions.list()); // Set up delegates. delegates_.push_back(new ShiftMarkOtherDelegateLearner(1)); - delegates_.push_back(new ClassificationDelegateLearner(actions_)); + delegates_.push_back(new ClassificationDelegateLearner(actions)); } // Transition generator. @@ -281,18 +243,10 @@ class CasparTrainer : public ParserTrainer { }); } - // Save action table in model. - void SaveModel(Flow *flow, Store *store) override { - // Save action table in store. - Builder table(store); - table.AddId("/table"); - actions_.Write(&table); - table.Create(); - } - private: - // Parser actions. - ActionTable actions_; + // Hyperparameters. + int max_source_ = 5; + int max_target_ = 10; }; REGISTER_TASK_PROCESSOR("caspar-trainer", CasparTrainer); diff --git a/sling/nlp/parser/frame-evaluation.cc b/sling/nlp/parser/frame-evaluation.cc index f2694239..66e74b96 100644 --- a/sling/nlp/parser/frame-evaluation.cc +++ b/sling/nlp/parser/frame-evaluation.cc @@ -80,8 +80,10 @@ void FrameEvaluation::Evaluate(ParallelCorpus *corpus, Output *output) { // Benchmarks. auto &mention = output->mention; auto &frame = output->frame; - auto &type = output->type; + auto &pair = output->pair; + auto &edge = output->edge; auto &role = output->role; + auto &type = output->type; auto &label = output->label; // Statistics counters. @@ -134,9 +136,11 @@ void FrameEvaluation::Evaluate(ParallelCorpus *corpus, Output *output) { // Compute role precision and recall. RoleAccuracy(store, g2p_frame_alignment, - &type.recall, &role.recall, &label.recall); + &pair.recall, &edge.recall, &role.recall, + &type.recall, &label.recall); RoleAccuracy(store, p2g_frame_alignment, - &type.precision, &role.precision, &label.precision); + &pair.precision, &edge.precision, &role.precision, + &type.precision, &label.precision); // Update statistics. output->num_golden_spans += golden_mentions.size(); @@ -232,8 +236,10 @@ string FrameEvaluation::Benchmark::Summary() const { void FrameEvaluation::Output::GetScores(Scores *scores) const { mention.GetScores("SPAN", scores); frame.GetScores("FRAME", scores); - type.GetScores("TYPE", scores); + pair.GetScores("PAIR", scores); + edge.GetScores("EDGE", scores); role.GetScores("ROLE", scores); + type.GetScores("TYPE", scores); label.GetScores("LABEL", scores); slot.GetScores("SLOT", scores); combined.GetScores("COMBINED", scores); @@ -403,7 +409,8 @@ void FrameEvaluation::AlignmentAccuracy( void FrameEvaluation::RoleAccuracy( Store *store, const Alignment &alignment, - Metric *type, Metric *role, Metric *label) { + Metric *pair, Metric *edge, Metric *role, + Metric *type, Metric *label) { for (const auto &a : alignment) { Frame source(store, a.first); Frame target(store, a.second); @@ -412,15 +419,18 @@ void FrameEvaluation::RoleAccuracy( for (const Slot &s : source) { if (s.name.IsIsA()) { // Check type. - type->prediction(HasRole(target, Handle::isa(), s.value)); + type->prediction(HasSlot(target, Handle::isa(), s.value)); } else if (s.name.IsId() || s.name.IsIs()) { // Ignore special roles. } else if (s.value.IsLocalRef()) { // Check frame-to-frame role. - role->prediction(HasRole(target, s.name, alignment.Lookup(s.value))); + Handle value = alignment.Lookup(s.value); + pair->prediction(!value.IsNil()); + edge->prediction(HasValue(target, value)); + role->prediction(HasSlot(target, s.name, value)); } else { // Check label role. - label->prediction(HasRole(target, s.name, s.value)); + label->prediction(HasSlot(target, s.name, s.value)); } } } @@ -434,7 +444,7 @@ int FrameEvaluation::SlotCount(const Frame &f, Handle name) { return n; } -bool FrameEvaluation::HasRole(const Frame &f, Handle name, Handle value) { +bool FrameEvaluation::HasSlot(const Frame &f, Handle name, Handle value) { if (f.invalid() || name.IsNil() || value.IsNil()) return false; for (const Slot &s : f) { if (s.name == name && s.value == value) return true; @@ -442,5 +452,13 @@ bool FrameEvaluation::HasRole(const Frame &f, Handle name, Handle value) { return false; } +bool FrameEvaluation::HasValue(const Frame &f, Handle value) { + if (f.invalid() || value.IsNil()) return false; + for (const Slot &s : f) { + if (s.value == value) return true; + } + return false; +} + } // namespace nlp } // namespace sling diff --git a/sling/nlp/parser/frame-evaluation.h b/sling/nlp/parser/frame-evaluation.h index 9775bcb8..c8988ca5 100644 --- a/sling/nlp/parser/frame-evaluation.h +++ b/sling/nlp/parser/frame-evaluation.h @@ -144,8 +144,10 @@ class FrameEvaluation { Benchmark mention; Benchmark frame; Benchmark type; - Benchmark role; Benchmark label; + Benchmark pair; + Benchmark edge; + Benchmark role; Benchmark slot; Benchmark combined; @@ -213,13 +215,17 @@ class FrameEvaluation { // Computes role accuracy. static void RoleAccuracy(Store *store, const Alignment &alignment, - Metric *type, Metric *role, Metric *label); + Metric *pair, Metric *edge, Metric *role, + Metric *type, Metric *label); // Counts the number of slots with a given name. static int SlotCount(const Frame &f, Handle name); - // Checks if frame has a role with a given name and value. - static bool HasRole(const Frame &f, Handle name, Handle value); + // Checks if frame has a slot with a given name and value. + static bool HasSlot(const Frame &f, Handle name, Handle value); + + // Checks if frame has a slot with a given value. + static bool HasValue(const Frame &f, Handle value); }; } // namespace nlp diff --git a/sling/nlp/parser/multiclass-delegate.cc b/sling/nlp/parser/multiclass-delegate.cc new file mode 100644 index 00000000..1fdf8c0a --- /dev/null +++ b/sling/nlp/parser/multiclass-delegate.cc @@ -0,0 +1,69 @@ +// Copyright 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "sling/nlp/parser/action-table.h" +#include "sling/nlp/parser/parser.h" + +namespace sling { +namespace nlp { + +using namespace myelin; + +// Deletegate for fixed action classification. +class MultiClassDelegate : public Delegate { + public: + void Initialize(const Network &network, const Frame &spec) override { + cell_ = network.GetCell(spec.GetString("cell")); + input_ = cell_->GetParameter(cell_->name() + "/input"); + output_ = cell_->GetParameter(cell_->name() + "/output"); + actions_.Read(spec); + } + + DelegateInstance *CreateInstance() override { + return new MultiClassDelegateInstance(this); + } + + // Multi-class delegate instance. + class MultiClassDelegateInstance : public DelegateInstance { + public: + MultiClassDelegateInstance(MultiClassDelegate *delegate) + : delegate_(delegate), + data_(delegate->cell_) {} + + void Predict(float *activation, ParserAction *action) override { + // Predict action from activations. + data_.SetReference(delegate_->input_, activation); + data_.Compute(); + int argmax = *data_.Get(delegate_->output_); + *action = delegate_->actions_.Action(argmax); + } + + private: + MultiClassDelegate *delegate_; + Instance data_; + }; + + private: + ActionTable actions_; // action table for multi-class classification + + Cell *cell_ = nullptr; // cell for computation + Tensor *input_ = nullptr; // input for activations + Tensor *output_ = nullptr; // output prediction +}; + +REGISTER_DELEGATE("multiclass", MultiClassDelegate); + +} // namespace nlp +} // namespace sling + diff --git a/sling/nlp/parser/parser-action.cc b/sling/nlp/parser/parser-action.cc index 89764f53..053e0248 100644 --- a/sling/nlp/parser/parser-action.cc +++ b/sling/nlp/parser/parser-action.cc @@ -25,12 +25,9 @@ string ParserAction::TypeName(Type type) { case ParserAction::REFER: return "REFER"; case ParserAction::CONNECT: return "CONNECT"; case ParserAction::ASSIGN: return "ASSIGN"; - case ParserAction::EMBED: return "EMBED"; - case ParserAction::ELABORATE: return "ELABORATE"; case ParserAction::CASCADE: return "CASCADE"; case ParserAction::MARK: return "MARK"; case ParserAction::SHIFT: return "SHIFT"; - case ParserAction::STOP: return "STOP"; } return ""; @@ -46,10 +43,10 @@ Frame ParserAction::AsFrame(Store *store, const string &prefix) const { if (length != 0) builder.Add(prefix + "length", length); if (!label.IsNil()) builder.Add(prefix + "label", label); if (!role.IsNil()) builder.Add(prefix + "role", role); - if (type == REFER || type == CONNECT || type == EMBED) { + if (type == REFER || type == CONNECT) { builder.Add(prefix + "target", target); } - if (type == ASSIGN || type == CONNECT || type == ELABORATE) { + if (type == ASSIGN || type == CONNECT) { builder.Add(prefix + "source", source); } if (type == CASCADE) { @@ -75,22 +72,11 @@ string ParserAction::ToString(Store *store) const { StrAppend(&s, source, " -> ", store->DebugString(role), " -> ", store->DebugString(label)); break; - case ParserAction::EMBED: - StrAppend(&s, "TYPE(", store->DebugString(label), ")", - " -> ", store->DebugString(role), " -> ", - target); - break; - case ParserAction::ELABORATE: - StrAppend(&s, "TYPE(", store->DebugString(label), ")", - " <- ", store->DebugString(role), " <- ", - source); - break; case ParserAction::CASCADE: StrAppend(&s, "(delegate=", delegate, ")"); break; case ParserAction::SHIFT: case ParserAction::MARK: - case ParserAction::STOP: default: s.pop_back(); break; diff --git a/sling/nlp/parser/parser-action.h b/sling/nlp/parser/parser-action.h index 9cf42fee..ef2e03d2 100644 --- a/sling/nlp/parser/parser-action.h +++ b/sling/nlp/parser/parser-action.h @@ -31,45 +31,29 @@ struct ParserAction { enum Type : uint8 { // Skips the next input token. Only valid when not at the end of the input // buffer. - SHIFT, - - // Signals that we have reach the end of the parse. This is only valid when - // at the end of the input buffer. Multiple STOP actions can be added to - // the transition sequence to make all sequences in a beam have the same - // length. - STOP, + SHIFT = 0, // Evokes frame of with type 'type' from the next 'length' tokens in the // input. The new frame will become the center of attention. - EVOKE, + EVOKE = 2, // Makes a new mention of an existing frame. This frame will become the new // center of attention. - REFER, + REFER = 3, // Adds slot to frame 'source' with name 'role' and value 'target'. The // source frame become the new center of attention. - CONNECT, + CONNECT = 4, // Adds slot to frame 'source' with name 'role' and value 'type' and moves // frame to the center of attention. - ASSIGN, - - // Create new frame with type 'type' and add a slot to it with name 'role' - // and value 'target', where target is a frame in the attention buffer. - // The new frame become the new center of attention. - EMBED, - - // Create new frame with type 'type' and add a slot to an existing frame - // 'source' in the attention buffer with 'role' set to the new frame. - // The new frame become the new center of attention. - ELABORATE, + ASSIGN = 5, // Delegate to another member (specified by 'delegate') of the cascade. - CASCADE, + CASCADE = 8, // Mark the current token as the beginning of a span. - MARK, + MARK = 9, }; // Number of action types. @@ -82,16 +66,16 @@ struct ParserAction { // Length of the evoked frame for EVOKE and REFER. uint8 length; - // Source frame index for CONNECT, ASSIGN, ELABORATE. + // Source frame index for CONNECT and ASSIGN. uint8 source; - // Target frame index for CONNECT, EMBED, REFER. + // Target frame index for CONNECT and REFER. uint8 target; - // Role argument for CONNECT, ASSIGN, EMBED, ELABORATE. + // Role argument for CONNECT and ASSIGN. Handle role; - // Frame type for EVOKE, EMBED, ELABORATE, and value for ASSIGN. + // Frame type for EVOKE and value for ASSIGN. Handle label; // Index of the delegate for CASCADE actions. @@ -145,11 +129,6 @@ struct ParserAction { return ParserAction(ParserAction::SHIFT); } - // Returns a STOP action. - static ParserAction Stop() { - return ParserAction(ParserAction::STOP); - } - // Returns an EVOKE action. static ParserAction Evoke(uint8 length, Handle type) { ParserAction action(ParserAction::EVOKE, length); diff --git a/sling/nlp/parser/parser-features.cc b/sling/nlp/parser/parser-features.cc index d75ceddc..fc7306b9 100644 --- a/sling/nlp/parser/parser-features.cc +++ b/sling/nlp/parser/parser-features.cc @@ -33,7 +33,6 @@ myelin::Tensor *ParserFeatureModel::GetParam(const string &name, } void ParserFeatureModel::Init(myelin::Cell *cell, - myelin::Flow::Blob *spec, const RoleSet *roles, int frame_limit) { // Store cell that contains the feature inputs. @@ -42,39 +41,22 @@ void ParserFeatureModel::Init(myelin::Cell *cell, frame_limit_ = frame_limit; // Get feature inputs. - lr_focus_feature_ = GetParam("lr", true); - rl_focus_feature_ = GetParam("rl", true); - lr_attention_feature_ = GetParam("frame-end-lr", true); - rl_attention_feature_ = GetParam("frame-end-rl", true); - frame_create_feature_ = GetParam("frame-creation-steps", true); - frame_focus_feature_ = GetParam("frame-focus-steps", true); + token_feature_ = GetParam("token", true); + + attention_tokens_feature_ = GetParam("attention_tokens", true); + attention_steps_feature_ = GetParam("attention_steps", true); + + mark_tokens_feature_ = GetParam("mark_tokens", true); + mark_steps_feature_ = GetParam("mark_steps", true); + history_feature_ = GetParam("history", true); - mark_lr_feature_ = GetParam("mark-lr", true); - mark_rl_feature_ = GetParam("mark-rl", true); - mark_step_feature_ = GetParam("mark-step", true); - mark_distance_feature_ = GetParam("mark-distance", true); - out_roles_feature_ = GetParam("out-roles", true); - in_roles_feature_ = GetParam("in-roles", true); - unlabeled_roles_feature_ = GetParam("unlabeled-roles", true); - labeled_roles_feature_ = GetParam("labeled-roles", true); + + out_roles_feature_ = GetParam("out_roles", true); + in_roles_feature_ = GetParam("in_roles", true); + unlabeled_roles_feature_ = GetParam("unlabeled_roles", true); + labeled_roles_feature_ = GetParam("labeled_roles", true); // Get feature sizes. - std::vector attention_features { - lr_attention_feature_, - rl_attention_feature_, - frame_create_feature_, - frame_focus_feature_, - }; - for (auto *f : attention_features) { - if (!f) continue; - if (f->elements() > attention_depth_) { - attention_depth_ = f->elements(); - } - } - for (auto *f : attention_features) { - if (!f) continue; - CHECK_EQ(attention_depth_, f->elements()); - } if (history_feature_ != nullptr) { history_size_ = history_feature_->elements(); } @@ -90,110 +72,58 @@ void ParserFeatureModel::Init(myelin::Cell *cell, if (labeled_roles_feature_ != nullptr) { labeled_roles_size_ = labeled_roles_feature_->elements(); } - if (mark_lr_feature_ != nullptr) { - mark_depth_ = mark_lr_feature_->elements(); - } - if (mark_rl_feature_ != nullptr) { - if (mark_depth_ == 0) { - mark_depth_ = mark_rl_feature_->elements(); - } else { - CHECK_EQ(mark_depth_, mark_rl_feature_->elements()); - } - } - if (mark_distance_feature_ != nullptr) { - CHECK(spec != nullptr); - string bins_str = spec->GetAttr("mark_distance_bins"); - std::vector bins; - int start = 0; - while (true) { - ssize_t index = bins_str.find(' ', start); - if (index != string::npos) { - string s = bins_str.substr(start, index - start); - bins.push_back(std::stoi(s)); - start = index + 1; - } else { - bins.push_back(std::stoi(bins_str.substr(start))); - break; - } - } - int distance = 0; - for (int i = 0; i < bins.size(); ++i) { - while (distance <= bins[i]) { - mark_distance_bins_.push_back(i); - distance++; - } - } - mark_distance_bins_.push_back(bins.size()); + if (mark_tokens_feature_ != nullptr) { + mark_depth_ = mark_tokens_feature_->elements(); } - // Get links. - lr_lstm_ = GetParam("link/lr_lstm"); - rl_lstm_ = GetParam("link/rl_lstm"); + // Get channel links. + tokens_ = GetParam("tokens"); steps_ = GetParam("steps"); - hidden_ = GetParam("hidden"); + + // Get output step activation from decoder. + activation_ = GetParam("activation"); }; -void ParserFeatureExtractor::Attach(const myelin::BiChannel &bilstm, +void ParserFeatureExtractor::Attach(myelin::Channel *encodings, myelin::Channel *activations, myelin::Instance *instance) { const ParserFeatureModel *fm = features_; - instance->Set(fm->lr_lstm_, bilstm.lr); - instance->Set(fm->rl_lstm_, bilstm.rl); + instance->Set(fm->tokens_, encodings); instance->Set(fm->steps_, activations); - instance->Set(fm->hidden_, activations, state_->step()); + instance->Set(fm->activation_, activations, state_->step()); } void ParserFeatureExtractor::Extract(myelin::Instance *instance) { const ParserFeatureModel *fm = features_; Data data(instance); - // Extract LSTM focus features. + // Extract current token feature. int current = state_->current() - state_->begin(); - if (state_->current() == state_->end()) current = -1; - int *lr_focus = data.Get(fm->lr_focus_feature_); - int *rl_focus = data.Get(fm->rl_focus_feature_); - if (lr_focus != nullptr) *lr_focus = current; - if (rl_focus != nullptr) *rl_focus = current; + int *token = data.Get(fm->token_feature_); + if (token != nullptr) *token = current; // Extract features from the mark stack. auto &marks = state_->marks(); - int *lr_mark = data.Get(fm->mark_lr_feature_); - int *rl_mark = data.Get(fm->mark_rl_feature_); - int *mark_step = data.Get(fm->mark_step_feature_); + int *mark_tokens = data.Get(fm->mark_tokens_feature_); + int *mark_steps = data.Get(fm->mark_steps_feature_); for (int d = 0; d < fm->mark_depth_; ++d) { if (d < marks.size()) { const auto &m = marks[marks.size() - 1 - d]; int token = m.token - state_->begin(); - if (lr_mark != nullptr) lr_mark[d] = token; - if (rl_mark != nullptr) rl_mark[d] = token; - if (mark_step != nullptr) mark_step[d] = m.step; + if (mark_tokens != nullptr) mark_tokens[d] = token; + if (mark_steps != nullptr) mark_steps[d] = m.step; } else { - if (lr_mark != nullptr) lr_mark[d] = -1; - if (rl_mark != nullptr) rl_mark[d] = -1; - if (mark_step != nullptr) mark_step[d] = -1; - } - } - - int *mark_distance = data.Get(fm->mark_distance_feature_); - if (mark_distance != nullptr) { - *mark_distance = fm->mark_distance_bins_.back(); - if (!marks.empty()) { - int distance = state_->current() - marks[marks.size() - 1].token; - if (distance < fm->mark_distance_bins_.size()) { - *mark_distance = fm->mark_distance_bins_[distance]; - } + if (mark_tokens != nullptr) mark_tokens[d] = -1; + if (mark_steps != nullptr) mark_steps[d] = -1; } } - // Extract frame attention, create, and focus features. - if (fm->attention_depth_ > 0) { - int *lr = data.Get(fm->lr_attention_feature_); - int *rl = data.Get(fm->rl_attention_feature_); - int *create = data.Get(fm->frame_create_feature_); - int *focus = data.Get(fm->frame_focus_feature_); - for (int d = 0; d < fm->attention_depth_; ++d) { - int token = -1; - int created = -1; + // Extract token and step attention features. + if (fm->frame_limit_ > 0) { + int *token_attention = data.Get(fm->attention_tokens_feature_); + int *step_attention = data.Get(fm->attention_steps_feature_); + for (int d = 0; d < fm->frame_limit_; ++d) { + int evoked = -1; int focused = -1; if (d < state_->AttentionSize()) { // Get frame from attention buffer. @@ -201,18 +131,15 @@ void ParserFeatureExtractor::Extract(myelin::Instance *instance) { // Get end token for phrase that evoked frame. if (attention.span != nullptr) { - token = attention.span->end(); - if (token != -1) token -= state_->begin() + 1; + evoked = attention.span->end(); + if (evoked != -1) evoked -= state_->begin() + 1; } - // Get the step numbers that created and focused the frame. - created = attention.created; + // Get the step numbers that focused the frame. focused = attention.focused; } - if (lr != nullptr) lr[d] = token; - if (rl != nullptr) rl[d] = token; - if (create != nullptr) create[d] = created; - if (focus != nullptr) focus[d] = focused; + if (token_attention != nullptr) token_attention[d] = evoked; + if (step_attention != nullptr) step_attention[d] = focused; } } @@ -273,35 +200,6 @@ void ParserFeatureExtractor::Extract(myelin::Instance *instance) { } }; -void ParserFeatureExtractor::TraceFeatures(myelin::Instance *instance, - Trace *trace) const { - trace->steps.emplace_back(); - auto &step = trace->steps.back(); - step.current = state_->current(); - - Data data(instance); - const ParserFeatureModel *fm = features_; - step.Add(data.Get(fm->lr_focus_feature_), 1, "lr"); - step.Add(data.Get(fm->rl_focus_feature_), 1, "rl"); - step.Add(data.Get(fm->mark_lr_feature_), fm->mark_depth_, "mark-lr"); - step.Add(data.Get(fm->mark_rl_feature_), fm->mark_depth_, "mark-rl"); - step.Add(data.Get(fm->mark_step_feature_), fm->mark_depth_, "mark-step"); - - int depth = fm->attention_depth_; - step.Add(data.Get(fm->lr_attention_feature_), depth, "frame-end-lr"); - step.Add(data.Get(fm->rl_attention_feature_), depth, "frame-end-rl"); - step.Add(data.Get(fm->frame_create_feature_), depth, "frame-creation-steps"); - step.Add(data.Get(fm->frame_focus_feature_), depth, "frame-focus-steps"); - step.Add(data.Get(fm->history_feature_), fm->history_size_, "history"); - step.Add(data.Get(fm->mark_distance_feature_), 1, "mark-distance"); - step.Add(data.Get(fm->out_roles_feature_), fm->out_roles_size_, "out-roles"); - step.Add(data.Get(fm->in_roles_feature_), fm->in_roles_size_, "in-roles"); - step.Add(data.Get(fm->unlabeled_roles_feature_), - fm->unlabeled_roles_size_, "unlabeled-roles"); - step.Add(data.Get(fm->labeled_roles_feature_), - fm->labeled_roles_size_, "labeled-roles"); -} - } // namespace nlp } // namespace sling diff --git a/sling/nlp/parser/parser-features.h b/sling/nlp/parser/parser-features.h index d6fd40f8..32d0ad51 100644 --- a/sling/nlp/parser/parser-features.h +++ b/sling/nlp/parser/parser-features.h @@ -18,10 +18,8 @@ #include #include "sling/myelin/compute.h" -#include "sling/myelin/rnn.h" #include "sling/nlp/parser/parser-state.h" #include "sling/nlp/parser/roles.h" -#include "sling/nlp/parser/trace.h" namespace sling { namespace nlp { @@ -33,12 +31,11 @@ class ParserFeatureModel { public: // Initialize feature model. void Init(myelin::Cell *cell, - myelin::Flow::Blob *spec, const RoleSet *roles, int frame_limit); - // Return tensor for hidden layer activations. - const myelin::Tensor *hidden() const { return hidden_; } + // Return output with activation for current step. + const myelin::Tensor *activation() const { return activation_; } private: // Get parameter tensor in decoder cell. @@ -50,46 +47,40 @@ class ParserFeatureModel { // Set of roles considered. const RoleSet *roles_; - // Maximum attention index considered (exclusive). + // Maximum frame attention index considered (exclusive). int frame_limit_; // Features. - myelin::Tensor *lr_focus_feature_; // LR LSTM input focus feature - myelin::Tensor *rl_focus_feature_; // RL LSTM input focus feature + myelin::Tensor *token_feature_; // current token feature - myelin::Tensor *lr_attention_feature_; // LR LSTM frame attention feature - myelin::Tensor *rl_attention_feature_; // LR LSTM frame attention feature + myelin::Tensor *attention_tokens_feature_; // token attention feature + myelin::Tensor *attention_steps_feature_; // step attention feature - myelin::Tensor *frame_create_feature_; // FF frame create feature - myelin::Tensor *frame_focus_feature_; // FF frame focus feature + myelin::Tensor *history_feature_; // history feature - myelin::Tensor *history_feature_; // history feature + myelin::Tensor *mark_tokens_feature_; // mark tokens feature + myelin::Tensor *mark_steps_feature_; // mark steps feature - myelin::Tensor *mark_lr_feature_; // LR LSTM mark-token feature - myelin::Tensor *mark_rl_feature_; // RL LSTM mark-token feature - myelin::Tensor *mark_step_feature_; // mark token step feature - myelin::Tensor *mark_distance_feature_; // mark token distance feature - - myelin::Tensor *out_roles_feature_; // out roles feature - myelin::Tensor *in_roles_feature_; // in roles feature - myelin::Tensor *unlabeled_roles_feature_; // unlabeled roles feature - myelin::Tensor *labeled_roles_feature_; // labeled roles feature + myelin::Tensor *out_roles_feature_; // out roles feature + myelin::Tensor *in_roles_feature_; // in roles feature + myelin::Tensor *unlabeled_roles_feature_; // unlabeled roles feature + myelin::Tensor *labeled_roles_feature_; // labeled roles feature // Feature dimensions. - int mark_depth_ = 0; // mark stack depth to use - int attention_depth_ = 0; // number of attention features - int history_size_ = 0; // number of history features - int out_roles_size_ = 0; // max number of out roles - int in_roles_size_ = 0; // max number of in roles - int labeled_roles_size_ = 0; // max number of unlabeled roles - int unlabeled_roles_size_ = 0; // max number of labeled roles - std::vector mark_distance_bins_; // distance bins for mark tokens - - // Links. - myelin::Tensor *lr_lstm_; // link to LR LSTM hidden layer - myelin::Tensor *rl_lstm_; // link to RL LSTM hidden layer - myelin::Tensor *steps_; // link to FF step hidden layer - myelin::Tensor *hidden_; // link to FF hidden layer output + int mark_depth_ = 0; // mark stack depth to use + int history_size_ = 0; // number of history features + int out_roles_size_ = 0; // max number of out roles + int in_roles_size_ = 0; // max number of in roles + int labeled_roles_size_ = 0; // max number of unlabeled roles + int unlabeled_roles_size_ = 0; // max number of labeled roles + + // Channel links. + myelin::Tensor *tokens_; // link to token encodings + myelin::Tensor *steps_; // link to step activations + + + // Output with activation for current step. + myelin::Tensor *activation_; friend class ParserFeatureExtractor; }; @@ -103,16 +94,13 @@ class ParserFeatureExtractor { : features_(features), state_(state) {} // Attach instance to input and output channels. - void Attach(const myelin::BiChannel &bilstm, + void Attach(myelin::Channel *encodings, myelin::Channel *activations, myelin::Instance *instance); // Extract features from current state and add these to the data instance. void Extract(myelin::Instance *data); - // Add extracted features to trace. - void TraceFeatures(myelin::Instance *instance, Trace *trace) const; - private: // Wrapper for data instance for looking up feature input tensors. class Data { diff --git a/sling/nlp/parser/parser-state.cc b/sling/nlp/parser/parser-state.cc index 8e4f47d8..081a360a 100644 --- a/sling/nlp/parser/parser-state.cc +++ b/sling/nlp/parser/parser-state.cc @@ -18,7 +18,6 @@ #include "sling/frame/object.h" #include "sling/frame/store.h" #include "sling/nlp/document/document.h" -#include "sling/string/strcat.h" namespace sling { namespace nlp { @@ -28,26 +27,7 @@ ParserState::ParserState(Document *document, int begin, int end) begin_(begin), end_(end), current_(begin), - step_(0), - done_(false) {} - -string ParserState::DebugString() const { - static const int MAX_ATTENTION = 10; - string s = - StrCat("Begin:", begin_, " End:", end_, " Current:", current_, - " Done: ", (done_ ? "Y" : "N"), " AttentionSize: ", - attention_.size(), "\n"); - for (int i = 0; i < attention_.size(); ++i) { - if (i == MAX_ATTENTION) { - StrAppend(&s, "..and ", (attention_.size() - MAX_ATTENTION), " more.\n"); - break; - } - StrAppend(&s, "AttentionIndex: ", i, - " FrameType:", store()->DebugString(Type(i)), "\n"); - } - - return s; -} + step_(0) {} void ParserState::Apply(const ParserAction &action) { switch (action.type) { @@ -55,10 +35,6 @@ void ParserState::Apply(const ParserAction &action) { Shift(); break; - case ParserAction::STOP: - Stop(); - break; - case ParserAction::MARK: Mark(); break; @@ -79,14 +55,6 @@ void ParserState::Apply(const ParserAction &action) { Assign(action.source, action.role, action.label); break; - case ParserAction::EMBED: - Embed(action.target, action.role, action.label); - break; - - case ParserAction::ELABORATE: - Elaborate(action.source, action.role, action.label); - break; - case ParserAction::CASCADE: LOG(FATAL) << "Cannot apply CASCADE action"; break; @@ -95,7 +63,6 @@ void ParserState::Apply(const ParserAction &action) { } bool ParserState::CanApply(const ParserAction &action) const { - if (done_) return false; switch (action.type) { case ParserAction::CASCADE: // Do not allow cascading back to the main cascade. @@ -105,10 +72,6 @@ bool ParserState::CanApply(const ParserAction &action) const { // Do not allow shifting past the end of the input buffer. return current_ < end_; - case ParserAction::STOP: - // Only allow stop if we are at the end of the input buffer. - return current_ == end_; - case ParserAction::MARK: return current_ < end_ && marks_.size() < MAX_MARK_DEPTH; @@ -199,32 +162,6 @@ bool ParserState::CanApply(const ParserAction &action) const { Frame frame(store(), Attention(source).frame); return !frame.Has(action.role, Attention(target).frame); } - - case ParserAction::EMBED: { - // Check that target is a valid index into the attention buffer. - if (action.target >= attention_.size()) return false; - - // Check that we haven't embedded the same frame the same way. - Handle target = Attention(action.target).frame; - for (const auto &e : embed_) { - if (e.first == target && e.second == action.label) return false; - } - - return true; - } - - case ParserAction::ELABORATE: { - // Check that source is a valid index into the attention buffer. - if (action.source >= attention_.size()) return false; - - // Check that we haven't elaborated the same frame the same way. - Handle source = Attention(action.source).frame; - for (const auto &e : elaborate_) { - if (e.first == source && e.second == action.label) return false; - } - - return true; - } } return false; @@ -233,14 +170,6 @@ bool ParserState::CanApply(const ParserAction &action) const { void ParserState::Shift() { // Move to the next token in the input buffer. current_++; - - // Clear the states for EMBED and ELABORATE. - embed_.clear(); - elaborate_.clear(); -} - -void ParserState::Stop() { - done_ = true; } void ParserState::Evoke(int length, Handle type) { @@ -315,41 +244,6 @@ void ParserState::Assign(int frame, Handle role, Handle value) { Center(frame, nullptr); } -void ParserState::Embed(int frame, Handle role, Handle type) { - // Create new frame with the specified type and add link to target frame. - Handle target = Attention(frame).frame; - Slot slots[2]; - slots[0].name = Handle::isa(); - slots[0].value = type; - slots[1].name = role; - slots[1].value = target; - Handle h = store()->AllocateFrame(slots, slots + 2); - embed_.emplace_back(target, type); - - // Add new frame to the attention buffer. - Add(h, nullptr); - - // Add new frame as a thematic frame to the document. - document_->AddTheme(h); -} - -void ParserState::Elaborate(int frame, Handle role, Handle type) { - // Create new frame with the specified type. - Handle source = Attention(frame).frame; - Slot slot(Handle::isa(), type); - Handle target = store()->AllocateFrame(&slot, &slot + 1); - - // Add new frame as a thematic frame to the document. - document_->AddTheme(target); - - // Add link to new frame from source frame. - store()->Add(source, role, target); - elaborate_.emplace_back(Attention(frame).frame, type); - - // Add new frame to the attention buffer. - Add(target, nullptr); -} - void ParserState::Add(Handle frame, Span *span) { attention_.emplace_back(frame, step_, span); } @@ -385,10 +279,6 @@ int ParserState::AttentionIndex(Handle frame, int k) const { return -1; } -Handle ParserState::Type(int index) const { - return store()->GetFrame(Attention(index).frame)->get(Handle::isa()); -} - } // namespace nlp } // namespace sling diff --git a/sling/nlp/parser/parser-state.h b/sling/nlp/parser/parser-state.h index 7b50f33d..0d3d9ddc 100644 --- a/sling/nlp/parser/parser-state.h +++ b/sling/nlp/parser/parser-state.h @@ -76,10 +76,6 @@ class ParserState { // ensure that 'action' is applicable using CanApply(). void Apply(const ParserAction &action); - // Returns the first type for a frame in the attention buffer. This will be - // the type specified when the frame was created with EVOKE/EMBED/ELABORATE. - Handle Type(int index) const; - // Gets the handles of the k frames that are closest to the center of // attention in the order of attention. There might be less than k frames if // there are fewer elements in the attention buffer. @@ -90,8 +86,8 @@ class ParserState { // limited to the top-k frames that are closest to the center of attention. int AttentionIndex(Handle frame, int k = -1) const; - // The parse is done when we have performed the first STOP action. - bool done() const { return done_; } + // The parse is done when we have reached the end of the sentence. + bool done() const { return current_ == end_; } // Returns slot in attention buffer. The center of attention has index 0. const AttentionSlot &Attention(int index) const { @@ -104,26 +100,20 @@ class ParserState { // Returns the size of the attention buffer. int AttentionSize() const { return attention_.size(); } - // Returns whether 'action' can be applied to the state. + // Returns whether action can be applied to the state. bool CanApply(const ParserAction &action) const; - // Returns a human-readable representation of the state. - string DebugString() const; - // Returns the underlying store. Store *store() const { return document_->store(); } private: // Applies individual actions, which are assumed to be applicable. void Shift(); - void Stop(); void Mark(); void Evoke(int length, Handle type); void Refer(int length, int frame); void Connect(int source, Handle role, int target); void Assign(int frame, Handle role, Handle value); - void Embed(int frame, Handle role, Handle type); - void Elaborate(int frame, Handle role, Handle type); // Adds frame to attention buffer, making it the new center of attention. void Add(Handle frame, Span *span); @@ -145,9 +135,6 @@ class ParserState { // Current parse step. int step_; - // When we have performed the first STOP action, the parse is done. - bool done_; - // Attention buffer. This contains evoked frames in order of attention. The // last element is the center of attention. std::vector attention_; @@ -156,11 +143,6 @@ class ParserState { // each mention. std::vector marks_; - // (Source/Target frame handle, Frame type) for frames embedded or elaborated - // at the current position. This is cleared once the position advances. - std::vector> embed_; - std::vector> elaborate_; - // Maximum mark depth. static const int MAX_MARK_DEPTH = 5; }; diff --git a/sling/nlp/parser/parser-trainer.cc b/sling/nlp/parser/parser-trainer.cc index ecac96bd..647726f2 100644 --- a/sling/nlp/parser/parser-trainer.cc +++ b/sling/nlp/parser/parser-trainer.cc @@ -36,12 +36,15 @@ ParserTrainer::~ParserTrainer() { void ParserTrainer::Run(task::Task *task) { // Get training parameters. - task->Fetch("lstm_dim", &lstm_dim_); - task->Fetch("max_source", &max_source_); - task->Fetch("max_target", &max_target_); + task->Fetch("rnn_dim", &rnn_dim_); + task->Fetch("rnn_layers", &rnn_layers_); + task->Fetch("rnn_type", &rnn_type_); + task->Fetch("rnn_bidir", &rnn_bidir_); + task->Fetch("rnn_highways", &rnn_highways_); + task->Fetch("mark_depth", &mark_depth_); + task->Fetch("mark_dim", &mark_dim_); task->Fetch("frame_limit", &frame_limit_); - task->Fetch("attention_depth", &attention_depth_); task->Fetch("history_size", &history_size_); task->Fetch("out_roles_size", &out_roles_size_); task->Fetch("in_roles_size", &in_roles_size_); @@ -49,13 +52,16 @@ void ParserTrainer::Run(task::Task *task) { task->Fetch("unlabeled_roles_size", &unlabeled_roles_size_); task->Fetch("roles_dim", &roles_dim_); task->Fetch("activations_dim", &activations_dim_); - task->Fetch("link_dim_lstm", &link_dim_lstm_); - task->Fetch("link_dim_ff", &link_dim_ff_); - task->Fetch("mark_dim", &mark_dim_); + task->Fetch("link_dim_token", &link_dim_token_); + task->Fetch("link_dim_step", &link_dim_step_); + task->Fetch("seed", &seed_); task->Fetch("batch_size", &batch_size_); task->Fetch("learning_rate", &learning_rate_); task->Fetch("min_learning_rate", &min_learning_rate_); + task->Fetch("learning_rate_cliff", &learning_rate_cliff_); + task->Fetch("dropout", &dropout_); + task->Fetch("ff_l2reg", &ff_l2reg_); // Statistics. num_tokens_ = task->GetCounter("tokens"); @@ -99,6 +105,14 @@ void ParserTrainer::Run(task::Task *task) { spec_.quote_dim = task->Get("quote_dim", 8);; spec_.digit_dim = task->Get("digit_dim", 8);; + // Set up RNNs. + RNN::Spec rnn_spec; + rnn_spec.type = static_cast(rnn_type_); + rnn_spec.dim = rnn_dim_; + rnn_spec.highways = rnn_highways_; + rnn_spec.dropout = dropout_; + encoder_.AddLayers(rnn_layers_, rnn_spec, rnn_bidir_); + // Custom parser model initialization. This should set up the word and role // vocabularies as well as the delegate cascade. Setup(task); @@ -112,25 +126,33 @@ void ParserTrainer::Run(task::Task *task) { compiler_.Compile(&flow_, &model_); // Get decoder cells and tensors. - decoder_ = model_.GetCell("ff_trunk"); - activations_ = decoder_->GetParameter("ff_trunk/steps"); + decoder_ = model_.GetCell("decoder"); + encodings_ = decoder_->GetParameter("decoder/tokens"); + activations_ = decoder_->GetParameter("decoder/steps"); + activation_ = decoder_->GetParameter("decoder/activation"); + gdecoder_ = decoder_->Gradient(); primal_ = decoder_->Primal(); + dencodings_ = encodings_->Gradient(); dactivations_ = activations_->Gradient(); - dactivation_ = gdecoder_->GetParameter("ff_trunk/hidden")->Gradient(); - dlr_ = gdecoder_->GetParameter("gradients/ff_trunk/d_lr_lstm"); - drl_ = gdecoder_->GetParameter("gradients/ff_trunk/d_rl_lstm"); + dactivation_ = activation_->Gradient(); // Initialize model. - feature_model_.Init(model_.GetCell("ff_trunk"), - flow_.DataBlock("spec"), - &roles_, frame_limit_); - model_.InitLearnableWeights(seed_, 0.0, 0.01); + feature_model_.Init(decoder_, &roles_, frame_limit_); + model_.InitModelParameters(seed_); encoder_.Initialize(model_); optimizer_->Initialize(model_); for (auto *d : delegates_) d->Initialize(model_); commons_.Freeze(); + // Optionally load initial model parameters for restart. + if (task->Get("restart", false) && !model_filename_.empty()) { + LOG(INFO) << "Load model parameters from " << model_filename_; + Flow initial; + CHECK(initial.Load(model_filename_)); + model_.LoadParameters(initial); + } + // Train model. Train(task, &model_); @@ -162,6 +184,7 @@ void ParserTrainer::Worker(int index, Network *model) { std::vector decoders; myelin::Channel activations(activations_); myelin::Channel dactivations(dactivations_); + myelin::Channel dencodings(dencodings_); for (;;) { // Prepare next batch. for (auto *g : gradients) g->Clear(); @@ -199,7 +222,8 @@ void ParserTrainer::Worker(int index, Network *model) { } // Run document through encoder to produce contextual token encodings. - auto bilstm = encoder.Compute(*document, 0, document->length()); + myelin::Channel *encodings = + encoder.Compute(*document, 0, document->length()); // Run decoder and delegates on all steps in the transition sequence. int t = 0; @@ -211,7 +235,7 @@ void ParserTrainer::Worker(int index, Network *model) { // Attach instance to recurrent layers. decoder->Clear(); - features.Attach(bilstm, &activations, decoder); + features.Attach(encodings, &activations, decoder); // Extract features. features.Extract(decoder); @@ -239,18 +263,17 @@ void ParserTrainer::Worker(int index, Network *model) { } // Propagate gradients back through decoder. - auto grad = encoder.PrepareGradientChannels(document->length()); + dencodings.reset(document->length()); for (int s = steps - 1; s >= 0; --s) { gdecoder.Set(primal_, decoders[s]); + gdecoder.Set(dencodings_, &dencodings); gdecoder.Set(dactivations_, &dactivations); gdecoder.Set(dactivation_, &dactivations, s); - gdecoder.Set(dlr_, grad.lr); - gdecoder.Set(drl_, grad.rl); gdecoder.Compute(); } // Propagate gradients back through encoder. - encoder.Backpropagate(); + encoder.Backpropagate(&dencodings); delete document; } @@ -280,22 +303,22 @@ void ParserTrainer::Parse(Document *document) const { for (SentenceIterator s(document); s.more(); s.next()) { // Run the lexical encoder for sentence. LexicalEncoderInstance encoder(encoder_); - auto bilstm = encoder.Compute(*document, s.begin(), s.end()); + myelin::Channel *encodings = encoder.Compute(*document, s.begin(), s.end()); // Initialize decoder. ParserState state(document, s.begin(), s.end()); ParserFeatureExtractor features(&feature_model_, &state); myelin::Instance decoder(decoder_); - myelin::Channel activations(feature_model_.hidden()); + myelin::Channel activations(feature_model_.activation()); // Run decoder to predict transitions. - for (;;) { + while (!state.done()) { // Allocate space for next step. activations.push(); // Attach instance to recurrent layers. decoder.Clear(); - features.Attach(bilstm, &activations, &decoder); + features.Attach(encodings, &activations, &decoder); // Extract features. features.Extract(&decoder); @@ -315,20 +338,13 @@ void ParserTrainer::Parse(Document *document) const { d = action.delegate; } - // Shift or stop if predicted action is invalid. + // Fall back to SHIFT if predicted action is not valid. if (!state.CanApply(action)) { - if (state.current() < state.end()) { - action.type = ParserAction::SHIFT; - } else { - action.type = ParserAction::STOP; - } + action.type = ParserAction::SHIFT; } // Apply action to parser state. state.Apply(action); - - // Check if we are done. - if (action.type == ParserAction::STOP) break; } } @@ -346,11 +362,10 @@ bool ParserTrainer::Evaluate(int64 epoch, Network *model) { loss_count_ = 0; // Decay learning rate if loss increases. - if (prev_loss_ != 0.0 && - prev_loss_ < loss && - learning_rate_ > min_learning_rate_) { - learning_rate_ = optimizer_->DecayLearningRate(); - } + bool decay = prev_loss_ != 0.0 && prev_loss_ < loss; + if (learning_rate_cliff_ != 0 && epoch >= learning_rate_cliff_) decay = true; + if (learning_rate_ <= min_learning_rate_) decay = false; + if (decay) learning_rate_ = optimizer_->DecayLearningRate(); prev_loss_ = loss; LOG(INFO) << "epoch=" << epoch @@ -364,8 +379,10 @@ bool ParserTrainer::Evaluate(int64 epoch, Network *model) { FrameEvaluation::Evaluate(&corpus, &eval); LOG(INFO) << "SPAN: " << eval.mention.Summary(); LOG(INFO) << "FRAME: " << eval.frame.Summary(); - LOG(INFO) << "TYPE: " << eval.type.Summary(); + LOG(INFO) << "PAIR: " << eval.pair.Summary(); + LOG(INFO) << "EDGE: " << eval.edge.Summary(); LOG(INFO) << "ROLE: " << eval.role.Summary(); + LOG(INFO) << "TYPE: " << eval.type.Summary(); LOG(INFO) << "LABEL: " << eval.label.Summary(); LOG(INFO) << "SLOT: " << eval.slot.Summary(); LOG(INFO) << "TOTAL: " << eval.combined.Summary(); @@ -382,105 +399,98 @@ void ParserTrainer::Checkpoint(int64 epoch, Network *model) { void ParserTrainer::Build(Flow *flow, bool learn) { // Build document input encoder. - BiLSTM::Outputs lstm; + RNN::Variables rnn; if (learn) { Vocabulary::HashMapIterator vocab(words_); - lstm = encoder_.Build(flow, spec_, &vocab, lstm_dim_, true); + rnn = encoder_.Build(flow, spec_, &vocab, true); } else { - lstm = encoder_.Build(flow, spec_, nullptr, lstm_dim_, false); + rnn = encoder_.Build(flow, spec_, nullptr, false); } + int token_dim = rnn.output->elements(); // Build parser decoder. - FlowBuilder f(flow, "ff_trunk"); + FlowBuilder f(flow, "decoder"); std::vector features; - Flow::Blob *spec = flow->AddBlob("spec", ""); // Add inputs for recurrent channels. - auto *lr = f.Placeholder("link/lr_lstm", DT_FLOAT, {1, lstm_dim_}, true); - auto *rl = f.Placeholder("link/rl_lstm", DT_FLOAT, {1, lstm_dim_}, true); + auto *tokens = f.Placeholder("tokens", DT_FLOAT, {1, token_dim}, true); auto *steps = f.Placeholder("steps", DT_FLOAT, {1, activations_dim_}, true); // Role features. if (in_roles_size_ > 0) { - features.push_back(f.Feature("in-roles", roles_.size() * frame_limit_, + features.push_back(f.Feature("in_roles", roles_.size() * frame_limit_, in_roles_size_, roles_dim_)); } if (out_roles_size_ > 0) { - features.push_back(f.Feature("out-roles", roles_.size() * frame_limit_, + features.push_back(f.Feature("out_roles", roles_.size() * frame_limit_, out_roles_size_, roles_dim_)); } if (labeled_roles_size_ > 0) { - features.push_back(f.Feature("labeled-roles", + features.push_back(f.Feature("labeled_roles", roles_.size() * frame_limit_ * frame_limit_, labeled_roles_size_, roles_dim_)); } if (unlabeled_roles_size_ > 0) { - features.push_back(f.Feature("unlabeled-roles", + features.push_back(f.Feature("unlabeled_roles", frame_limit_ * frame_limit_, unlabeled_roles_size_, roles_dim_)); } // Link features. - features.push_back(LinkedFeature(&f, "frame-creation-steps", - steps, frame_limit_, link_dim_ff_)); - features.push_back(LinkedFeature(&f, "frame-focus-steps", - steps, frame_limit_, link_dim_ff_)); + features.push_back(LinkedFeature(&f, "token", tokens, 1, link_dim_token_)); + features.push_back(LinkedFeature(&f, "attention_tokens", + tokens, frame_limit_, link_dim_token_)); + features.push_back(LinkedFeature(&f, "attention_steps", + steps, frame_limit_, link_dim_step_)); features.push_back(LinkedFeature(&f, "history", - steps, history_size_, link_dim_ff_)); - features.push_back(LinkedFeature(&f, "frame-end-lr", - lr, frame_limit_, link_dim_lstm_)); - features.push_back(LinkedFeature(&f, "frame-end-rl", - rl, frame_limit_, link_dim_lstm_)); - features.push_back(LinkedFeature(&f, "lr", lr, 1, link_dim_lstm_)); - features.push_back(LinkedFeature(&f, "rl", rl, 1, link_dim_lstm_)); + steps, history_size_, link_dim_step_)); // Mark features. - features.push_back(f.Feature("mark-distance", - mark_distance_bins_.size() + 1, - mark_depth_, mark_dim_)); - features.push_back(LinkedFeature(&f, "mark-lr", - lr, mark_depth_, link_dim_lstm_)); - features.push_back(LinkedFeature(&f, "mark-rl", - rl, mark_depth_, link_dim_lstm_)); - features.push_back(LinkedFeature(&f, "mark-step", - steps, mark_depth_, link_dim_ff_)); - string bins; - for (int d : mark_distance_bins_) { - if (!bins.empty()) bins.push_back(' '); - bins.append(std::to_string(d)); + features.push_back(LinkedFeature(&f, "mark_tokens", + tokens, mark_depth_, link_dim_token_)); + features.push_back(LinkedFeature(&f, "mark_steps", + steps, mark_depth_, link_dim_step_)); + + // Pad feature vector. + const static int alignment = 16; + int n = 0; + for (auto *f : features) n += f->elements(); + if (n % alignment != 0) { + int padding = alignment - n % alignment; + auto *zeroes = f.Const(nullptr, DT_FLOAT, {1, padding}); + features.push_back(zeroes); } - spec->SetAttr("mark_distance_bins", bins); - spec->SetAttr("frame_limit", frame_limit_); // Concatenate mapped feature inputs. auto *fv = f.Concat(features); int fvsize = fv->dim(1); // Feed-forward layer. - auto *W = f.Random(f.Parameter("W0", DT_FLOAT, {fvsize, activations_dim_})); - auto *b = f.Random(f.Parameter("b0", DT_FLOAT, {1, activations_dim_})); - auto *activations = f.Name(f.Relu(f.Add(f.MatMul(fv, W), b)), "hidden"); - activations->set_in()->set_out()->set_ref(); + auto *W = f.Parameter("W0", DT_FLOAT, {fvsize, activations_dim_}); + auto *b = f.Parameter("b0", DT_FLOAT, {1, activations_dim_}); + f.RandomNormal(W); + if (ff_l2reg_ != 0.0) W->SetAttr("l2reg", ff_l2reg_); + auto *activation = f.Name(f.Relu(f.Add(f.MatMul(fv, W), b)), "activation"); + activation->set_in()->set_out()->set_ref(); // Build function decoder gradient. - Flow::Variable *dactivations = nullptr; + Flow::Variable *dactivation = nullptr; if (learn) { Gradient(flow, f.func()); - dactivations = flow->GradientVar(activations); + dactivation = flow->GradientVar(activation); } // Build flows for delegates. for (DelegateLearner *delegate : delegates_) { - delegate->Build(flow, activations, dactivations, learn); + delegate->Build(flow, activation, dactivation, learn); } // Link recurrences. - flow->Connect({lstm.lr, lr}); - flow->Connect({lstm.rl, rl}); - flow->Connect({steps, activations}); + flow->Connect({tokens, rnn.output}); + flow->Connect({steps, activation}); if (learn) { auto *dsteps = flow->GradientVar(steps); - flow->Connect({dsteps, dactivations}); + flow->Connect({dsteps, dactivation}); } } @@ -514,36 +524,55 @@ void ParserTrainer::Save(const string &filename) { Build(&flow, false); // Copy weights from trained model. - model_.SaveLearnedWeights(&flow); + model_.SaveParameters(&flow); // Save lexicon. encoder_.SaveLexicon(&flow); - // Save extra model data in store. + // Make parser specification frame. Store store(&commons_); - SaveModel(&flow, &store); + Builder spec(&store); + + // Save encoder spec. + Builder encoder_spec(&store); + encoder_spec.Add("type", "lexrnn"); + encoder_spec.Add("rnn", static_cast(rnn_type_)); + encoder_spec.Add("dim", rnn_dim_); + encoder_spec.Add("layers", rnn_layers_); + encoder_spec.Add("bidir", rnn_bidir_); + encoder_spec.Add("highways", rnn_highways_); + spec.Set("encoder", encoder_spec.Create()); + + // Save decoder spec. + Builder decoder_spec(&store); + decoder_spec.Add("type", "transition"); + decoder_spec.Set("frame_limit", frame_limit_); + decoder_spec.Set("sentence_reset", sentence_reset_); + + Handles role_list(&store); + roles_.GetList(&role_list); + decoder_spec.Set("roles", Array(&store, role_list)); - // Save delegates. - Builder cascade(&store); - cascade.AddId("/cascade"); Array delegates(&store, delegates_.size()); for (int i = 0; i < delegates_.size(); ++i) { - Builder data(&store); - delegates_[i]->Save(&flow, &data); - delegates.set(i, data.Create().handle()); + Builder delegate_spec(&store); + delegates_[i]->Save(&flow, &delegate_spec); + delegates.set(i, delegate_spec.Create().handle()); } - cascade.Add("delegates", delegates); - cascade.Create(); + decoder_spec.Set("delegates", delegates); + + spec.Set("decoder", decoder_spec.Create()); - // Save store in flow. + // Save parser spec in flow. StringEncoder encoder(&store); - encoder.EncodeAll(); + encoder.Encode(spec.Create()); - Flow::Blob *blob = flow.AddBlob("commons", "frames"); + Flow::Blob *blob = flow.AddBlob("parser", "frame"); blob->data = flow.AllocateMemory(encoder.buffer()); blob->size = encoder.buffer().size(); // Save model to file. + DCHECK(flow.IsConsistent()); flow.Save(filename); } diff --git a/sling/nlp/parser/parser-trainer.h b/sling/nlp/parser/parser-trainer.h index 1b66785a..7532cb68 100644 --- a/sling/nlp/parser/parser-trainer.h +++ b/sling/nlp/parser/parser-trainer.h @@ -46,8 +46,8 @@ class DelegateLearner { // Build flow for delegate learner. virtual void Build(myelin::Flow *flow, - myelin::Flow::Variable *activations, - myelin::Flow::Variable *dactivations, + myelin::Flow::Variable *activation, + myelin::Flow::Variable *dactivation, bool learn) = 0; // Initialize network for delegate. @@ -57,7 +57,7 @@ class DelegateLearner { virtual DelegateLearnerInstance *CreateInstance() = 0; // Save model data to flow. - virtual void Save(myelin::Flow *flow, Builder *data) = 0; + virtual void Save(myelin::Flow *flow, Builder *spec) = 0; }; // Interface for delegate learner instance. @@ -72,8 +72,8 @@ class DelegateLearnerInstance { virtual void ClearGradients() = 0; // Compute loss and gradient for delegate with respect to golden action. - virtual float Compute(float *activations, - float *dactivations, + virtual float Compute(float *activation, + float *dactivation, const ParserAction &action) = 0; // Predict action for delegate. @@ -98,9 +98,6 @@ class ParserTrainer : public task::LearnerTask { virtual void GenerateTransitions(const Document &document, std::vector *transitions) = 0; - // Abstract method for saving extra data in final model. - virtual void SaveModel(myelin::Flow *flow, Store *store) = 0; - private: // Build flow graph for parser model. void Build(myelin::Flow *flow, bool learn); @@ -153,6 +150,9 @@ class ParserTrainer : public task::LearnerTask { // Role set. RoleSet roles_; + // Reset parser state between sentences in a document. + bool sentence_reset_ = false; + // Lexical feature specification for encoder. LexicalFeatures::Spec spec_; @@ -170,13 +170,15 @@ class ParserTrainer : public task::LearnerTask { // Decoder model. myelin::Cell *decoder_ = nullptr; + myelin::Tensor *encodings_ = nullptr; myelin::Tensor *activations_ = nullptr; + myelin::Tensor *activation_ = nullptr; + myelin::Cell *gdecoder_ = nullptr; - myelin::Tensor *dactivations_ = nullptr; myelin::Tensor *primal_ = nullptr; + myelin::Tensor *dencodings_ = nullptr; + myelin::Tensor *dactivations_ = nullptr; myelin::Tensor *dactivation_ = nullptr; - myelin::Tensor *dlr_ = nullptr; - myelin::Tensor *drl_ = nullptr; // Delegates. std::vector delegates_; @@ -186,12 +188,13 @@ class ParserTrainer : public task::LearnerTask { Mutex update_mu_; // Model hyperparameters. - int lstm_dim_ = 256; - int max_source_ = 5; - int max_target_ = 10; + int rnn_type_ = myelin::RNN::LSTM; + int rnn_dim_ = 256; + int rnn_layers_ = 1; + bool rnn_bidir_ = true; + bool rnn_highways_ = false; int mark_depth_ = 1; int frame_limit_ = 5; - int attention_depth_ = 5; int history_size_ = 5; int out_roles_size_ = 32; int in_roles_size_ = 32; @@ -199,14 +202,16 @@ class ParserTrainer : public task::LearnerTask { int unlabeled_roles_size_ = 32; int roles_dim_ = 16; int activations_dim_ = 128; - int link_dim_lstm_ = 32; - int link_dim_ff_ = 64; + int link_dim_token_ = 32; + int link_dim_step_ = 64; int mark_dim_ = 32; - std::vector mark_distance_bins_{0, 1, 2, 3, 6, 10, 15, 20}; int seed_ = 0; int batch_size_ = 32; + int learning_rate_cliff_ = 0; float learning_rate_ = 1.0; float min_learning_rate_ = 0.001; + float dropout_ = 0.0; + float ff_l2reg_ = 0.0; // Evaluation statistics. float prev_loss_ = 0.0; diff --git a/sling/nlp/parser/parser.cc b/sling/nlp/parser/parser.cc index 358f735e..8be82b1c 100644 --- a/sling/nlp/parser/parser.cc +++ b/sling/nlp/parser/parser.cc @@ -12,117 +12,139 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include - #include "sling/nlp/parser/parser.h" #include "sling/frame/serialization.h" -#include "sling/myelin/profile.h" -#include "sling/myelin/kernel/dragnn.h" -#include "sling/nlp/document/document.h" -#include "sling/nlp/document/features.h" -#include "sling/nlp/document/lexicon.h" -#include "sling/nlp/parser/action-table.h" -using namespace std::placeholders; +REGISTER_COMPONENT_REGISTRY("parser delegate", sling::nlp::Delegate); namespace sling { namespace nlp { +Parser::~Parser() { + for (auto *d : delegates_) delete d; +} + void Parser::Load(Store *store, const string &model) { - // Load and analyze parser flow file. + // Load and compile parser flow. myelin::Flow flow; CHECK(flow.Load(model)); - - // FIXME(ringgaard): Patch feature cell output. - flow.Var("features/feature_vector")->set_in(); - - // Register DRAGNN kernel to support legacy parser models. - RegisterDragnnLibrary(compiler_.library()); - - // Compile parser flow. compiler_.Compile(&flow, &network_); - // Initialize lexical encoder. - encoder_.Initialize(network_); - encoder_.LoadLexicon(&flow); - // Load commons store from parser model. myelin::Flow::Blob *commons = flow.DataBlock("commons"); - CHECK(commons != nullptr); - StringDecoder decoder(store, commons->data, commons->size); - decoder.DecodeAll(); + if (commons != nullptr) { + StringDecoder decoder(store, commons->data, commons->size); + decoder.DecodeAll(); + } - // Read the cascade specification and implementation from the flow. - Frame cascade_spec(store, "/cascade"); - CHECK(cascade_spec.valid()); - cascade_.Initialize(network_, cascade_spec); + // Get parser specification. + myelin::Flow::Blob *spec_data = flow.DataBlock("parser"); + CHECK(spec_data != nullptr) << "No parser specification in model: " << model; + StringDecoder spec_decoder(store, spec_data->data, spec_data->size); + Frame spec = spec_decoder.Decode().AsFrame(); + CHECK(spec.valid()); + + // Initialize encoder. + Frame encoder_spec = spec.GetFrame("encoder"); + CHECK(encoder_spec.valid()); + CHECK_EQ(encoder_spec.GetText("type"), "lexrnn"); + myelin::RNN::Spec rnn_spec; + rnn_spec.type = static_cast(encoder_spec.GetInt("rnn")); + rnn_spec.dim = encoder_spec.GetInt("dim"); + rnn_spec.highways = encoder_spec.GetBool("highways"); + int rnn_layers = encoder_spec.GetInt("layers"); + bool rnn_bidir = encoder_spec.GetBool("bidir"); + + encoder_.AddLayers(rnn_layers, rnn_spec, rnn_bidir); + encoder_.Initialize(network_); + encoder_.LoadLexicon(&flow); - // Initialize action table. - store_ = store; - ActionTable actions; - actions.Init(store); - roles_.Init(actions.list()); + // Initialize decoder. + Frame decoder_spec = spec.GetFrame("decoder"); + CHECK(decoder_spec.valid()); + CHECK_EQ(decoder_spec.GetText("type"), "transition"); + int frame_limit = decoder_spec.GetInt("frame_limit"); + + // Initialize roles. + Array roles = decoder_spec.Get("roles").AsArray(); + if (roles.valid()) { + for (int i = 0; i < roles.length(); ++i) { + roles_.Add(roles.get(i)); + } + } + + // Initialize decoder cascade. + Array delegates = decoder_spec.Get("delegates").AsArray(); + CHECK(delegates.valid()); + for (int i = 0; i < delegates.length(); ++i) { + Frame delegate_spec(store, delegates.get(i)); + string type = delegate_spec.GetString("type"); + Delegate *delegate = Delegate::Create(type); + delegate->Initialize(network_, delegate_spec); + delegates_.push_back(delegate); + } // Initialize decoder feature model. - myelin::Flow::Blob *spec = flow.DataBlock("spec"); - decoder_ = network_.GetCell("ff_trunk"); - feature_model_.Init(decoder_, spec, &roles_, actions.frame_limit()); + decoder_ = network_.GetCell("decoder"); + feature_model_.Init(decoder_, &roles_, frame_limit); } void Parser::Parse(Document *document) const { + // Create delegates. + std::vector delegates; + for (auto *d : delegates_) delegates.push_back(d->CreateInstance()); + // Parse each sentence of the document. + LexicalEncoderInstance encoder(encoder_); for (SentenceIterator s(document); s.more(); s.next()) { - // Set up trace if feature tracing is enabled. - Trace *trace = trace_ ? new Trace(s.begin(), s.end()) : nullptr; - // Run the lexical encoder for sentence. - LexicalEncoderInstance encoder(encoder_); - if (trace) { - encoder.set_trace(std::bind(&Trace::AddLSTM, trace, _1, _2, _3)); - } - auto bilstm = encoder.Compute(*document, s.begin(), s.end()); + myelin::Channel *encodings = encoder.Compute(*document, s.begin(), s.end()); // Initialize decoder. ParserState state(document, s.begin(), s.end()); ParserFeatureExtractor features(&feature_model_, &state); myelin::Instance decoder(decoder_); - myelin::Channel activations(feature_model_.hidden()); - CascadeInstance cascade(&cascade_); + myelin::Channel activations(feature_model_.activation()); // Run decoder to predict transitions. - for (;;) { + while (!state.done()) { // Allocate space for next step. activations.push(); // Attach instance to recurrent layers. decoder.Clear(); - features.Attach(bilstm, &activations, &decoder); + features.Attach(encodings, &activations, &decoder); // Extract features. features.Extract(&decoder); - if (trace) features.TraceFeatures(&decoder, trace); // Compute decoder activations. decoder.Compute(); // Run the cascade. - ParserAction action; - cascade.Compute(&activations, &state, &action, trace); + ParserAction action(ParserAction::CASCADE, 0); + int step = state.step(); + float *activation = reinterpret_cast(activations.at(step)); + int d = 0; + for (;;) { + delegates[d]->Predict(activation, &action); + if (action.type != ParserAction::CASCADE) break; + CHECK_GT(action.delegate, d); + d = action.delegate; + } + + // Fall back to SHIFT if predicted action is not valid. + if (!state.CanApply(action)) { + action.type = ParserAction::SHIFT; + } // Apply action to parser state. state.Apply(action); - - // Check if we are done. - if (action.type == ParserAction::STOP) break; - } - - // Write feature trace to document. - if (trace) { - trace->Write(document); - delete trace; } } + + for (auto *d : delegates) delete d; } } // namespace nlp diff --git a/sling/nlp/parser/parser.h b/sling/nlp/parser/parser.h index beb51af6..6e946249 100644 --- a/sling/nlp/parser/parser.h +++ b/sling/nlp/parser/parser.h @@ -15,9 +15,10 @@ #ifndef SLING_NLP_PARSER_PARSER_H_ #define SLING_NLP_PARSER_PARSER_H_ -#include +#include #include "sling/base/logging.h" +#include "sling/base/registry.h" #include "sling/base/types.h" #include "sling/frame/store.h" #include "sling/myelin/compiler.h" @@ -25,18 +26,45 @@ #include "sling/myelin/flow.h" #include "sling/nlp/document/document.h" #include "sling/nlp/document/lexical-encoder.h" -#include "sling/nlp/parser/cascade.h" #include "sling/nlp/parser/parser-features.h" #include "sling/nlp/parser/parser-state.h" #include "sling/nlp/parser/roles.h" -#include "sling/nlp/parser/trace.h" namespace sling { namespace nlp { -// Frame semantics parser model. +class DelegateInstance; + +// Interface for delegate component at prediction time. +class Delegate : public Component { + public: + virtual ~Delegate() = default; + + // Initialize delegate from specification. + virtual void Initialize(const myelin::Network &network, + const Frame &spec) = 0; + + // Create new delegate instance for action prediction. + virtual DelegateInstance *CreateInstance() = 0; +}; + +#define REGISTER_DELEGATE(type, component) \ + REGISTER_COMPONENT_TYPE(sling::nlp::Delegate, type, component) + +// Interface for delegate instance at prediction time. +class DelegateInstance { + public: + virtual ~DelegateInstance() = default; + + // Predict action for delegate. + virtual void Predict(float *activations, ParserAction *action) = 0; +}; + +// Frame semantics parser. class Parser { public: + ~Parser(); + // Load and initialize parser model. void Load(Store *store, const string &filename); @@ -46,15 +74,6 @@ class Parser { // Neural network for parser. const myelin::Network &network() const { return network_; } - // Return the lexical encoder. - const LexicalEncoder &encoder() const { return encoder_; } - - // Returns whether tracing is enabled. - bool trace() const { return trace_; } - - // Enable/disable tracing. - void set_trace(bool trace) { trace_ = trace; } - private: // JIT compiler. myelin::Compiler compiler_; @@ -71,17 +90,11 @@ class Parser { // Parser feature model for feature extraction in the decoder. ParserFeatureModel feature_model_; - // Cascade. - Cascade cascade_; - - // Global store for parser. - Store *store_ = nullptr; + // Cascade with parser action prediction delegates. + std::vector delegates_; // Set of roles considered. RoleSet roles_; - - // Whether tracing is enabled. - bool trace_; }; } // namespace nlp diff --git a/sling/nlp/parser/roles.cc b/sling/nlp/parser/roles.cc index ef7cde7b..7fbd50f4 100644 --- a/sling/nlp/parser/roles.cc +++ b/sling/nlp/parser/roles.cc @@ -17,12 +17,21 @@ namespace sling { namespace nlp { -void RoleSet::Init(const std::vector &actions) { - for (const ParserAction &action : actions) { - if (!action.role.IsNil() && roles_.find(action.role) == roles_.end()) { - int index = roles_.size(); - roles_[action.role] = index; - } +void RoleSet::Add(Handle role) { + if (!role.IsNil() && roles_.find(role) == roles_.end()) { + int index = roles_.size(); + roles_[role] = index; + } +} + +void RoleSet::Add(const std::vector &actions) { + for (const ParserAction &action : actions) Add(action.role); +} + +void RoleSet::GetList(std::vector *list) const { + list->resize(roles_.size()); + for (auto &it : roles_) { + (*list)[it.second] = it.first; } } diff --git a/sling/nlp/parser/roles.h b/sling/nlp/parser/roles.h index c63ab9f9..3dc4e829 100644 --- a/sling/nlp/parser/roles.h +++ b/sling/nlp/parser/roles.h @@ -28,8 +28,11 @@ namespace nlp { // A mapping of roles to role ids extracted from the action set. class RoleSet { public: - // Initialize role mapping for action set. - void Init(const std::vector &actions); + // Add role to role set. + void Add(Handle role); + + // Add roles from action list. + void Add(const std::vector &actions); // Look up role id for role. Return -1 if role is unknown. int Lookup(Handle role) const { @@ -41,6 +44,9 @@ class RoleSet { // Return the number of roles in the role set. int size() const { return roles_.size(); } + // Get list of roles. + void GetList(std::vector *list) const; + private: // Mapping from role handle to role id. HandleMap roles_; diff --git a/sling/nlp/parser/tools/BUILD b/sling/nlp/parser/tools/BUILD index 82bed00e..1753cb9a 100644 --- a/sling/nlp/parser/tools/BUILD +++ b/sling/nlp/parser/tools/BUILD @@ -15,6 +15,7 @@ cc_binary( "//sling/nlp/document:document-tokenizer", "//sling/nlp/document:lex", "//sling/nlp/parser", + "//sling/nlp/parser:multiclass-delegate", "//sling/nlp/parser:frame-evaluation", "//sling/string:printf", ], diff --git a/sling/nlp/parser/tools/parse.cc b/sling/nlp/parser/tools/parse.cc index b20c2780..4e73ac17 100644 --- a/sling/nlp/parser/tools/parse.cc +++ b/sling/nlp/parser/tools/parse.cc @@ -54,7 +54,6 @@ DEFINE_string(output, "", "Output filename"); DEFINE_int32(indent, 2, "Indentation for SLING output"); DEFINE_string(corpus, "", "Input corpus"); DEFINE_bool(parse, false, "Parse input corpus"); -DEFINE_bool(trace, false, "Trace or not"); DEFINE_bool(benchmark, false, "Benchmark parser"); DEFINE_bool(lex, false, "Output documents in LEX format"); DEFINE_bool(evaluate, false, "Evaluate parser"); @@ -134,7 +133,6 @@ int main(int argc, char *argv[]) { Store commons; Parser parser; parser.Load(&commons, FLAGS_parser); - parser.set_trace(FLAGS_trace); commons.Freeze(); clock.stop(); LOG(INFO) << clock.ms() << " ms loading parser"; @@ -167,7 +165,6 @@ int main(int argc, char *argv[]) { if (FLAGS_parse) { CHECK(!FLAGS_corpus.empty()); LOG(INFO) << "Parse " << FLAGS_corpus; - if (FLAGS_trace) LOG(INFO) << "Tracing on"; DocumentCorpus corpus(&commons, FLAGS_corpus); int num_documents = 0; RecordWriter *writer = nullptr; diff --git a/sling/nlp/parser/tools/train_caspar.py b/sling/nlp/parser/tools/train_caspar.py index 54fadd88..d51d0049 100644 --- a/sling/nlp/parser/tools/train_caspar.py +++ b/sling/nlp/parser/tools/train_caspar.py @@ -2,8 +2,20 @@ import sling.flags as flags import sling.task.workflow as workflow -# Start up workflow system. +flags.define("--accurate", default=False,action='store_true') + flags.parse() + +if flags.arg.accurate: + modelfn = "local/data/e/caspar/caspar-accurate.flow" + rnn_layers = 3 + rnn_dim = 192 +else: + modelfn = "local/data/e/caspar/caspar.flow" + rnn_layers = 1 + rnn_dim = 128 + +# Start up workflow system. workflow.startup() # Create workflow. @@ -25,23 +37,28 @@ format="embeddings" ) -parser_model = wf.resource( - "local/data/e/caspar/caspar.flow", - format="flow" -) +parser_model = wf.resource(modelfn, format="flow") # Parser trainer task. trainer = wf.task("caspar-trainer") trainer.add_params({ + "rnn_type": 1, + "rnn_dim": rnn_dim, + "rnn_highways": True, + "rnn_layers": rnn_layers, + "dropout": 0.2, + "ff_l2reg": 0.0001, + "learning_rate": 1.0, "learning_rate_decay": 0.8, "clipping": 1, "optimizer": "sgd", - "epochs": 50000, "batch_size": 32, "rampup": 120, - "report_interval": 500 + "report_interval": 1000, + "learning_rate_cliff": 40000, + "epochs": 50000, }) trainer.attach_input("training_corpus", training_corpus) diff --git a/sling/nlp/parser/trace.cc b/sling/nlp/parser/trace.cc deleted file mode 100644 index e8b03fc3..00000000 --- a/sling/nlp/parser/trace.cc +++ /dev/null @@ -1,120 +0,0 @@ -// Copyright 2017 Google Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "sling/nlp/parser/trace.h" - -namespace sling { -namespace nlp { - -void Trace::Step::Add(int *ptr, int num, const string &name) { - // Nothing to do if the feature is missing. - if (ptr == nullptr || num == 0) return; - - auto &features = ff_features[name]; - for (int i = 0; i < num; ++i) { - // Note: -2 signals end of feature indices. - if (ptr[i] != -2) features.push_back(ptr[i]); - } -} - -void Trace::Action(const ParserAction &action) { - steps.back().actions.emplace_back(action, action); -} - -void Trace::Fallback(const ParserAction &fallback) { - steps.back().actions.back().second = fallback; -} - -void Trace::AddLSTM(int token, const string &name, int val) { - if (lstm_features.size() <= token) lstm_features.resize(token + 1); - size_t index = name.rfind('/'); - if (index != string::npos) { - string shortname = name.substr(index + 1); - lstm_features[token][shortname].emplace_back(val); - } else { - lstm_features[token][name].emplace_back(val); - } -} - -void Trace::Write(Document *document) const { - Store *store = document->store(); - Builder builder(store); - builder.Add("begin", begin); - builder.Add("end", end); - - // Write encoder features. - Array lstm_array(store, lstm_features.size()); - for (int t = 0; t < lstm_features.size(); ++t) { - Builder lstm(store); - lstm.Add("/trace/token", document->token(t).word()); - lstm.Add("/trace/index", t); - for (const auto &kv : lstm_features[t]) { - Array values(store, kv.second.size()); - for (int v = 0; v < kv.second.size(); ++v) { - values.set(v, Handle::Integer(kv.second[v])); - } - lstm.Add("/trace/" + kv.first, values); - } - lstm_array.set(t, lstm.Create().handle()); - } - builder.Add("/trace/lstm_features", lstm_array); - - // Write steps. - Array steps_array(store, steps.size()); - for (int i = 0; i < steps.size(); ++i) { - Builder step(store); - int current = steps[i].current; - step.Add("/trace/current", current); - step.Add("/trace/index", i); - string word = ""; - if (current < document->num_tokens()) { - word = document->token(current).word(); - } - step.Add("/trace/current_word", word); - - // Write decoder features. - Array ff(store, steps[i].ff_features.size()); - int ff_count = 0; - for (const auto &kv : steps[i].ff_features) { - Builder feature(store); - feature.Add("/trace/feature", kv.first); - Array values(store, kv.second.size()); - for (int v = 0; v < kv.second.size(); ++v) { - values.set(v, Handle::Integer(kv.second[v])); - } - feature.Add("/trace/values", values); - ff.set(ff_count++, feature.Create().handle()); - } - step.Add("/trace/ff_features", ff); - - // Write (predicted, final) actions. - Array actions(store, steps[i].actions.size()); - for (int a = 0; a < steps[i].actions.size(); ++a) { - Frame predicted = steps[i].actions[a].first.AsFrame(store, "/trace/"); - Frame applied = steps[i].actions[a].second.AsFrame(store, "/trace/"); - Builder action(store); - action.Add("/trace/predicted", predicted); - action.Add("/trace/final", applied); - actions.set(a, action.Create().handle()); - } - step.Add("/trace/actions", actions); - steps_array.set(i, step.Create().handle()); - } - builder.Add("/trace/steps", steps_array); - - document->AddExtra(store->Lookup("trace"), builder.Create().handle()); -} - -} // namespace nlp -} // namespace sling diff --git a/sling/nlp/parser/trace.h b/sling/nlp/parser/trace.h deleted file mode 100644 index 1a9d5ad7..00000000 --- a/sling/nlp/parser/trace.h +++ /dev/null @@ -1,76 +0,0 @@ -// Copyright 2017 Google Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef SLING_NLP_PARSER_TRACE_H_ -#define SLING_NLP_PARSER_TRACE_H_ - -#include -#include -#include - -#include "sling/frame/object.h" -#include "sling/frame/store.h" -#include "sling/nlp/document/document.h" -#include "sling/nlp/parser/parser-action.h" - -namespace sling { -namespace nlp { - -// Tracing information for a semantic parse. -struct Trace { - // Single step of the decoder. - struct Step { - // Index of the current token, - int current; - - // Decoder feature -> List of feature value(s). - std::unordered_map> ff_features; - - // List of (predicted action, final action) in this step's cascade. - std::vector> actions; - - // Adds feature values for feature 'name'. If 'ptr' is not nullptr, - // 'num' values are taken starting at 'ptr'. - void Add(int *ptr, int num, const string &name); - }; - - Trace(int begin, int end) : begin(begin), end(end) {} - - // Beginning and ending tokens of the parser state. - int begin; - int end; - - // Token -> (encoder feature name -> feature values). - std::vector>> lstm_features; - - // List of steps. - std::vector steps; - - // Adds a predicted (=final) action to the latest step. - void Action(const ParserAction &action); - - // Sets the last final action to 'fallback' for the latest step. - void Fallback(const ParserAction &fallback); - - // Adds encoder feature values for 'token'. - void AddLSTM(int token, const string &name, int val); - - // Writes tracing information as a frame to 'document'. - void Write(Document *document) const; -}; - -} // namespace nlp -} // namespace sling - -#endif // SLING_NLP_PARSER_TRACE_H_ diff --git a/sling/nlp/parser/trainer/pytorch_modules.py b/sling/nlp/parser/trainer/pytorch_modules.py index 06122834..e83e690b 100644 --- a/sling/nlp/parser/trainer/pytorch_modules.py +++ b/sling/nlp/parser/trainer/pytorch_modules.py @@ -691,7 +691,7 @@ def finish_concat_op(bldr, op): flow_ff.set_layer_data(0, self.ff_layer.weight.data.numpy(), \ self.ff_layer.bias.data.numpy()) - ff_concat_op = ff.rawop(optype="ConcatV2", name="concat") + ff_concat_op = ff.rawop(optype="Concat", name="concat") ff_concat_op.add_output(ff_input) # Add link variable to the given connector. diff --git a/sling/nlp/parser/transition-generator.cc b/sling/nlp/parser/transition-generator.cc index aaf4d4fa..c3d5b3fb 100644 --- a/sling/nlp/parser/transition-generator.cc +++ b/sling/nlp/parser/transition-generator.cc @@ -148,14 +148,6 @@ ParserAction Translate(const Handles &attention, const Action &action) { case ParserAction::REFER: output.target = Index(attention, action.frame->handle); break; - case ParserAction::EMBED: - output.label = action.frame->type; - output.target = Index(attention, action.neighbor->handle); - break; - case ParserAction::ELABORATE: - output.label = action.frame->type; - output.source = Index(attention, action.neighbor->handle); - break; case ParserAction::CONNECT: output.source = Index(attention, action.frame->handle); output.target = Index(attention, action.neighbor->handle); @@ -174,10 +166,10 @@ ParserAction Translate(const Handles &attention, const Action &action) { // Updates 'attention' as a result of execution 'action'. void Update(const Action &action, Handles *attention) { auto type = action.core.type; - if (type == ParserAction::EVOKE || type == ParserAction::EMBED || - type == ParserAction::ELABORATE) { + if (type == ParserAction::EVOKE) { attention->emplace_back(action.frame->handle); - } else if (type == ParserAction::REFER || type == ParserAction::ASSIGN || + } else if (type == ParserAction::REFER || + type == ParserAction::ASSIGN || type == ParserAction::CONNECT) { QCHECK_GT(attention->size(), 0); auto handle = action.frame->handle; @@ -339,7 +331,6 @@ void CollectSpanActions(Store *store, actions->emplace_back(ParserAction::SHIFT, nullptr); } - actions->emplace_back(ParserAction::STOP, nullptr); } void OutputActions(Store *store, @@ -355,8 +346,7 @@ void OutputActions(Store *store, actions->pop_front(); ParserAction::Type type = output.type; - if (type == ParserAction::EVOKE || type == ParserAction::EMBED || - type == ParserAction::ELABORATE) { + if (type == ParserAction::EVOKE) { action.frame->accounted = true; // CONNECT. @@ -375,39 +365,6 @@ void OutputActions(Store *store, actions->push_front(connect); } - // EMBED. - for (auto *edge : action.frame->edges) { - if (edge->accounted || !edge->incoming) continue; - - FrameGraph::Node *neighbor = frame_graph.node(edge->neighbor); - if (neighbor == nullptr || neighbor->accounted || - neighbor->evoked) { - continue; - } - Action embed(ParserAction::EMBED, neighbor); - embed.core.role = edge->role; - embed.neighbor = action.frame; - edge->accounted = true; - edge->inverse->accounted = true; - actions->push_front(embed); - } - - // ELABORATE. - for (auto *edge : action.frame->edges) { - if (edge->accounted || edge->incoming) continue; - - FrameGraph::Node *neighbor = frame_graph.node(edge->neighbor); - if (neighbor == nullptr || neighbor->accounted || neighbor->evoked) { - continue; - } - Action elaborate(ParserAction::EMBED, neighbor); - elaborate.core.role = edge->role; - elaborate.neighbor = action.frame; - edge->accounted = true; - edge->inverse->accounted = true; - actions->push_front(elaborate); - } - // ASSIGN. for (auto *edge : action.frame->edges) { if (edge->accounted || edge->incoming) continue; diff --git a/sling/pyapi/BUILD b/sling/pyapi/BUILD index b3e3fb7f..c2447666 100644 --- a/sling/pyapi/BUILD +++ b/sling/pyapi/BUILD @@ -54,6 +54,7 @@ cc_library( "//sling/nlp/kb:facts", "//sling/nlp/kb:phrase-table", "//sling/nlp/parser", + "//sling/nlp/parser:multiclass-delegate", "//sling/nlp/parser:frame-evaluation", "//sling/nlp/wiki:wikidata-converter", "//sling/task:dashboard", diff --git a/sling/pyapi/pymyelin.cc b/sling/pyapi/pymyelin.cc index 1c178c7e..9e2f7daa 100644 --- a/sling/pyapi/pymyelin.cc +++ b/sling/pyapi/pymyelin.cc @@ -249,6 +249,8 @@ bool PyCompiler::ImportFlow(PyObject *pyflow, Flow *flow, PyBuffers *buffers) { var->flags = PyIntAttr(pyvar, "flags"); varmap[pyvar] = var; + if (!ImportAttributes(pyvar, var)) return false; + PyObject *pydata = PyAttr(pyvar, "data"); if (pydata != Py_None) { var->data = buffers->GetData(pydata, var->type, &var->size); diff --git a/sling/task/learner.cc b/sling/task/learner.cc index 74e178c2..c4c60c86 100644 --- a/sling/task/learner.cc +++ b/sling/task/learner.cc @@ -107,7 +107,7 @@ Optimizer *GetOptimizer(Task *task) { sgd->set_learning_rate(lr); sgd->set_decay(decay); sgd->set_clipping_threshold(clip); - sgd->set_lambda(task->Get("regularization", 0.0)); + sgd->set_lambda(task->Get("l2reg", 0.0)); return sgd; } else if (type == "momentum") { MomentumOptimizer *momentum = new MomentumOptimizer();