From ecd3021a1ea759a1a611ac2b345c2f9b4b4e1edb Mon Sep 17 00:00:00 2001 From: ahmedlone127 Date: Fri, 8 Nov 2024 03:46:43 +0500 Subject: [PATCH] Update Bart.scala --- .../scala/com/johnsnowlabs/ml/ai/Bart.scala | 308 +++++++++--------- 1 file changed, 153 insertions(+), 155 deletions(-) diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/Bart.scala b/src/main/scala/com/johnsnowlabs/ml/ai/Bart.scala index 1efaddb23b5313..fe79d4cd3d123b 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/Bart.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/Bart.scala @@ -34,24 +34,24 @@ import org.tensorflow.{Session, Tensor} import scala.collection.JavaConverters._ /** This class is used to run Bart model for For Sequence Batches of WordpieceTokenizedSentence. - * Input for this model must be tokenized with a SentencePieceModel, - * - * @param tensorflow - * BART Model wrapper with TensorFlowWrapper - * @param configProtoBytes - * Configuration for TensorFlow session - */ + * Input for this model must be tokenized with a SentencePieceModel, + * + * @param tensorflow + * BART Model wrapper with TensorFlowWrapper + * @param configProtoBytes + * Configuration for TensorFlow session + */ private[johnsnowlabs] class Bart( - val tensorflowWrapper: Option[TensorflowWrapper], - val onnxWrapper: Option[EncoderDecoderWithoutPastWrappers], - val openvinoWrapper: Option[OpenvinoEncoderDecoderWithoutPastWrappers], - configProtoBytes: Option[Array[Byte]] = None, - signatures: Option[Map[String, String]] = None, - merges: Map[(String, String), Int], - vocabulary: Map[String, Int], - useCache: Boolean = false) - extends Serializable + val tensorflowWrapper: Option[TensorflowWrapper], + val onnxWrapper: Option[EncoderDecoderWithoutPastWrappers], + val openvinoWrapper: Option[OpenvinoEncoderDecoderWithoutPastWrappers], + configProtoBytes: Option[Array[Byte]] = None, + signatures: Option[Map[String, String]] = None, + merges: Map[(String, String), Int], + vocabulary: Map[String, Int], + useCache: Boolean = false) + extends Serializable with Generate { val bpeTokenizer: BartTokenizer = BpeTokenizer @@ -75,6 +75,7 @@ private[johnsnowlabs] class Bart( else if (openvinoWrapper.isDefined) Openvino.name else TensorFlow.name + private object OnnxSignatures { val encoderInputIDs: String = "input_ids" val encoderAttentionMask: String = "attention_mask" @@ -103,51 +104,51 @@ private[johnsnowlabs] class Bart( /** @param sentences - * Sequence of WordpieceTokenizedSentence - * @param batchSize - * Batch size - * @param minOutputLength - * Minimum length of output - * @param maxOutputLength - * Maximum length of output - * @param doSample - * Whether to sample or not - * @param temperature - * Temperature for sampling - * @param topK - * Top K for sampling - * @param topP - * Top P for sampling - * @param repetitionPenalty - * Repetition penalty for sampling - * @param noRepeatNgramSize - * No repeat ngram size for sampling - * @param task - * Task - * @param randomSeed - * Random seed - * @param ignoreTokenIds - * Ignore token ids - * @param beamSize - * Beam size - * @return - */ + * Sequence of WordpieceTokenizedSentence + * @param batchSize + * Batch size + * @param minOutputLength + * Minimum length of output + * @param maxOutputLength + * Maximum length of output + * @param doSample + * Whether to sample or not + * @param temperature + * Temperature for sampling + * @param topK + * Top K for sampling + * @param topP + * Top P for sampling + * @param repetitionPenalty + * Repetition penalty for sampling + * @param noRepeatNgramSize + * No repeat ngram size for sampling + * @param task + * Task + * @param randomSeed + * Random seed + * @param ignoreTokenIds + * Ignore token ids + * @param beamSize + * Beam size + * @return + */ def predict( - sentences: Seq[Annotation], - batchSize: Int, - minOutputLength: Int, - maxOutputLength: Int, - doSample: Boolean, - temperature: Double, - topK: Int, - topP: Double, - repetitionPenalty: Double, - noRepeatNgramSize: Int, - task: String, - randomSeed: Option[Long] = None, - ignoreTokenIds: Array[Int] = Array(), - beamSize: Int, - maxInputLength: Int): Seq[Annotation] = { + sentences: Seq[Annotation], + batchSize: Int, + minOutputLength: Int, + maxOutputLength: Int, + doSample: Boolean, + temperature: Double, + topK: Int, + topP: Double, + repetitionPenalty: Double, + noRepeatNgramSize: Int, + task: String, + randomSeed: Option[Long] = None, + ignoreTokenIds: Array[Int] = Array(), + beamSize: Int, + maxInputLength: Int): Seq[Annotation] = { val batchDecoder = sentences.grouped(batchSize).toArray.flatMap { batch => val batchSP = encode(batch, task) @@ -189,46 +190,46 @@ private[johnsnowlabs] class Bart( } /** @param batch - * Sequence of WordpieceTokenizedSentence - * @param minOutputLength - * Minimum length of output - * @param maxOutputLength - * Maximum length of output - * @param doSample - * Whether to sample or not - * @param temperature - * Temperature for sampling - * @param topK - * Top K for sampling - * @param topP - * Top P for sampling - * @param repetitionPenalty - * Repetition penalty for sampling - * @param noRepeatNgramSize - * No repeat ngram size for sampling - * @param randomSeed - * Random seed - * @param ignoreTokenIds - * Ignore token ids - * @param beamSize - * Beam size - * @return - * Sequence of WordpieceTokenizedSentence - */ + * Sequence of WordpieceTokenizedSentence + * @param minOutputLength + * Minimum length of output + * @param maxOutputLength + * Maximum length of output + * @param doSample + * Whether to sample or not + * @param temperature + * Temperature for sampling + * @param topK + * Top K for sampling + * @param topP + * Top P for sampling + * @param repetitionPenalty + * Repetition penalty for sampling + * @param noRepeatNgramSize + * No repeat ngram size for sampling + * @param randomSeed + * Random seed + * @param ignoreTokenIds + * Ignore token ids + * @param beamSize + * Beam size + * @return + * Sequence of WordpieceTokenizedSentence + */ def tag( - batch: Seq[Array[Int]], - minOutputLength: Int, - maxOutputLength: Int, - doSample: Boolean, - temperature: Double, - topK: Int, - topP: Double, - repetitionPenalty: Double, - noRepeatNgramSize: Int, - randomSeed: Option[Long], - ignoreTokenIds: Array[Int] = Array(), - beamSize: Int, - maxInputLength: Int): Array[Array[Int]] = { + batch: Seq[Array[Int]], + minOutputLength: Int, + maxOutputLength: Int, + doSample: Boolean, + temperature: Double, + topK: Int, + topP: Double, + repetitionPenalty: Double, + noRepeatNgramSize: Int, + randomSeed: Option[Long], + ignoreTokenIds: Array[Int] = Array(), + beamSize: Int, + maxInputLength: Int): Array[Array[Int]] = { val ignoreTokenIdsInt = ignoreTokenIds val expandedEncoderInputIdsVals = @@ -236,6 +237,7 @@ private[johnsnowlabs] class Bart( val sequencesLength = expandedEncoderInputIdsVals.map(x => x.length).toArray val maxSentenceLength = sequencesLength.max // - curLen + val numReturn_sequences = 1 // from config @@ -296,11 +298,8 @@ private[johnsnowlabs] class Bart( ModelSignatureConstants.EncoderAttentionMask.key, "missing_encoder_attention_mask"), encoderAttentionMaskTensors) - .fetch( - _tfBartSignatures - .getOrElse( - ModelSignatureConstants.CachedEncoderOutput.key, - "missing_last_hidden_state")) + .fetch(_tfBartSignatures + .getOrElse(ModelSignatureConstants.CachedEncoderOutput.key, "missing_last_hidden_state")) val encoderOuts = runner.run().asScala val encoderOutsFloats = TensorResources.extractFloats(encoderOuts.head) @@ -361,11 +360,9 @@ private[johnsnowlabs] class Bart( nextStateTensor2 = None } modelOutputs - } - else if (detectedEngine == ONNX.name) { + else if (detectedEngine == ONNX.name) { { - var (encoderSession, encoderEnv): (OrtSession, OrtEnvironment) = (null, null) var (decoderSession, decoderEnv): (OrtSession, OrtEnvironment) = (null, null) @@ -378,14 +375,10 @@ private[johnsnowlabs] class Bart( decoderEnv = _decoderEnv val encoderAttentionMask: OnnxTensor = - OnnxTensor.createTensor( - encoderEnv, - expandedEncoderInputIdsVals.toArray.map(_.map(_ => 1L))) + OnnxTensor.createTensor(encoderEnv, expandedEncoderInputIdsVals.toArray.map(_.map(_ => 1L))) val encoderInputTensors: OnnxTensor = - OnnxTensor.createTensor( - encoderEnv, - expandedEncoderInputIdsVals.toArray.map(_.map(_.toLong))) + OnnxTensor.createTensor(encoderEnv, expandedEncoderInputIdsVals.toArray.map(_.map(_.toLong))) val encoderInputs: java.util.Map[String, OnnxTensor] = Map( OnnxSignatures.encoderInputIDs -> encoderInputTensors, @@ -411,6 +404,8 @@ private[johnsnowlabs] class Bart( if (encoderResults != null) encoderResults.close() } + + val decoderEncoderStateTensors = OnnxTensor.createTensor(encoderEnv, encoderStateBuffer) val modelOutputs = generate( batch, @@ -432,7 +427,7 @@ private[johnsnowlabs] class Bart( this.paddingTokenId, randomSeed, ignoreTokenIdsInt, - Right((decoderEnv, decoderSession))) + Right((decoderEnv,decoderSession))) encoderInputTensors.close() encoderAttentionMask.close() @@ -440,7 +435,7 @@ private[johnsnowlabs] class Bart( modelOutputs } - } + } else { val encoderInferRequest = @@ -522,23 +517,23 @@ private[johnsnowlabs] class Bart( } /** Decode a sequence of sentences - * @param sentences - * Sequence of sentences - * @return - * Sequence of decoded sentences - */ + * @param sentences + * Sequence of sentences + * @return + * Sequence of decoded sentences + */ def decode(sentences: Array[Array[Int]]): Seq[String] = { sentences.map(s => bpeTokenizer.decodeTokens(s.map(_.toInt))) } /** Encode a sequence of sentences - * @param sentences - * Sequence of sentences - * @param task - * Task - * @return - * Sequence of encoded sentences - */ + * @param sentences + * Sequence of sentences + * @param task + * Task + * @return + * Sequence of encoded sentences + */ def encode(sentences: Seq[Annotation], task: String): Seq[Array[Int]] = { SentenceSplit .unpack(sentences) @@ -554,29 +549,29 @@ private[johnsnowlabs] class Bart( } /** Get model output for a batch of input sequences - * @param encoderInputIds - * input ids - * @param decoderInputIds - * decoder input ids - * @param decoderEncoderStateTensors - * encoder state - * @param encoderAttentionMaskTensors - * attention mask - * @param maxLength - * max length - * @param session - * tensorflow session - * @return - * model output - */ + * @param encoderInputIds + * input ids + * @param decoderInputIds + * decoder input ids + * @param decoderEncoderStateTensors + * encoder state + * @param encoderAttentionMaskTensors + * attention mask + * @param maxLength + * max length + * @param session + * tensorflow session + * @return + * model output + */ override def getModelOutput( - encoderInputIds: Seq[Array[Int]], - decoderInputIds: Seq[Array[Int]], - decoderEncoderStateTensors: Either[Tensor, OnnxTensor], - encoderAttentionMaskTensors: Either[Tensor, OnnxTensor], - maxLength: Int, - session: Either[Session, (OrtEnvironment, OrtSession)], - ovInferRequest: Option[InferRequest]): Array[Array[Float]] = { + encoderInputIds: Seq[Array[Int]], + decoderInputIds: Seq[Array[Int]], + decoderEncoderStateTensors: Either[Tensor, OnnxTensor], + encoderAttentionMaskTensors: Either[Tensor, OnnxTensor], + maxLength: Int, + session: Either[Session, (OrtEnvironment, OrtSession)], + ovInferRequest: Option[InferRequest]): Array[Array[Float]] = { if (detectedEngine == TensorFlow.name) { // extract decoderEncoderStateTensors, encoderAttentionMaskTensors and Session from LEFT @@ -711,18 +706,18 @@ private[johnsnowlabs] class Bart( } decoderInputTensors.close() nextTokenLogits - } else { - + } + else if (detectedEngine == ONNX.name) { val (env, decoderSession) = session.right.get val decoderInputLength = decoderInputIds.head.length - val sequenceLength = decoderInputLength + val sequenceLength =decoderInputLength val batchSize = encoderInputIds.length val decoderInputIdsLong: Array[Array[Long]] = - decoderInputIds.map { tokenIds => tokenIds.map(_.toLong) }.toArray.map { tokenIds => - tokenIds - } + decoderInputIds.map { tokenIds => tokenIds.map(_.toLong) }. + toArray.map { tokenIds =>tokenIds} + val decoderInputIdsLongTensor: OnnxTensor = OnnxTensor.createTensor(env, decoderInputIdsLong) @@ -747,6 +742,7 @@ private[johnsnowlabs] class Bart( OnnxSignatures.decoderEncoderState -> decoderEncoderStateTensor).asJava val sessionOutput = decoderSession.run(decoderInputs) + val logitsRaw = sessionOutput.getFloatArray(OnnxSignatures.decoderOutput) val decoderOutputs = (0 until batchSize).map(i => { logitsRaw @@ -789,6 +785,7 @@ private[johnsnowlabs] class Bart( } } + private def sessionWarmup(): Unit = { val dummyInput = Array.fill(1)(0) ++ Array(eosTokenId) tag( @@ -806,5 +803,6 @@ private[johnsnowlabs] class Bart( beamSize = 1, maxInputLength = 512) + } }