Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrating ONNX runtime (ORT) in Spark NLP 5.0.0 🎉 #13857

Merged
merged 33 commits into from
Jul 1, 2023
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
406ef47
Add ONNX Runtime to the dependencies
maziyarpanahi Apr 13, 2023
1dcd672
Add both CPU and GPU coordinates for onnxruntime
maziyarpanahi Apr 13, 2023
61ecae1
Implement OnnxSerializeModel
maziyarpanahi Apr 13, 2023
c1ae390
Implement OnnxWrapper
maziyarpanahi Apr 13, 2023
1a02cc4
Update error message for loading external models
maziyarpanahi Apr 13, 2023
73ce499
Add support for ONNX to BertEmbeddings annotator
maziyarpanahi Apr 13, 2023
7dfb64b
Add support for ONNX to BERT backend
maziyarpanahi Apr 13, 2023
991aa80
Add support for ONNX to DeBERTa
maziyarpanahi Apr 13, 2023
1a7ba47
Implement ONNX in DeBERTa backend
maziyarpanahi Apr 13, 2023
28f3ec4
Adapt Bert For sentence embeddings with the new backend
maziyarpanahi Apr 13, 2023
4dd6a63
Update unit test for BERT (temp)
maziyarpanahi Apr 13, 2023
fd4be2d
Update unit test for DeBERTa (temp)
maziyarpanahi Apr 13, 2023
2068899
Update onnxruntime and google cloud dependencies
maziyarpanahi Apr 24, 2023
8c745cc
Seems Apple Silicon and Aarch64 are supported in onnxruntime
maziyarpanahi Apr 24, 2023
ef3233b
Cleaning up
maziyarpanahi Apr 24, 2023
81dc5d4
Remove bad merge
maziyarpanahi Apr 24, 2023
2de44f9
Update BERT unit test
maziyarpanahi Apr 24, 2023
70763b7
Merge branch 'master' into feature/onnx-runtime
maziyarpanahi Apr 26, 2023
d507528
Add fix me to the try
maziyarpanahi Apr 26, 2023
f71e645
Making withSafeOnnxModelLoader thread safe
maziyarpanahi May 1, 2023
bafb5df
Merge branch 'master' into feature/onnx-runtime
maziyarpanahi May 18, 2023
9afd6ab
Merge branch 'master' into feature/onnx-runtime
maziyarpanahi May 30, 2023
2578b4b
update onnxruntime
maziyarpanahi Jun 8, 2023
3415fc1
Merge branch 'master' into feature/onnx-runtime
maziyarpanahi Jun 12, 2023
103818b
Revert back to normal unit tests for now [ski ptest]
maziyarpanahi Jun 15, 2023
56867fd
Merge branch 'release/500-release-candidate' into feature/onnx-runtime
maziyarpanahi Jun 19, 2023
ffdc375
Added ADT for ModelEngine (#13862)
wolliq Jun 23, 2023
a675e97
Optimize ONNX on CPU
maziyarpanahi Jun 25, 2023
e67d79d
refactor
maziyarpanahi Jun 25, 2023
a2dc8c6
Add ONNX support to DistilBERT
maziyarpanahi Jun 25, 2023
2fe844f
Add support for ONNX in RoBERTa
maziyarpanahi Jun 25, 2023
a904234
Fix the bad serialization on write
maziyarpanahi Jun 25, 2023
0521edd
Fix using the wrong object
maziyarpanahi Jun 27, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,16 @@ val tensorflowDependencies: Seq[sbt.ModuleID] =
else
Seq(tensorflowCPU)

val onnxDependencies: Seq[sbt.ModuleID] =
if (is_gpu.equals("true"))
Seq(onnxGPU)
else if (is_silicon.equals("true"))
Seq(onnxCPU)
else if (is_aarch64.equals("true"))
Seq(onnxCPU)
else
Seq(onnxCPU)

lazy val mavenProps = settingKey[Unit]("workaround for Maven properties")

lazy val root = (project in file("."))
Expand All @@ -175,6 +185,7 @@ lazy val root = (project in file("."))
testDependencies ++
utilDependencies ++
tensorflowDependencies ++
onnxDependencies ++
typedDependencyParserDependencies,
// TODO potentially improve this?
mavenProps := {
Expand Down
5 changes: 4 additions & 1 deletion project/Dependencies.scala
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,10 @@ object Dependencies {
val tensorflowM1 = "com.johnsnowlabs.nlp" %% "tensorflow-m1" % tensorflowVersion
val tensorflowLinuxAarch64 = "com.johnsnowlabs.nlp" %% "tensorflow-aarch64" % tensorflowVersion

val gcpStorageVersion = "2.16.0"
val onnxRuntimeVersion = "1.15.0"
val onnxCPU = "com.microsoft.onnxruntime" % "onnxruntime" % onnxRuntimeVersion
val onnxGPU = "com.microsoft.onnxruntime" % "onnxruntime_gpu" % onnxRuntimeVersion
val gcpStorageVersion = "2.20.1"
val gcpStorage = "com.google.cloud" % "google-cloud-storage" % gcpStorageVersion

/** ------- Dependencies end ------- */
Expand Down
280 changes: 189 additions & 91 deletions src/main/scala/com/johnsnowlabs/ml/ai/Bert.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@

package com.johnsnowlabs.ml.ai

import ai.onnxruntime.OnnxTensor
import com.johnsnowlabs.ml.ai.util.PrepareEmbeddings
import com.johnsnowlabs.ml.onnx.OnnxWrapper
import com.johnsnowlabs.ml.tensorflow.sign.{ModelSignatureConstants, ModelSignatureManager}
import com.johnsnowlabs.ml.tensorflow.{TensorResources, TensorflowWrapper}
import com.johnsnowlabs.ml.util.ModelArch
import com.johnsnowlabs.ml.util.{ModelEngine, ModelArch}
import com.johnsnowlabs.nlp.annotators.common._
import com.johnsnowlabs.nlp.{Annotation, AnnotatorType}

Expand All @@ -35,6 +37,8 @@ import scala.collection.JavaConverters._
*
* @param tensorflowWrapper
* Bert Model wrapper with TensorFlow Wrapper
* @param onnxWrapper
* Bert Model wrapper with ONNX Wrapper
* @param sentenceStartTokenId
* Id of sentence start Token
* @param sentenceEndTokenId
Expand All @@ -47,7 +51,8 @@ import scala.collection.JavaConverters._
* Source: [[https://github.com/google-research/bert]]
*/
private[johnsnowlabs] class Bert(
val tensorflowWrapper: TensorflowWrapper,
val tensorflowWrapper: Option[TensorflowWrapper],
val onnxWrapper: Option[OnnxWrapper],
sentenceStartTokenId: Int,
sentenceEndTokenId: Int,
configProtoBytes: Option[Array[Byte]] = None,
Expand All @@ -57,6 +62,10 @@ private[johnsnowlabs] class Bert(
extends Serializable {

val _tfBertSignatures: Map[String, String] = signatures.getOrElse(ModelSignatureManager.apply())
val detectedEngine: String =
if (tensorflowWrapper.isDefined) ModelEngine.tensorflow
else if (onnxWrapper.isDefined) ModelEngine.onnx
else ModelEngine.tensorflow
maziyarpanahi marked this conversation as resolved.
Show resolved Hide resolved

private def sessionWarmup(): Unit = {
val dummyInput =
Expand All @@ -74,51 +83,99 @@ private[johnsnowlabs] class Bert(
sessionWarmup()

def tag(batch: Seq[Array[Int]]): Seq[Array[Array[Float]]] = {

val maxSentenceLength = batch.map(pieceIds => pieceIds.length).max
val batchLength = batch.length

val tensors = new TensorResources()

val (tokenTensors, maskTensors, segmentTensors) =
PrepareEmbeddings.prepareBatchTensorsWithSegment(
tensors = tensors,
batch = batch,
maxSentenceLength = maxSentenceLength,
batchLength = batchLength)

val runner = tensorflowWrapper
.getTFSessionWithSignature(
configProtoBytes = configProtoBytes,
savedSignatures = signatures,
initAllTables = false)
.runner

runner
.feed(
_tfBertSignatures.getOrElse(
ModelSignatureConstants.InputIdsV1.key,
"missing_input_id_key"),
tokenTensors)
.feed(
_tfBertSignatures
.getOrElse(ModelSignatureConstants.AttentionMaskV1.key, "missing_input_mask_key"),
maskTensors)
.feed(
_tfBertSignatures
.getOrElse(ModelSignatureConstants.TokenTypeIdsV1.key, "missing_segment_ids_key"),
segmentTensors)
.fetch(_tfBertSignatures
.getOrElse(ModelSignatureConstants.LastHiddenStateV1.key, "missing_sequence_output_key"))

val outs = runner.run().asScala
val embeddings = TensorResources.extractFloats(outs.head)
val embeddings = detectedEngine match {

case ModelEngine.onnx =>
// [nb of encoded sentences , maxSentenceLength]
val (runner, env) = onnxWrapper.get.getSession()

val tokenTensors =
OnnxTensor.createTensor(env, batch.map(x => x.map(x => x.toLong)).toArray)
val maskTensors =
OnnxTensor.createTensor(
env,
batch.map(sentence => sentence.map(x => if (x == 0L) 0L else 1L)).toArray)

val segmentTensors =
OnnxTensor.createTensor(env, batch.map(x => Array.fill(maxSentenceLength)(0L)).toArray)

val inputs =
Map(
"input_ids" -> tokenTensors,
"attention_mask" -> maskTensors,
"token_type_ids" -> segmentTensors).asJava

// TODO: A try without a catch or finally is equivalent to putting its body in a block; no exceptions are handled.
try {
val results = runner.run(inputs)
try {
val embeddings = results
.get("last_hidden_state")
.get()
.asInstanceOf[OnnxTensor]
.getFloatBuffer
.array()
tokenTensors.close()
maskTensors.close()
segmentTensors.close()
// runner.close()
// env.close()
//
embeddings
} finally if (results != null) results.close()
}
case _ =>
val tensors = new TensorResources()

val (tokenTensors, maskTensors, segmentTensors) =
PrepareEmbeddings.prepareBatchTensorsWithSegment(
tensors,
batch,
maxSentenceLength,
batchLength)

val runner = tensorflowWrapper.get
.getTFSessionWithSignature(
configProtoBytes = configProtoBytes,
savedSignatures = signatures,
initAllTables = false)
.runner

runner
.feed(
_tfBertSignatures.getOrElse(
ModelSignatureConstants.InputIdsV1.key,
"missing_input_id_key"),
tokenTensors)
.feed(
_tfBertSignatures
.getOrElse(ModelSignatureConstants.AttentionMaskV1.key, "missing_input_mask_key"),
maskTensors)
.feed(
_tfBertSignatures
.getOrElse(ModelSignatureConstants.TokenTypeIdsV1.key, "missing_segment_ids_key"),
segmentTensors)
.fetch(
_tfBertSignatures
.getOrElse(
ModelSignatureConstants.LastHiddenStateV1.key,
"missing_sequence_output_key"))

val outs = runner.run().asScala
val embeddings = TensorResources.extractFloats(outs.head)

tokenTensors.close()
maskTensors.close()
segmentTensors.close()
tensors.clearSession(outs)
tensors.clearTensors()

embeddings

tokenTensors.close()
maskTensors.close()
segmentTensors.close()
tensors.clearSession(outs)
tensors.clearTensors()
}

PrepareEmbeddings.prepareBatchWordEmbeddings(
batch,
Expand All @@ -133,48 +190,91 @@ private[johnsnowlabs] class Bert(
val maxSentenceLength = batch.map(pieceIds => pieceIds.length).max
val batchLength = batch.length

val tensors = new TensorResources()

val (tokenTensors, maskTensors, segmentTensors) =
PrepareEmbeddings.prepareBatchTensorsWithSegment(
tensors = tensors,
batch = batch,
maxSentenceLength = maxSentenceLength,
batchLength = batchLength)

val runner = tensorflowWrapper
.getTFSessionWithSignature(
configProtoBytes = configProtoBytes,
savedSignatures = signatures,
initAllTables = false)
.runner

runner
.feed(
_tfBertSignatures.getOrElse(
ModelSignatureConstants.InputIdsV1.key,
"missing_input_id_key"),
tokenTensors)
.feed(
_tfBertSignatures
.getOrElse(ModelSignatureConstants.AttentionMaskV1.key, "missing_input_mask_key"),
maskTensors)
.feed(
_tfBertSignatures
.getOrElse(ModelSignatureConstants.TokenTypeIdsV1.key, "missing_segment_ids_key"),
segmentTensors)
.fetch(_tfBertSignatures
.getOrElse(ModelSignatureConstants.PoolerOutput.key, "missing_pooled_output_key"))

val outs = runner.run().asScala
val embeddings = TensorResources.extractFloats(outs.head)

tokenTensors.close()
maskTensors.close()
segmentTensors.close()
tensors.clearSession(outs)
tensors.clearTensors()
val embeddings = detectedEngine match {
case ModelEngine.onnx =>
// [nb of encoded sentences , maxSentenceLength]
val (runner, env) = onnxWrapper.get.getSession()

val tokenTensors =
OnnxTensor.createTensor(env, batch.map(x => x.map(x => x.toLong)).toArray)
val maskTensors =
OnnxTensor.createTensor(
env,
batch.map(sentence => sentence.map(x => if (x == 0L) 0L else 1L)).toArray)

val segmentTensors =
OnnxTensor.createTensor(env, batch.map(x => Array.fill(maxSentenceLength)(0L)).toArray)

val inputs =
Map(
"input_ids" -> tokenTensors,
"attention_mask" -> maskTensors,
"token_type_ids" -> segmentTensors).asJava

try {
val results = runner.run(inputs)
try {
val embeddings = results
.get("pooler_output")
.get()
.asInstanceOf[OnnxTensor]
.getFloatBuffer
.array()
tokenTensors.close()
maskTensors.close()
segmentTensors.close()
// runner.close()
// env.close()
//
embeddings
} finally if (results != null) results.close()
}
case _ =>
val tensors = new TensorResources()

val (tokenTensors, maskTensors, segmentTensors) =
PrepareEmbeddings.prepareBatchTensorsWithSegment(
tensors,
batch,
maxSentenceLength,
batchLength)

val runner = tensorflowWrapper.get
.getTFSessionWithSignature(
configProtoBytes = configProtoBytes,
savedSignatures = signatures,
initAllTables = false)
.runner

runner
.feed(
_tfBertSignatures.getOrElse(
ModelSignatureConstants.InputIdsV1.key,
"missing_input_id_key"),
tokenTensors)
.feed(
_tfBertSignatures
.getOrElse(ModelSignatureConstants.AttentionMaskV1.key, "missing_input_mask_key"),
maskTensors)
.feed(
_tfBertSignatures
.getOrElse(ModelSignatureConstants.TokenTypeIdsV1.key, "missing_segment_ids_key"),
segmentTensors)
.fetch(_tfBertSignatures
.getOrElse(ModelSignatureConstants.PoolerOutput.key, "missing_pooled_output_key"))

val outs = runner.run().asScala
val embeddings = TensorResources.extractFloats(outs.head)

tokenTensors.close()
maskTensors.close()
segmentTensors.close()
tensors.clearSession(outs)
tensors.clearTensors()

embeddings

}
val dim = embeddings.length / batchLength
embeddings.grouped(dim).toArray

Expand All @@ -200,17 +300,17 @@ private[johnsnowlabs] class Bert(
segmentBuffers.offset(offset).write(Array.fill(maxSentenceLength)(0L))
}

val runner = tensorflowWrapper
val tokenTensors = tensors.createLongBufferTensor(shape, tokenBuffers)
val maskTensors = tensors.createLongBufferTensor(shape, maskBuffers)
val segmentTensors = tensors.createLongBufferTensor(shape, segmentBuffers)

val runner = tensorflowWrapper.get
.getTFSessionWithSignature(
configProtoBytes = configProtoBytes,
savedSignatures = signatures,
initAllTables = false)
.runner

val tokenTensors = tensors.createLongBufferTensor(shape, tokenBuffers)
val maskTensors = tensors.createLongBufferTensor(shape, maskBuffers)
val segmentTensors = tensors.createLongBufferTensor(shape, segmentBuffers)

runner
.feed(
_tfBertSignatures.getOrElse(
Expand Down Expand Up @@ -257,7 +357,6 @@ private[johnsnowlabs] class Bert(
maxSentenceLength,
sentenceStartTokenId,
sentenceEndTokenId)

val vectors = tag(encoded)

/*Combine tokens and calculated embeddings*/
Expand Down Expand Up @@ -324,7 +423,6 @@ private[johnsnowlabs] class Bert(
maxSentenceLength,
sentenceStartTokenId,
sentenceEndTokenId)

val embeddings = if (isLong) {
tagSequenceSBert(encoded)
} else {
Expand Down
Loading