-
Notifications
You must be signed in to change notification settings - Fork 718
/
Copy pathBertSentenceEmbeddings.scala
476 lines (413 loc) · 16.2 KB
/
BertSentenceEmbeddings.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
/*
* Copyright 2017-2022 John Snow Labs
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.johnsnowlabs.nlp.embeddings
import com.johnsnowlabs.ml.ai.Bert
import com.johnsnowlabs.ml.tensorflow._
import com.johnsnowlabs.ml.util.LoadExternalModel.{
loadTextAsset,
modelSanityCheck,
notSupportedEngineError
}
import com.johnsnowlabs.ml.util.ModelEngine
import com.johnsnowlabs.nlp._
import com.johnsnowlabs.nlp.annotators.common._
import com.johnsnowlabs.nlp.annotators.tokenizer.wordpiece.{BasicTokenizer, WordpieceEncoder}
import com.johnsnowlabs.nlp.serialization.MapFeature
import com.johnsnowlabs.storage.HasStorageRef
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.ml.param.{BooleanParam, IntArrayParam, IntParam}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.{DataFrame, SparkSession}
/** Sentence-level embeddings using BERT. BERT (Bidirectional Encoder Representations from
* Transformers) provides dense vector representations for natural language by using a deep,
* pre-trained neural network with the Transformer architecture.
*
* Pretrained models can be loaded with `pretrained` of the companion object:
* {{{
* val embeddings = BertSentenceEmbeddings.pretrained()
* .setInputCols("sentence")
* .setOutputCol("sentence_bert_embeddings")
* }}}
* The default model is `"sent_small_bert_L2_768"`, if no name is provided.
*
* For available pretrained models please see the
* [[https://sparknlp.org/models?task=Embeddings Models Hub]].
*
* For extended examples of usage, see the
* [[https://github.com/JohnSnowLabs/spark-nlp/blob/master/examples/python/transformers/HuggingFace%20in%20Spark%20NLP%20-%20BERT%20Sentence.ipynb Examples]]
* and the
* [[https://github.com/JohnSnowLabs/spark-nlp/blob/master/src/test/scala/com/johnsnowlabs/nlp/embeddings/BertSentenceEmbeddingsTestSpec.scala BertSentenceEmbeddingsTestSpec]].
*
* '''Sources''' :
*
* [[https://arxiv.org/abs/1810.04805 BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding]]
*
* [[https://github.com/google-research/bert]]
*
* ''' Paper abstract '''
*
* ''We introduce a new language representation model called BERT, which stands for Bidirectional
* Encoder Representations from Transformers. Unlike recent language representation models, BERT
* is designed to pre-train deep bidirectional representations from unlabeled text by jointly
* conditioning on both left and right context in all layers. As a result, the pre-trained BERT
* model can be fine-tuned with just one additional output layer to create state-of-the-art
* models for a wide range of tasks, such as question answering and language inference, without
* substantial task-specific architecture modifications. BERT is conceptually simple and
* empirically powerful. It obtains new state-of-the-art results on eleven natural language
* processing tasks, including pushing the GLUE score to 80.5% (7.7% point absolute improvement),
* MultiNLI accuracy to 86.7% (4.6% absolute improvement), SQuAD v1.1 question answering Test F1
* to 93.2 (1.5 point absolute improvement) and SQuAD v2.0 Test F1 to 83.1 (5.1 point absolute
* improvement).''
*
* ==Example==
* {{{
* import spark.implicits._
* import com.johnsnowlabs.nlp.base.DocumentAssembler
* import com.johnsnowlabs.nlp.annotator.SentenceDetector
* import com.johnsnowlabs.nlp.embeddings.BertSentenceEmbeddings
* import com.johnsnowlabs.nlp.EmbeddingsFinisher
* import org.apache.spark.ml.Pipeline
*
* val documentAssembler = new DocumentAssembler()
* .setInputCol("text")
* .setOutputCol("document")
*
* val sentence = new SentenceDetector()
* .setInputCols("document")
* .setOutputCol("sentence")
*
* val embeddings = BertSentenceEmbeddings.pretrained("sent_small_bert_L2_128")
* .setInputCols("sentence")
* .setOutputCol("sentence_bert_embeddings")
*
* val embeddingsFinisher = new EmbeddingsFinisher()
* .setInputCols("sentence_bert_embeddings")
* .setOutputCols("finished_embeddings")
* .setOutputAsVector(true)
*
* val pipeline = new Pipeline().setStages(Array(
* documentAssembler,
* sentence,
* embeddings,
* embeddingsFinisher
* ))
*
* val data = Seq("John loves apples. Mary loves oranges. John loves Mary.").toDF("text")
* val result = pipeline.fit(data).transform(data)
*
* result.selectExpr("explode(finished_embeddings) as result").show(5, 80)
* +--------------------------------------------------------------------------------+
* | result|
* +--------------------------------------------------------------------------------+
* |[-0.8951074481010437,0.13753940165042877,0.3108254075050354,-1.65693199634552...|
* |[-0.6180210709571838,-0.12179657071828842,-0.191165953874588,-1.4497021436691...|
* |[-0.822715163230896,0.7568016648292542,-0.1165061742067337,-1.59048593044281,...|
* +--------------------------------------------------------------------------------+
* }}}
*
* @see
* [[BertEmbeddings]] for token-level embeddings
* @see
* [[com.johnsnowlabs.nlp.annotators.classifier.dl.BertForSequenceClassification BertForSequenceClassification]]
* for embeddings with a sequence classification layer on top
* @see
* [[https://sparknlp.org/docs/en/annotators Annotators Main Page]] for a list of transformer
* based embeddings
* @param uid
* required uid for storing annotator to disk
* @groupname anno Annotator types
* @groupdesc anno
* Required input and expected output annotator types
* @groupname Ungrouped Members
* @groupname param Parameters
* @groupname setParam Parameter setters
* @groupname getParam Parameter getters
* @groupname Ungrouped Members
* @groupprio param 1
* @groupprio anno 2
* @groupprio Ungrouped 3
* @groupprio setParam 4
* @groupprio getParam 5
* @groupdesc param
* A list of (hyper-)parameter keys this annotator can take. Users can set and get the
* parameter values through setters and getters, respectively.
*/
class BertSentenceEmbeddings(override val uid: String)
extends AnnotatorModel[BertSentenceEmbeddings]
with HasBatchedAnnotate[BertSentenceEmbeddings]
with WriteTensorflowModel
with HasEmbeddingsProperties
with HasStorageRef
with HasCaseSensitiveProperties
with HasEngine
with HasProtectedParams {
def this() = this(Identifiable.randomUID("BERT_SENTENCE_EMBEDDINGS"))
/** Vocabulary used to encode the words to ids with WordPieceEncoder
*
* @group param
*/
val vocabulary: MapFeature[String, Int] = new MapFeature(this, "vocabulary").setProtected()
/** ConfigProto from tensorflow, serialized into byte array. Get with
* config_proto.SerializeToString()
*
* @group param
*/
val configProtoBytes = new IntArrayParam(
this,
"configProtoBytes",
"ConfigProto from tensorflow, serialized into byte array. Get with config_proto.SerializeToString()")
/** Max sentence length to process (Default: `128`)
*
* @group param
*/
val maxSentenceLength =
new IntParam(this, "maxSentenceLength", "Max sentence length to process").setProtected()
/** Use Long type instead of Int type for inputs (Default: `false`)
*
* @group param
*/
val isLong = new BooleanParam(
parent = this,
name = "isLong",
"Use Long type instead of Int type for inputs buffer - Some Bert models require Long instead of Int.")
.setProtected()
/** set isLong
*
* @group setParam
*/
def setIsLong(value: Boolean): this.type = {
set(this.isLong, value)
}
/** get isLong
*
* @group getParam
*/
def getIsLong: Boolean = $(isLong)
/** @group setParam */
def sentenceStartTokenId: Int = {
$$(vocabulary)("[CLS]")
}
/** @group setParam */
def sentenceEndTokenId: Int = {
$$(vocabulary)("[SEP]")
}
/** Set Embeddings dimensions for the BERT model Only possible to set this when the first time
* is saved dimension is not changeable, it comes from BERT config file
*
* @group setParam
*/
override def setDimension(value: Int): this.type = {
set(this.dimension, value)
}
/** Whether to lowercase tokens or not
*
* @group setParam
*/
override def setCaseSensitive(value: Boolean): this.type = {
set(this.caseSensitive, value)
}
/** Vocabulary used to encode the words to ids with WordPieceEncoder
*
* @group setParam
*/
def setVocabulary(value: Map[String, Int]): this.type = set(vocabulary, value)
/** ConfigProto from tensorflow, serialized into byte array. Get with
* config_proto.SerializeToString()
*
* @group setParam
*/
def setConfigProtoBytes(bytes: Array[Int]): BertSentenceEmbeddings.this.type =
set(this.configProtoBytes, bytes)
/** Max sentence length to process (Default: `128`)
*
* @group setParam
*/
def setMaxSentenceLength(value: Int): this.type = {
require(
value <= 512,
"BERT models do not support sequences longer than 512 because of trainable positional embeddings")
set(maxSentenceLength, value)
}
/** ConfigProto from tensorflow, serialized into byte array. Get with
* config_proto.SerializeToString()
*
* @group getParam
*/
def getConfigProtoBytes: Option[Array[Byte]] = get(this.configProtoBytes).map(_.map(_.toByte))
/** Max sentence length to process (Default: `128`)
*
* @group getParam
*/
def getMaxSentenceLength: Int = $(maxSentenceLength)
setDefault(
dimension -> 768,
batchSize -> 8,
maxSentenceLength -> 128,
caseSensitive -> false,
isLong -> false)
/** It contains TF model signatures for the laded saved model
*
* @group param
*/
val signatures =
new MapFeature[String, String](model = this, name = "signatures").setProtected()
/** @group setParam */
def setSignatures(value: Map[String, String]): this.type = {
set(signatures, value)
this
}
/** @group getParam */
def getSignatures: Option[Map[String, String]] = get(this.signatures)
private var _model: Option[Broadcast[Bert]] = None
/** @group getParam */
def getModelIfNotSet: Bert = _model.get.value
/** @group setParam */
def setModelIfNotSet(spark: SparkSession, tensorflow: TensorflowWrapper): this.type = {
if (_model.isEmpty) {
_model = Some(
spark.sparkContext.broadcast(
new Bert(
tensorflow,
sentenceStartTokenId,
sentenceEndTokenId,
configProtoBytes = getConfigProtoBytes,
signatures = getSignatures)))
}
this
}
def tokenize(sentences: Seq[Sentence]): Seq[WordpieceTokenizedSentence] = {
val basicTokenizer = new BasicTokenizer($(caseSensitive))
val encoder = new WordpieceEncoder($$(vocabulary))
sentences.map { s =>
val tokens = basicTokenizer.tokenize(s)
val wordpieceTokens = tokens.flatMap(token => encoder.encode(token))
WordpieceTokenizedSentence(wordpieceTokens)
}
}
/** takes a document and annotations and produces new annotations of this annotator's annotation
* type
*
* @param batchedAnnotations
* Annotations that correspond to inputAnnotationCols generated by previous annotators if any
* @return
* any number of annotations processed for every input annotation. Not necessary one to one
* relationship
*/
override def batchAnnotate(batchedAnnotations: Seq[Array[Annotation]]): Seq[Seq[Annotation]] = {
// Unpack annotations and zip each sentence to the index or the row it belongs to
val sentencesWithRow = batchedAnnotations.zipWithIndex
.flatMap { case (annotations, i) => SentenceSplit.unpack(annotations).map(x => (x, i)) }
// Tokenize sentences
val tokenizedSentences = tokenize(sentencesWithRow.map(_._1))
// Process all sentences
val allAnnotations = getModelIfNotSet.predictSequence(
tokenizedSentences,
sentencesWithRow.map(_._1),
$(batchSize),
$(maxSentenceLength),
getIsLong)
// Group resulting annotations by rows. If there are not sentences in a given row, return empty sequence
batchedAnnotations.indices.map(rowIndex => {
val rowAnnotations = allAnnotations
// zip each annotation with its corresponding row index
.zip(sentencesWithRow)
// select the sentences belonging to the current row
.filter(_._2._2 == rowIndex)
// leave the annotation only
.map(_._1)
if (rowAnnotations.nonEmpty)
rowAnnotations
else
Seq.empty[Annotation]
})
}
override protected def afterAnnotate(dataset: DataFrame): DataFrame = {
dataset.withColumn(
getOutputCol,
wrapSentenceEmbeddingsMetadata(
dataset.col(getOutputCol),
$(dimension),
Some($(storageRef))))
}
/** Annotator reference id. Used to identify elements in metadata or to refer to this annotator
* type
*/
override val inputAnnotatorTypes: Array[String] = Array(AnnotatorType.DOCUMENT)
override val outputAnnotatorType: AnnotatorType = AnnotatorType.SENTENCE_EMBEDDINGS
override def onWrite(path: String, spark: SparkSession): Unit = {
super.onWrite(path, spark)
writeTensorflowModelV2(
path,
spark,
getModelIfNotSet.tensorflowWrapper,
"_bert_sentence",
BertSentenceEmbeddings.tfFile,
configProtoBytes = getConfigProtoBytes)
}
}
trait ReadablePretrainedBertSentenceModel
extends ParamsAndFeaturesReadable[BertSentenceEmbeddings]
with HasPretrained[BertSentenceEmbeddings] {
override val defaultModelName: Some[String] = Some("sent_small_bert_L2_768")
/** Java compliant-overrides */
override def pretrained(): BertSentenceEmbeddings = super.pretrained()
override def pretrained(name: String): BertSentenceEmbeddings = super.pretrained(name)
override def pretrained(name: String, lang: String): BertSentenceEmbeddings =
super.pretrained(name, lang)
override def pretrained(name: String, lang: String, remoteLoc: String): BertSentenceEmbeddings =
super.pretrained(name, lang, remoteLoc)
}
trait ReadBertSentenceDLModel extends ReadTensorflowModel {
this: ParamsAndFeaturesReadable[BertSentenceEmbeddings] =>
override val tfFile: String = "bert_sentence_tensorflow"
def readModel(instance: BertSentenceEmbeddings, path: String, spark: SparkSession): Unit = {
val tf = readTensorflowModel(path, spark, "_bert_sentence_tf", initAllTables = false)
instance.setModelIfNotSet(spark, tf)
}
addReader(readModel)
def loadSavedModel(modelPath: String, spark: SparkSession): BertSentenceEmbeddings = {
val (localModelPath, detectedEngine) = modelSanityCheck(modelPath)
val vocabs = loadTextAsset(localModelPath, "vocab.txt").zipWithIndex.toMap
/*Universal parameters for all engines*/
val annotatorModel = new BertSentenceEmbeddings()
.setVocabulary(vocabs)
annotatorModel.set(annotatorModel.engine, detectedEngine)
detectedEngine match {
case ModelEngine.tensorflow =>
val (wrapper, signatures) =
TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true)
val _signatures = signatures match {
case Some(s) => s
case None => throw new Exception("Cannot load signature definitions from model!")
}
/** the order of setSignatures is important if we use getSignatures inside
* setModelIfNotSet
*/
annotatorModel
.setSignatures(_signatures)
.setModelIfNotSet(spark, wrapper)
case _ =>
throw new Exception(notSupportedEngineError)
}
annotatorModel
}
}
/** This is the companion object of [[BertSentenceEmbeddings]]. Please refer to that class for the
* documentation.
*/
object BertSentenceEmbeddings
extends ReadablePretrainedBertSentenceModel
with ReadBertSentenceDLModel