From 8045e3f9fa095e8eb0ec4e9fc4e2ab398ba51a2d Mon Sep 17 00:00:00 2001 From: Prabod Rathnayaka Date: Wed, 17 Jan 2024 10:30:57 +0000 Subject: [PATCH 01/11] introducing LLAMA2 --- .../scala/com/johnsnowlabs/ml/ai/LLAMA2.scala | 276 ++++++++++++ .../ml/util/LoadExternalModel.scala | 18 +- .../seq2seq/LLAMA2Transformer.scala | 425 ++++++++++++++++++ .../annotators/seq2seq/LLAMA2TestSpec.scala | 55 +++ 4 files changed, 769 insertions(+), 5 deletions(-) create mode 100644 src/main/scala/com/johnsnowlabs/ml/ai/LLAMA2.scala create mode 100644 src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/LLAMA2Transformer.scala create mode 100644 src/test/scala/com/johnsnowlabs/nlp/annotators/seq2seq/LLAMA2TestSpec.scala diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/LLAMA2.scala b/src/main/scala/com/johnsnowlabs/ml/ai/LLAMA2.scala new file mode 100644 index 00000000000000..fdcf7046147665 --- /dev/null +++ b/src/main/scala/com/johnsnowlabs/ml/ai/LLAMA2.scala @@ -0,0 +1,276 @@ +/* + * Copyright 2017 - 2023 John Snow Labs + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.johnsnowlabs.ml.ai + +import ai.onnxruntime.{OnnxTensor, OrtEnvironment, OrtSession} +import com.johnsnowlabs.ml.ai.util.Generation.GenerationConfig +import com.johnsnowlabs.ml.onnx.OnnxWrapper.DecoderWrappers +import com.johnsnowlabs.ml.onnx.TensorResources.implicits._ +import com.johnsnowlabs.ml.tensorflow.sentencepiece.SentencePieceWrapper +import com.johnsnowlabs.nlp.Annotation +import scala.collection.JavaConverters._ +import com.johnsnowlabs.nlp.AnnotatorType.DOCUMENT + +private[johnsnowlabs] class LLAMA2( + val onnxWrappers: DecoderWrappers, + val spp: SentencePieceWrapper, + generationConfig: GenerationConfig) + extends Serializable { + + private val GenerationConfig( + bosTokenId: Int, + _, + eosTokenId: Int, + vocabSize: Int, + beginSuppressTokens, + suppressTokenIds, + forcedDecoderIds) = + generationConfig + private val pieceSize = spp.getSppModel.getPieceSize + + /** Decode a sequence of sentences + * @param sentences + * Sequence of sentences + * @return + * Sequence of decoded sentences + */ + def decode(sentences: Array[Array[Int]]): Seq[String] = { + sentences.map { s => + val filteredPieceIds = s.filter(x => x <= pieceSize) + spp.getSppModel.decodeIds(filteredPieceIds.map(_.toInt): _*) + } + } + + /** Encode a sequence of sentences + * @param sentences + * Sequence of sentences + * @param task + * Task + * @return + * Sequence of encoded sentences + */ + def encode(sentences: Seq[Annotation]): Seq[Array[Int]] = { + sentences.map(s => { + val sentWithTask = s.result + spp.getSppModel.encodeAsIds(sentWithTask) + }) + } + + def tag( + batch: Seq[Array[Int]], + minOutputLength: Int, + maxOutputLength: Int, + doSample: Boolean, + temperature: Double, + topK: Int, + topP: Double, + repetitionPenalty: Double, + noRepeatNgramSize: Int, + randomSeed: Option[Long], + ignoreTokenIds: Array[Int] = Array(), + beamSize: Int, + maxInputLength: Int): Array[Array[Int]] = { + val (encoderSession, env) = onnxWrappers.decoder.getSession() + val ignoreTokenIdsInt = ignoreTokenIds + val expandedEncoderInputIdsVals = + batch.flatMap(x => List.fill(beamSize)(x.take(maxInputLength))) + val sequencesLength = expandedEncoderInputIdsVals.map(x => x.length).toArray + val maxSentenceLength = sequencesLength.max // - curLen + + val numReturn_sequences = 1 + // from config + + var effectiveBatch_size = 1 + var effectiveBatch_mult = 1 + if (doSample) { + effectiveBatch_size = expandedEncoderInputIdsVals.length * numReturn_sequences + effectiveBatch_mult = numReturn_sequences + } else { + effectiveBatch_size = expandedEncoderInputIdsVals.length + effectiveBatch_mult = 1 + } + + // Run the prompt through the decoder and get the past + val decoderOutputs = + generateGreedyOnnx( + expandedEncoderInputIdsVals.toArray, + (encoderSession, env), + maxOutputLength) + decoderOutputs + } + + def predict( + sentences: Seq[Annotation], + batchSize: Int, + minOutputLength: Int, + maxOutputLength: Int, + doSample: Boolean, + temperature: Double, + topK: Int, + topP: Double, + repetitionPenalty: Double, + noRepeatNgramSize: Int, + randomSeed: Option[Long] = None, + ignoreTokenIds: Array[Int] = Array(), + beamSize: Int, + maxInputLength: Int): Seq[Annotation] = { + + val batchDecoder = sentences.grouped(batchSize).toArray.flatMap { batch => + val batchSP = encode(batch) + val spIds = tag( + batchSP, + minOutputLength, + maxOutputLength, + doSample, + temperature, + topK, + topP, + repetitionPenalty, + noRepeatNgramSize, + randomSeed, + ignoreTokenIds, + beamSize, + maxInputLength) + + decode(spIds) + + } + + var sentBegin, nextSentEnd = 0 + val annotations = batchDecoder.zip(sentences).map { case (content, sent) => + nextSentEnd += content.length - 1 + val annots = new Annotation( + annotatorType = DOCUMENT, + begin = sentBegin, + end = nextSentEnd, + result = content, + metadata = sent.metadata) + sentBegin += nextSentEnd + 1 + annots + } + annotations + } + + private def getDecoderOutputsWithPast( + inputIds: Array[Array[Int]], + decoderPast: Map[String, OnnxTensor], + onnxSession: (OrtSession, OrtEnvironment)) + : (Array[Array[Float]], Map[String, OnnxTensor]) = { + val (session, env) = onnxSession + + val lastTokens: Array[Array[Long]] = + inputIds.map { tokenIds => + Array(tokenIds.last.toLong) + } + + val lastTokensTensor: OnnxTensor = + OnnxTensor.createTensor(env, lastTokens) + val decoderAttentionMask: OnnxTensor = + OnnxTensor.createTensor(env, lastTokens.map(_.map(_ => 1L))) + val decoderWithPastInputs: java.util.Map[String, OnnxTensor] = (Map( + OnnxSignatures.decoderInputIDs -> lastTokensTensor, + OnnxSignatures.decoderAttentionMask -> decoderAttentionMask) ++ decoderPast).asJava + val sessionOutput = session.run(decoderWithPastInputs) + val logits = sessionOutput.getFloatArray(OnnxSignatures.decoderOutput) + val decoderPresent = sessionOutput.getOnnxTensors(OnnxSignatures.decoderPresent) + lastTokensTensor.close() + val batchLogits = logits.grouped(vocabSize).toArray + (batchLogits, decoderPresent) + + } + + private def getDecoderOutputs( + inputIds: Array[Array[Int]], + onnxSession: (OrtSession, OrtEnvironment)): (Array[Array[Float]]) = { + val (session, env) = onnxSession + + val inputIdsLong: Array[Array[Long]] = + inputIds.map { tokenIds => tokenIds.map(_.toLong) } + + val inputIdsLongTensor: OnnxTensor = + OnnxTensor.createTensor(env, inputIdsLong) + val decoderAttentionMask: OnnxTensor = + OnnxTensor.createTensor(env, inputIdsLong.map(_.map(_ => 1L))) + val decoderInputs: java.util.Map[String, OnnxTensor] = Map( + OnnxSignatures.decoderInputIDs -> inputIdsLongTensor, + OnnxSignatures.decoderAttentionMask -> decoderAttentionMask).asJava + val sessionOutput = session.run(decoderInputs) + val logits = sessionOutput.getFloatArray(OnnxSignatures.decoderOutput) + inputIdsLongTensor.close() + val batchLogits = logits.grouped(vocabSize).toArray + batchLogits + + } + + /** Gets the index with the highest score + * + * @param scores + * Array of Scores to max + * @return + * Index of the highest score + */ + private def argmax(scores: Array[Float]): Int = + scores.zipWithIndex.maxBy { case (score, _) => + score + }._2 + private def greedyGenerationFinished( + decoderIds: Seq[Array[Int]], + eosTokenId: Int, + maxOutputLength: Int): Boolean = + decoderIds.map(_.last).forall(_ == eosTokenId) || decoderIds.head.length == maxOutputLength + + private def generateGreedyOnnx( + inputIds: Array[Array[Int]], + onnxSession: (OrtSession, OrtEnvironment), + maxOutputLength: Int): (Array[Array[Int]]) = { + + val sequencesLength = inputIds.map(x => x.length).toArray + val maxSentenceLength = sequencesLength.max // - curLen + var generatedIds: Array[Array[Int]] = inputIds + while (!greedyGenerationFinished( + generatedIds, + eosTokenId, + maxOutputLength + maxSentenceLength)) { + + val (batchLogits: Array[Array[Float]]) = + Array(getDecoderOutputs(generatedIds, onnxSession).last) + + val nextTokenIds: Array[Int] = batchLogits.map(argmax) + nextTokenIds.foreach(x => println(s"new ids:$x")) + generatedIds = + generatedIds.zip(nextTokenIds).map { case (currentIds: Array[Int], nextId: Int) => + currentIds ++ Array(nextId) + } + // print lens of generatedIds + generatedIds.foreach(x => println(x.length)) + } + generatedIds + } + + private object OnnxSignatures { + val decoderInputIDs: String = "input_ids" + val decoderAttentionMask: String = "attention_mask" + // create decoder past for 32 layers of key and value eg. past_key_values.0.key and past_key_values.0.value + val decoderPast: Array[String] = (0 until 32) + .flatMap(i => Seq(s"past_key_values.$i.key", s"past_key_values.$i.value")) + .toArray + val decoderOutput: String = "logits" + val decoderPresent: Array[String] = + (0 until 32).flatMap(i => Seq(s"present.$i.key", s"present.$i.value")).toArray + } + +} diff --git a/src/main/scala/com/johnsnowlabs/ml/util/LoadExternalModel.scala b/src/main/scala/com/johnsnowlabs/ml/util/LoadExternalModel.scala index 9848d6ae142509..827e9e7b5b2be8 100644 --- a/src/main/scala/com/johnsnowlabs/ml/util/LoadExternalModel.scala +++ b/src/main/scala/com/johnsnowlabs/ml/util/LoadExternalModel.scala @@ -64,13 +64,19 @@ object LoadExternalModel { def isOnnxModel( modelPath: String, isEncoderDecoder: Boolean = false, - withPast: Boolean = false): Boolean = { + withPast: Boolean = false, + isDecoder: Boolean = false): Boolean = { if (isEncoderDecoder) { val onnxEncoderModel = new File(modelPath, ONNX.encoderModel) val onnxDecoderModel = if (withPast) new File(modelPath, ONNX.decoderWithPastModel) else new File(modelPath, ONNX.decoderModel) onnxEncoderModel.exists() && onnxDecoderModel.exists() + } else if (isDecoder) { + val onnxDecoderModel = + if (withPast) new File(modelPath, ONNX.decoderWithPastModel) + else new File(modelPath, ONNX.decoderModel) + onnxDecoderModel.exists() } else { val onnxModel = new File(modelPath, ONNX.modelName) onnxModel.exists() @@ -81,7 +87,8 @@ object LoadExternalModel { def detectEngine( modelPath: String, isEncoderDecoder: Boolean = false, - withPast: Boolean = false): String = { + withPast: Boolean = false, + isDecoder: Boolean = false): String = { /** Check if the path is correct */ val f = new File(modelPath) @@ -98,7 +105,7 @@ object LoadExternalModel { val tfSavedModelExist = isTensorFlowModel(modelPath) /*ONNX required model's name*/ - val onnxModelExist = isOnnxModel(modelPath, isEncoderDecoder, withPast) + val onnxModelExist = isOnnxModel(modelPath, isEncoderDecoder, withPast, isDecoder) if (tfSavedModelExist) { TensorFlow.name @@ -125,10 +132,11 @@ object LoadExternalModel { def modelSanityCheck( path: String, isEncoderDecoder: Boolean = false, - withPast: Boolean = false): (String, String) = { + withPast: Boolean = false, + isDecoder: Boolean = false): (String, String) = { val localPath: String = ResourceHelper.copyToLocal(path) - (localPath, detectEngine(localPath, isEncoderDecoder, withPast)) + (localPath, detectEngine(localPath, isEncoderDecoder, withPast, isDecoder)) } def loadTextAsset(assetPath: String, assetName: String): Array[String] = { diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/LLAMA2Transformer.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/LLAMA2Transformer.scala new file mode 100644 index 00000000000000..32802ea9979579 --- /dev/null +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/LLAMA2Transformer.scala @@ -0,0 +1,425 @@ +/* + * Copyright 2017-2022 John Snow Labs + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.johnsnowlabs.nlp.annotators.seq2seq +import com.johnsnowlabs.ml.ai.util.Generation.GenerationConfig +import com.johnsnowlabs.ml.ai.LLAMA2 +import com.johnsnowlabs.ml.onnx.OnnxWrapper.DecoderWrappers +import com.johnsnowlabs.ml.onnx.{OnnxWrapper, ReadOnnxModel, WriteOnnxModel} +import com.johnsnowlabs.ml.util.LoadExternalModel.{ + loadJsonStringAsset, + loadSentencePieceAsset, + modelSanityCheck, + notSupportedEngineError +} +import com.johnsnowlabs.ml.util.ONNX +import com.johnsnowlabs.nlp.AnnotatorType.DOCUMENT +import com.johnsnowlabs.nlp._ +import com.johnsnowlabs.ml.tensorflow.sentencepiece.{ + ReadSentencePieceModel, + SentencePieceWrapper, + WriteSentencePieceModel +} +import com.johnsnowlabs.nlp.serialization.MapFeature +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.ml.param._ +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.sql.SparkSession +import com.johnsnowlabs.nlp.serialization.{MapFeature, StructFeature} +import org.json4s._ +import org.json4s.jackson.JsonMethods._ + +/** GPT-2: the OpenAI Text-To-Text Transformer + * + * GPT-2 is a large transformer-based language model with 1.5 billion parameters, trained on a + * dataset of 8 million web pages. GPT-2 is trained with a simple objective: predict the next + * word, given all of the previous words within some text. The diversity of the dataset causes + * this simple goal to contain naturally occurring demonstrations of many tasks across diverse + * domains. GPT-2 is a direct scale-up of GPT, with more than 10X the parameters and trained on + * more than 10X the amount of data. + * + * GPT-2 displays a broad set of capabilities, including the ability to generate conditional + * synthetic text samples of unprecedented quality, where we prime the model with an input and + * have it generate a lengthy continuation. In addition, GPT-2 outperforms other language models + * trained on specific domains (like Wikipedia, news, or books) without needing to use these + * domain-specific training datasets. On language tasks like question answering, reading + * comprehension, summarization, and translation, GPT-2 begins to learn these tasks from the raw + * text, using no task-specific training data. While scores on these downstream tasks are far + * from state-of-the-art, they suggest that the tasks can benefit from unsupervised techniques, + * given sufficient (unlabeled) data and compute. + * + * Pretrained models can be loaded with `pretrained` of the companion object: + * {{{ + * val gpt2 = LLAMA2Transformer.pretrained() + * .setInputCols("document") + * .setOutputCol("generation") + * }}} + * The default model is `"gpt2"`, if no name is provided. For available pretrained models please + * see the [[https://sparknlp.org/models?q=gpt2 Models Hub]]. + * + * For extended examples of usage, see + * [[https://github.com/JohnSnowLabs/spark-nlp/blob/master/src/test/scala/com/johnsnowlabs/nlp/annotators/seq2seq/GPT2TestSpec.scala GPT2TestSpec]]. + * + * '''References:''' + * - [[https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf Language Models are Unsupervised Multitask Learners]] + * - [[https://github.com/openai/gpt-2]] + * + * '''Paper Abstract:''' + * + * ''Natural language processing tasks, such as question answering, machine translation, reading + * comprehension, and summarization, are typically approached with supervised learning on + * taskspecific datasets. We demonstrate that language models begin to learn these tasks without + * any explicit supervision when trained on a new dataset of millions of webpages called WebText. + * When conditioned on a document plus questions, the answers generated by the language model + * reach F1 on the CoQA dataset - matching or exceeding the performance of 3 out of 4 baseline + * systems without using the 127,000+ training examples. The capacity of the language model is + * essential to the success of zero-shot task transfer and increasing it improves performance in + * a log-linear fashion across tasks. Our largest model, GPT-2, is a 1.5B parameter Transformer + * that achieves state of the art results on 7 out of 8 tested language modeling datasets in a + * zero-shot setting but still underfits WebText. Samples from the model reflect these + * improvements and contain coherent paragraphs of text. These findings suggest a promising path + * towards building language processing systems which learn to perform tasks from their naturally + * occurring demonstrations.'' + * + * '''Note:''' + * + * This is a very computationally expensive module especially on larger sequence. The use of an + * accelerator such as GPU is recommended. + * + * ==Example== + * {{{ + * import spark.implicits._ + * import com.johnsnowlabs.nlp.base.DocumentAssembler + * import com.johnsnowlabs.nlp.annotators.seq2seq.LLAMA2Transformer + * import org.apache.spark.ml.Pipeline + * + * val documentAssembler = new DocumentAssembler() + * .setInputCol("text") + * .setOutputCol("documents") + * + * val gpt2 = LLAMA2Transformer.pretrained("gpt2") + * .setInputCols(Array("documents")) + * .setMinOutputLength(10) + * .setMaxOutputLength(50) + * .setDoSample(false) + * .setTopK(50) + * .setNoRepeatNgramSize(3) + * .setOutputCol("generation") + * + * val pipeline = new Pipeline().setStages(Array(documentAssembler, gpt2)) + * + * val data = Seq( + * "My name is Leonardo." + * ).toDF("text") + * val result = pipeline.fit(data).transform(data) + * + * results.select("generation.result").show(truncate = false) + * +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ + * |result | + * +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ + * |[ My name is Leonardo. I am a man of letters. I have been a man for many years. I was born in the year 1776. I came to the United States in 1776, and I have lived in the United Kingdom since 1776]| + * +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ + * }}} + * + * @param uid + * required uid for storing annotator to disk + * @groupname anno Annotator types + * @groupdesc anno + * Required input and expected output annotator types + * @groupname Ungrouped Members + * @groupname param Parameters + * @groupname setParam Parameter setters + * @groupname getParam Parameter getters + * @groupname Ungrouped Members + * @groupprio param 1 + * @groupprio anno 2 + * @groupprio Ungrouped 3 + * @groupprio setParam 4 + * @groupprio getParam 5 + * @groupdesc param + * A list of (hyper-)parameter keys this annotator can take. Users can set and get the + * parameter values through setters and getters, respectively. + */ +class LLAMA2Transformer(override val uid: String) + extends AnnotatorModel[LLAMA2Transformer] + with HasBatchedAnnotate[LLAMA2Transformer] + with ParamsAndFeaturesWritable + with WriteOnnxModel + with HasGeneratorProperties + with HasEngine { + + def this() = this(Identifiable.randomUID("LLAMA2TRANSFORMER")) + + /** Input annotator type : DOCUMENT + * + * @group param + */ + override val inputAnnotatorTypes: Array[AnnotatorType] = Array(DOCUMENT) + + /** Output annotator type : DOCUMENT + * + * @group param + */ + override val outputAnnotatorType: String = DOCUMENT + + /** @group setParam */ + def setRandomSeed(value: Int): LLAMA2Transformer.this.type = { + if (randomSeed.isEmpty) { + this.randomSeed = Some(value) + } + this + } + + /** A list of token ids which are ignored in the decoder's output (Default: `Array()`) + * + * @group param + */ + var ignoreTokenIds = new IntArrayParam( + this, + "ignoreTokenIds", + "A list of token ids which are ignored in the decoder's output") + + /** @group setParam */ + def setIgnoreTokenIds(tokenIds: Array[Int]): LLAMA2Transformer.this.type = { + set(ignoreTokenIds, tokenIds) + } + + /** @group getParam */ + def getIgnoreTokenIds: Array[Int] = $(ignoreTokenIds) + + private var _model: Option[Broadcast[LLAMA2]] = None + + val generationConfig: StructFeature[GenerationConfig] = + new StructFeature(this, "generationConfig").setProtected() + + def setGenerationConfig(value: GenerationConfig): this.type = + set(generationConfig, value) + + def getGenerationConfig: GenerationConfig = $$(generationConfig) + + /** @group setParam */ + def setModelIfNotSet( + spark: SparkSession, + onnxWrappers: DecoderWrappers, + spp: SentencePieceWrapper): this.type = { + if (_model.isEmpty) { + _model = Some( + spark.sparkContext.broadcast( + new LLAMA2(onnxWrappers, spp = spp, generationConfig = getGenerationConfig))) + } + this + } + + /** @group getParam */ + def getModelIfNotSet: LLAMA2 = _model.get.value + + setDefault( + minOutputLength -> 0, + maxOutputLength -> 20, + doSample -> false, + temperature -> 1.0, + topK -> 50, + topP -> 1.0, + repetitionPenalty -> 1.0, + noRepeatNgramSize -> 3, + ignoreTokenIds -> Array(), + batchSize -> 1, + beamSize -> 1, + maxInputLength -> 512) + + /** takes a document and annotations and produces new annotations of this annotator's annotation + * type + * + * @param batchedAnnotations + * Annotations that correspond to inputAnnotationCols generated by previous annotators if any + * @return + * any number of annotations processed for every input annotation. Not necessary one to one + * relationship + */ + override def batchAnnotate(batchedAnnotations: Seq[Array[Annotation]]): Seq[Seq[Annotation]] = { + + val allAnnotations = batchedAnnotations + .filter(_.nonEmpty) + .zipWithIndex + .flatMap { case (annotations, i) => + annotations.filter(_.result.nonEmpty).map(x => (x, i)) + } + println(s"allAnnotations: ${allAnnotations.length}") + println(s"${allAnnotations.head._1}") + val processedAnnotations = if (allAnnotations.nonEmpty) { + this.getModelIfNotSet.predict( + sentences = allAnnotations.map(_._1), + batchSize = $(batchSize), + minOutputLength = $(minOutputLength), + maxOutputLength = $(maxOutputLength), + doSample = $(doSample), + temperature = $(temperature), + topK = $(topK), + topP = $(topP), + repetitionPenalty = $(repetitionPenalty), + noRepeatNgramSize = $(noRepeatNgramSize), + randomSeed = this.randomSeed, + ignoreTokenIds = $(ignoreTokenIds), + beamSize = $(beamSize), + maxInputLength = $(maxInputLength)) + } else { + Seq() + } + + // Group resulting annotations by rows. If there are not sentences in a given row, return empty sequence +// batchedAnnotations.indices.map(rowIndex => { +// val rowAnnotations = processedAnnotations +// // zip each annotation with its corresponding row index +// .zip(allAnnotations) +// // select the sentences belonging to the current row +// .filter(_._2._2 == rowIndex) +// // leave the annotation only +// .map(_._1) +// +// if (rowAnnotations.nonEmpty) +// rowAnnotations +// else +// Seq.empty[Annotation] +// }) +// Seq(Seq.empty[Annotation]) + Seq(processedAnnotations) + } + + override def onWrite(path: String, spark: SparkSession): Unit = { + super.onWrite(path, spark) + getEngine match { + case ONNX.name => + val wrappers = getModelIfNotSet.onnxWrappers + writeOnnxModels( + path, + spark, + Seq((wrappers.decoder, "decoder_model.onnx")), + LLAMA2Transformer.suffix) + } + } +} + +trait ReadablePretrainedLLAMA2TransformerModel + extends ParamsAndFeaturesReadable[LLAMA2Transformer] + with HasPretrained[LLAMA2Transformer] { + override val defaultModelName: Some[String] = Some("llama2") + + /** Java compliant-overrides */ + override def pretrained(): LLAMA2Transformer = super.pretrained() + + override def pretrained(name: String): LLAMA2Transformer = super.pretrained(name) + + override def pretrained(name: String, lang: String): LLAMA2Transformer = + super.pretrained(name, lang) + + override def pretrained(name: String, lang: String, remoteLoc: String): LLAMA2Transformer = + super.pretrained(name, lang, remoteLoc) +} + +trait ReadLLAMA2TransformerDLModel extends ReadOnnxModel with ReadSentencePieceModel { + this: ParamsAndFeaturesReadable[LLAMA2Transformer] => + + override val onnxFile: String = "llama2_onnx" + val suffix: String = "_llama2" + override val sppFile: String = "llama2_spp" + + def readModel(instance: LLAMA2Transformer, path: String, spark: SparkSession): Unit = { + instance.getEngine match { + case ONNX.name => + val wrappers = readOnnxModels(path, spark, Seq("decoder_model"), suffix) + val onnxWrappers = + DecoderWrappers(decoder = wrappers("decoder_model")) + val spp = readSentencePieceModel(path, spark, "_llama2_spp", sppFile) + instance.setModelIfNotSet(spark, onnxWrappers, spp) + case _ => + throw new Exception(notSupportedEngineError) + } + } + + addReader(readModel) + + def loadSavedModel(modelPath: String, spark: SparkSession): LLAMA2Transformer = { + implicit val formats: DefaultFormats.type = DefaultFormats // for json4 + // print model path + println(s"$modelPath") + val (localModelPath, detectedEngine) = + modelSanityCheck(modelPath, isDecoder = true) + val modelConfig: JValue = + parse(loadJsonStringAsset(localModelPath, "config.json")) + + val beginSuppressTokens: Array[Int] = + (modelConfig \ "begin_suppress_tokens").extract[Array[Int]] + + val suppressTokenIds: Array[Int] = + (modelConfig \ "suppress_tokens").extract[Array[Int]] + + val forcedDecoderIds: Array[(Int, Int)] = + (modelConfig \ "forced_decoder_ids").extract[Array[Array[Int]]].map { + case idxWithTokenId: Array[Int] if idxWithTokenId.length == 2 => + (idxWithTokenId(0), idxWithTokenId(1)) + case _ => + throw new Exception( + "Could not extract forced_decoder_ids. Should be a list of tuples with 2 entries.") + } + + def arrayOrNone[T](array: Array[T]): Option[Array[T]] = + if (array.nonEmpty) Some(array) else None + + val bosTokenId = (modelConfig \ "bos_token_id").extract[Int] + val eosTokenId = (modelConfig \ "eos_token_id").extract[Int] + val padTokenId = (modelConfig \ "eos_token_id").extract[Int] + val vocabSize = (modelConfig \ "vocab_size").extract[Int] + + val annotatorModel = new LLAMA2Transformer() + .setGenerationConfig( + GenerationConfig( + bosTokenId, + padTokenId, + eosTokenId, + vocabSize, + arrayOrNone(beginSuppressTokens), + arrayOrNone(suppressTokenIds), + arrayOrNone(forcedDecoderIds))) + val spModel = loadSentencePieceAsset(localModelPath, "tokenizer.model") + + annotatorModel.set(annotatorModel.engine, detectedEngine) + + detectedEngine match { + case ONNX.name => + val onnxWrapperDecoder = + OnnxWrapper.read( + modelPath, + zipped = false, + useBundle = true, + modelName = "decoder_model") + + val onnxWrappers = DecoderWrappers(onnxWrapperDecoder) + + annotatorModel + .setModelIfNotSet(spark, onnxWrappers, spModel) + + case _ => + throw new Exception(notSupportedEngineError) + } + + annotatorModel + } + +} + +object LLAMA2Transformer + extends ReadablePretrainedLLAMA2TransformerModel + with ReadLLAMA2TransformerDLModel diff --git a/src/test/scala/com/johnsnowlabs/nlp/annotators/seq2seq/LLAMA2TestSpec.scala b/src/test/scala/com/johnsnowlabs/nlp/annotators/seq2seq/LLAMA2TestSpec.scala new file mode 100644 index 00000000000000..6ea7a3438b3430 --- /dev/null +++ b/src/test/scala/com/johnsnowlabs/nlp/annotators/seq2seq/LLAMA2TestSpec.scala @@ -0,0 +1,55 @@ +/* + * Copyright 2017-2023 John Snow Labs + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.johnsnowlabs.nlp.annotators.seq2seq + +import com.johnsnowlabs.nlp.base.DocumentAssembler +import com.johnsnowlabs.nlp.util.io.ResourceHelper +import com.johnsnowlabs.tags.{SlowTest, FastTest} +import com.johnsnowlabs.util.Benchmark +import org.apache.spark.ml.Pipeline +import org.scalatest.flatspec.AnyFlatSpec + +class LLAMA2TestSpec extends AnyFlatSpec { + + "bart-large-cnn" should "should handle temperature=0 correctly and not crash when predicting more than 1 element with doSample=True" taggedAs FastTest in { + // Even tough the Paper states temperature in interval [0,1), using temperature=0 will result in division by 0 error. + // Also DoSample=True may result in infinities being generated and distFiltered.length==0 which results in exception if we don't return 0 instead internally. + val testData = ResourceHelper.spark + .createDataFrame(Seq( + (1, "PG&E stated it scheduled the blackouts in response to forecasts for high winds "))) + .toDF("id", "text") + .repartition(1) + val documentAssembler = new DocumentAssembler() + .setInputCol("text") + .setOutputCol("documents") + + val bart = LLAMA2Transformer + .loadSavedModel( + "/home/prabod/Projects/ModelZoo/BART/BART/custom_whisper_onnx/", + ResourceHelper.spark) + .setInputCols(Array("documents")) + .setDoSample(false) + .setMaxOutputLength(50) + .setOutputCol("generation") + new Pipeline() + .setStages(Array(documentAssembler, bart)) + .fit(testData) + .transform(testData) + .show(truncate = false) + + } +} From 492e6c39e7abedb403c017e06d430b8b037ebd82 Mon Sep 17 00:00:00 2001 From: Prabod Rathnayaka Date: Wed, 17 Jan 2024 11:49:05 +0000 Subject: [PATCH 02/11] Added option to read model from model path to onnx wrapper --- .../scala/com/johnsnowlabs/ml/ai/LLAMA2.scala | 9 +++--- .../johnsnowlabs/ml/onnx/OnnxWrapper.scala | 32 ++++++++++++++++++- .../annotators/seq2seq/LLAMA2TestSpec.scala | 4 +-- 3 files changed, 36 insertions(+), 9 deletions(-) diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/LLAMA2.scala b/src/main/scala/com/johnsnowlabs/ml/ai/LLAMA2.scala index fdcf7046147665..1f22dc28c1b0f5 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/LLAMA2.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/LLAMA2.scala @@ -18,10 +18,12 @@ package com.johnsnowlabs.ml.ai import ai.onnxruntime.{OnnxTensor, OrtEnvironment, OrtSession} import com.johnsnowlabs.ml.ai.util.Generation.GenerationConfig +import com.johnsnowlabs.ml.onnx.OnnxSession import com.johnsnowlabs.ml.onnx.OnnxWrapper.DecoderWrappers import com.johnsnowlabs.ml.onnx.TensorResources.implicits._ import com.johnsnowlabs.ml.tensorflow.sentencepiece.SentencePieceWrapper import com.johnsnowlabs.nlp.Annotation + import scala.collection.JavaConverters._ import com.johnsnowlabs.nlp.AnnotatorType.DOCUMENT @@ -30,7 +32,7 @@ private[johnsnowlabs] class LLAMA2( val spp: SentencePieceWrapper, generationConfig: GenerationConfig) extends Serializable { - + private val onnxSessionOptions: Map[String, String] = new OnnxSession().getSessionOptions private val GenerationConfig( bosTokenId: Int, _, @@ -84,7 +86,7 @@ private[johnsnowlabs] class LLAMA2( ignoreTokenIds: Array[Int] = Array(), beamSize: Int, maxInputLength: Int): Array[Array[Int]] = { - val (encoderSession, env) = onnxWrappers.decoder.getSession() + val (encoderSession, env) = onnxWrappers.decoder.getSession(onnxSessionOptions) val ignoreTokenIdsInt = ignoreTokenIds val expandedEncoderInputIdsVals = batch.flatMap(x => List.fill(beamSize)(x.take(maxInputLength))) @@ -250,13 +252,10 @@ private[johnsnowlabs] class LLAMA2( Array(getDecoderOutputs(generatedIds, onnxSession).last) val nextTokenIds: Array[Int] = batchLogits.map(argmax) - nextTokenIds.foreach(x => println(s"new ids:$x")) generatedIds = generatedIds.zip(nextTokenIds).map { case (currentIds: Array[Int], nextId: Int) => currentIds ++ Array(nextId) } - // print lens of generatedIds - generatedIds.foreach(x => println(x.length)) } generatedIds } diff --git a/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxWrapper.scala b/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxWrapper.scala index 7f4fb80fcff0e5..ff30ee3954400b 100644 --- a/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxWrapper.scala +++ b/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxWrapper.scala @@ -94,6 +94,21 @@ object OnnxWrapper { (session, env) } + private def withSafeOnnxModelPathLoader( + onnxModelPath: String, + sessionOptions: Map[String, String]): (OrtSession, OrtEnvironment) = + this.synchronized { + val env = OrtEnvironment.getEnvironment() + val sessionOptionsObject = if (sessionOptions.isEmpty) { + new SessionOptions() + } else { + mapToSessionOptionsObject(sessionOptions) + } + + val session = env.createSession(onnxModelPath, sessionOptionsObject) + (session, env) + } + def read( modelPath: String, zipped: Boolean = true, @@ -119,7 +134,20 @@ object OnnxWrapper { else Paths.get(folder, new File(folder).list().head).toString val modelFile = new File(onnxFile) val modelBytes = FileUtils.readFileToByteArray(modelFile) - val (session, env) = withSafeOnnxModelLoader(modelBytes, sessionOptions) + var session: OrtSession = null + var env: OrtEnvironment = null + try { + val (_session, _env) = withSafeOnnxModelLoader(modelBytes, sessionOptions) + session = _session + env = _env + } catch { + case e: Exception => { + println("Error loading model from file, trying to load from path") + val (_session, _env) = withSafeOnnxModelPathLoader(onnxFile, sessionOptions) + session = _session + env = _env + } + } // 4. Remove tmp folder FileHelper.delete(tmpFolder) @@ -209,4 +237,6 @@ object OnnxWrapper { encoder: OnnxWrapper, decoder: OnnxWrapper, decoderWithPast: OnnxWrapper) + + case class DecoderWrappers(decoder: OnnxWrapper) } diff --git a/src/test/scala/com/johnsnowlabs/nlp/annotators/seq2seq/LLAMA2TestSpec.scala b/src/test/scala/com/johnsnowlabs/nlp/annotators/seq2seq/LLAMA2TestSpec.scala index 6ea7a3438b3430..c908fa54e909b2 100644 --- a/src/test/scala/com/johnsnowlabs/nlp/annotators/seq2seq/LLAMA2TestSpec.scala +++ b/src/test/scala/com/johnsnowlabs/nlp/annotators/seq2seq/LLAMA2TestSpec.scala @@ -38,9 +38,7 @@ class LLAMA2TestSpec extends AnyFlatSpec { .setOutputCol("documents") val bart = LLAMA2Transformer - .loadSavedModel( - "/home/prabod/Projects/ModelZoo/BART/BART/custom_whisper_onnx/", - ResourceHelper.spark) + .loadSavedModel("/home/prabod/Projects/ModelZoo/BART/BART/llama2_7b/", ResourceHelper.spark) .setInputCols(Array("documents")) .setDoSample(false) .setMaxOutputLength(50) From 48cb061ea7d5689fbb4dfa3d696ed3d31131e6df Mon Sep 17 00:00:00 2001 From: Prabod Rathnayaka Date: Wed, 17 Jan 2024 14:41:57 +0000 Subject: [PATCH 03/11] Added option to read model from model path to onnx wrapper --- .../johnsnowlabs/ml/onnx/OnnxWrapper.scala | 48 +++++++++---------- .../seq2seq/LLAMA2Transformer.scala | 19 -------- 2 files changed, 22 insertions(+), 45 deletions(-) diff --git a/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxWrapper.scala b/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxWrapper.scala index ff30ee3954400b..9f36e61be637e1 100644 --- a/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxWrapper.scala +++ b/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxWrapper.scala @@ -81,7 +81,8 @@ object OnnxWrapper { // TODO: make sure this.synchronized is needed or it's not a bottleneck private def withSafeOnnxModelLoader( onnxModel: Array[Byte], - sessionOptions: Map[String, String]): (OrtSession, OrtEnvironment) = + sessionOptions: Map[String, String], + onnxModelPath: Option[String] = None): (OrtSession, OrtEnvironment) = this.synchronized { val env = OrtEnvironment.getEnvironment() val sessionOptionsObject = if (sessionOptions.isEmpty) { @@ -89,24 +90,13 @@ object OnnxWrapper { } else { mapToSessionOptionsObject(sessionOptions) } - - val session = env.createSession(onnxModel, sessionOptionsObject) - (session, env) - } - - private def withSafeOnnxModelPathLoader( - onnxModelPath: String, - sessionOptions: Map[String, String]): (OrtSession, OrtEnvironment) = - this.synchronized { - val env = OrtEnvironment.getEnvironment() - val sessionOptionsObject = if (sessionOptions.isEmpty) { - new SessionOptions() + if (onnxModelPath.isDefined) { + val session = env.createSession(onnxModelPath.get, sessionOptionsObject) + (session, env) } else { - mapToSessionOptionsObject(sessionOptions) + val session = env.createSession(onnxModel, sessionOptionsObject) + (session, env) } - - val session = env.createSession(onnxModelPath, sessionOptionsObject) - (session, env) } def read( @@ -132,23 +122,29 @@ object OnnxWrapper { val onnxFile = if (useBundle) Paths.get(modelPath, s"$modelName.onnx").toString else Paths.get(folder, new File(folder).list().head).toString + + // see if the onnx model has a .onnx_data file + val onnxDataFile: Boolean = if (useBundle) { + val onnxDataFile = Paths.get(modelPath, s"$modelName.onnx_data").toFile + onnxDataFile.exists() + } else { + val onnxDataFile = Paths.get(folder, new File(folder).list().head + "_data").toFile + onnxDataFile.exists() + } val modelFile = new File(onnxFile) val modelBytes = FileUtils.readFileToByteArray(modelFile) var session: OrtSession = null var env: OrtEnvironment = null - try { + if (onnxDataFile) { + val (_session, _env) = withSafeOnnxModelLoader(modelBytes, sessionOptions, Some(onnxFile)) + session = _session + env = _env + } else { val (_session, _env) = withSafeOnnxModelLoader(modelBytes, sessionOptions) session = _session env = _env - } catch { - case e: Exception => { - println("Error loading model from file, trying to load from path") - val (_session, _env) = withSafeOnnxModelPathLoader(onnxFile, sessionOptions) - session = _session - env = _env - } - } + } // 4. Remove tmp folder FileHelper.delete(tmpFolder) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/LLAMA2Transformer.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/LLAMA2Transformer.scala index 32802ea9979579..3de44c5b6f70a4 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/LLAMA2Transformer.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/LLAMA2Transformer.scala @@ -257,8 +257,6 @@ class LLAMA2Transformer(override val uid: String) .flatMap { case (annotations, i) => annotations.filter(_.result.nonEmpty).map(x => (x, i)) } - println(s"allAnnotations: ${allAnnotations.length}") - println(s"${allAnnotations.head._1}") val processedAnnotations = if (allAnnotations.nonEmpty) { this.getModelIfNotSet.predict( sentences = allAnnotations.map(_._1), @@ -278,23 +276,6 @@ class LLAMA2Transformer(override val uid: String) } else { Seq() } - - // Group resulting annotations by rows. If there are not sentences in a given row, return empty sequence -// batchedAnnotations.indices.map(rowIndex => { -// val rowAnnotations = processedAnnotations -// // zip each annotation with its corresponding row index -// .zip(allAnnotations) -// // select the sentences belonging to the current row -// .filter(_._2._2 == rowIndex) -// // leave the annotation only -// .map(_._1) -// -// if (rowAnnotations.nonEmpty) -// rowAnnotations -// else -// Seq.empty[Annotation] -// }) -// Seq(Seq.empty[Annotation]) Seq(processedAnnotations) } From b4cf4cf948aca185e870da0bbce270828ea7d8b8 Mon Sep 17 00:00:00 2001 From: Prabod Rathnayaka Date: Wed, 17 Jan 2024 15:21:40 +0000 Subject: [PATCH 04/11] updated text description --- .../seq2seq/LLAMA2Transformer.scala | 67 ++++++++----------- 1 file changed, 28 insertions(+), 39 deletions(-) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/LLAMA2Transformer.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/LLAMA2Transformer.scala index 3de44c5b6f70a4..859d73107ffbe2 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/LLAMA2Transformer.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/LLAMA2Transformer.scala @@ -42,57 +42,46 @@ import com.johnsnowlabs.nlp.serialization.{MapFeature, StructFeature} import org.json4s._ import org.json4s.jackson.JsonMethods._ -/** GPT-2: the OpenAI Text-To-Text Transformer +/** Llama 2: Open Foundation and Fine-Tuned Chat Models * - * GPT-2 is a large transformer-based language model with 1.5 billion parameters, trained on a - * dataset of 8 million web pages. GPT-2 is trained with a simple objective: predict the next - * word, given all of the previous words within some text. The diversity of the dataset causes - * this simple goal to contain naturally occurring demonstrations of many tasks across diverse - * domains. GPT-2 is a direct scale-up of GPT, with more than 10X the parameters and trained on - * more than 10X the amount of data. + * The Llama 2 release introduces a family of pretrained and fine-tuned LLMs, ranging in scale + * from 7B to 70B parameters (7B, 13B, 70B). The pretrained models come with significant + * improvements over the Llama 1 models, including being trained on 40% more tokens, having a + * much longer context length (4k tokens 🤯), and using grouped-query attention for fast + * inference of the 70B model🔥! * - * GPT-2 displays a broad set of capabilities, including the ability to generate conditional - * synthetic text samples of unprecedented quality, where we prime the model with an input and - * have it generate a lengthy continuation. In addition, GPT-2 outperforms other language models - * trained on specific domains (like Wikipedia, news, or books) without needing to use these - * domain-specific training datasets. On language tasks like question answering, reading - * comprehension, summarization, and translation, GPT-2 begins to learn these tasks from the raw - * text, using no task-specific training data. While scores on these downstream tasks are far - * from state-of-the-art, they suggest that the tasks can benefit from unsupervised techniques, - * given sufficient (unlabeled) data and compute. + * However, the most exciting part of this release is the fine-tuned models (Llama 2-Chat), which + * have been optimized for dialogue applications using Reinforcement Learning from Human Feedback + * (RLHF). Across a wide range of helpfulness and safety benchmarks, the Llama 2-Chat models + * perform better than most open models and achieve comparable performance to ChatGPT according + * to human evaluations. * * Pretrained models can be loaded with `pretrained` of the companion object: * {{{ - * val gpt2 = LLAMA2Transformer.pretrained() + * val llama2 = LLAMA2Transformer.pretrained() * .setInputCols("document") * .setOutputCol("generation") * }}} - * The default model is `"gpt2"`, if no name is provided. For available pretrained models please - * see the [[https://sparknlp.org/models?q=gpt2 Models Hub]]. + * The default model is `"llama2-7b"`, if no name is provided. For available pretrained models + * please see the [[https://sparknlp.org/models?q=llama2 Models Hub]]. * * For extended examples of usage, see - * [[https://github.com/JohnSnowLabs/spark-nlp/blob/master/src/test/scala/com/johnsnowlabs/nlp/annotators/seq2seq/GPT2TestSpec.scala GPT2TestSpec]]. + * [[https://github.com/JohnSnowLabs/spark-nlp/blob/master/src/test/scala/com/johnsnowlabs/nlp/annotators/seq2seq/LLAMA2TestSpec.scala LLAMA2TestSpec]]. * * '''References:''' - * - [[https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf Language Models are Unsupervised Multitask Learners]] - * - [[https://github.com/openai/gpt-2]] + * - [[https://ai.meta.com/research/publications/llama-2-open-foundation-and-fine-tuned-chat-models/ Llama 2: Open Foundation and Fine-Tuned Chat Models]] + * - [[https://github.com/facebookresearch/llama]] * * '''Paper Abstract:''' * - * ''Natural language processing tasks, such as question answering, machine translation, reading - * comprehension, and summarization, are typically approached with supervised learning on - * taskspecific datasets. We demonstrate that language models begin to learn these tasks without - * any explicit supervision when trained on a new dataset of millions of webpages called WebText. - * When conditioned on a document plus questions, the answers generated by the language model - * reach F1 on the CoQA dataset - matching or exceeding the performance of 3 out of 4 baseline - * systems without using the 127,000+ training examples. The capacity of the language model is - * essential to the success of zero-shot task transfer and increasing it improves performance in - * a log-linear fashion across tasks. Our largest model, GPT-2, is a 1.5B parameter Transformer - * that achieves state of the art results on 7 out of 8 tested language modeling datasets in a - * zero-shot setting but still underfits WebText. Samples from the model reflect these - * improvements and contain coherent paragraphs of text. These findings suggest a promising path - * towards building language processing systems which learn to perform tasks from their naturally - * occurring demonstrations.'' + * ''In this work, we develop and release Llama 2, a collection of pretrained and fine-tuned + * large language models (LLMs) ranging in scale from 7 billion to 70 billion parameters. Our + * fine-tuned LLMs, called Llama 2-Chat, are optimized for dialogue use cases. Our models + * outperform open-source chat models on most benchmarks we tested, and based on our human + * evaluations for helpfulness and safety, may be a suitable substitute for closed-source models. + * We provide a detailed description of our approach to fine-tuning and safety improvements of + * Llama 2-Chat in order to enable the community to build on our work and contribute to the + * responsible development of LLMs.'' * * '''Note:''' * @@ -110,7 +99,7 @@ import org.json4s.jackson.JsonMethods._ * .setInputCol("text") * .setOutputCol("documents") * - * val gpt2 = LLAMA2Transformer.pretrained("gpt2") + * val llama2 = LLAMA2Transformer.pretrained("llama2-7b") * .setInputCols(Array("documents")) * .setMinOutputLength(10) * .setMaxOutputLength(50) @@ -119,7 +108,7 @@ import org.json4s.jackson.JsonMethods._ * .setNoRepeatNgramSize(3) * .setOutputCol("generation") * - * val pipeline = new Pipeline().setStages(Array(documentAssembler, gpt2)) + * val pipeline = new Pipeline().setStages(Array(documentAssembler, llama2)) * * val data = Seq( * "My name is Leonardo." @@ -296,7 +285,7 @@ class LLAMA2Transformer(override val uid: String) trait ReadablePretrainedLLAMA2TransformerModel extends ParamsAndFeaturesReadable[LLAMA2Transformer] with HasPretrained[LLAMA2Transformer] { - override val defaultModelName: Some[String] = Some("llama2") + override val defaultModelName: Some[String] = Some("llama2-7b") /** Java compliant-overrides */ override def pretrained(): LLAMA2Transformer = super.pretrained() From 44bbc87be29610888d36a6f7605bde2184a291f9 Mon Sep 17 00:00:00 2001 From: Prabod Rathnayaka Date: Thu, 18 Jan 2024 13:37:06 +0000 Subject: [PATCH 05/11] LLAMA2 python API --- python/sparknlp/annotator/seq2seq/__init__.py | 1 + .../annotator/seq2seq/llama2_transformer.py | 343 ++++++++++++++++++ python/sparknlp/internal/__init__.py | 4 + .../seq2seq/llama2_transformer_test.py | 47 +++ .../seq2seq/LLAMA2Transformer.scala | 2 - .../annotators/seq2seq/LLAMA2TestSpec.scala | 2 +- 6 files changed, 396 insertions(+), 3 deletions(-) create mode 100644 python/sparknlp/annotator/seq2seq/llama2_transformer.py create mode 100644 python/test/annotator/seq2seq/llama2_transformer_test.py diff --git a/python/sparknlp/annotator/seq2seq/__init__.py b/python/sparknlp/annotator/seq2seq/__init__.py index f1bbfdac84535a..8bb8c6af6535e4 100644 --- a/python/sparknlp/annotator/seq2seq/__init__.py +++ b/python/sparknlp/annotator/seq2seq/__init__.py @@ -17,3 +17,4 @@ from sparknlp.annotator.seq2seq.marian_transformer import * from sparknlp.annotator.seq2seq.t5_transformer import * from sparknlp.annotator.seq2seq.bart_transformer import * +from sparknlp.annotator.seq2seq.llama2_transformer import * diff --git a/python/sparknlp/annotator/seq2seq/llama2_transformer.py b/python/sparknlp/annotator/seq2seq/llama2_transformer.py new file mode 100644 index 00000000000000..0960d53c09cc11 --- /dev/null +++ b/python/sparknlp/annotator/seq2seq/llama2_transformer.py @@ -0,0 +1,343 @@ +# Copyright 2017-2022 John Snow Labs +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains classes for the LLAMA2Transformer.""" + +from sparknlp.common import * + + +class LLAMA2Transformer(AnnotatorModel, HasBatchedAnnotate, HasEngine): + """Llama 2: Open Foundation and Fine-Tuned Chat Models + + The Llama 2 release introduces a family of pretrained and fine-tuned LLMs, ranging in scale + from 7B to 70B parameters (7B, 13B, 70B). The pretrained models come with significant + improvements over the Llama 1 models, including being trained on 40% more tokens, having a + much longer context length (4k tokens 🤯), and using grouped-query attention for fast + inference of the 70B model🔥! + + However, the most exciting part of this release is the fine-tuned models (Llama 2-Chat), which + have been optimized for dialogue applications using Reinforcement Learning from Human Feedback + (RLHF). Across a wide range of helpfulness and safety benchmarks, the Llama 2-Chat models + perform better than most open models and achieve comparable performance to ChatGPT according + to human evaluations. + + Pretrained models can be loaded with :meth:`.pretrained` of the companion + object: + + >>> llama2 = LLAMA2Transformer.pretrained() \\ + ... .setInputCols(["document"]) \\ + ... .setOutputCol("generation") + + + The default model is ``"llam2-7b"``, if no name is provided. For available + pretrained models please see the `Models Hub + `__. + + ====================== ====================== + Input Annotation types Output Annotation type + ====================== ====================== + ``DOCUMENT`` ``DOCUMENT`` + ====================== ====================== + + Parameters + ---------- + configProtoBytes + ConfigProto from tensorflow, serialized into byte array. + minOutputLength + Minimum length of the sequence to be generated, by default 0 + maxOutputLength + Maximum length of output text, by default 20 + doSample + Whether or not to use sampling; use greedy decoding otherwise, by default False + temperature + The value used to module the next token probabilities, by default 1.0 + topK + The number of highest probability vocabulary tokens to keep for + top-k-filtering, by default 50 + topP + Top cumulative probability for vocabulary tokens, by default 1.0 + + If set to float < 1, only the most probable tokens with probabilities + that add up to ``topP`` or higher are kept for generation. + repetitionPenalty + The parameter for repetition penalty, 1.0 means no penalty. , by default + 1.0 + noRepeatNgramSize + If set to int > 0, all ngrams of that size can only occur once, by + default 0 + ignoreTokenIds + A list of token ids which are ignored in the decoder's output, by + default [] + + Notes + ----- + This is a very computationally expensive module especially on larger + sequence. The use of an accelerator such as GPU is recommended. + + References + ---------- + - `Llama 2: Open Foundation and Fine-Tuned Chat Models + `__ + - https://github.com/facebookresearch/llama + + **Paper Abstract:** + + *In this work, we develop and release Llama 2, a collection of pretrained and fine-tuned + large language models (LLMs) ranging in scale from 7 billion to 70 billion parameters. Our + fine-tuned LLMs, called Llama 2-Chat, are optimized for dialogue use cases. Our models + outperform open-source chat models on most benchmarks we tested, and based on our human + evaluations for helpfulness and safety, may be a suitable substitute for closed-source models. + We provide a detailed description of our approach to fine-tuning and safety improvements of + Llama 2-Chat in order to enable the community to build on our work and contribute to the + responsible development of LLMs.* + + Examples + -------- + >>> import sparknlp + >>> from sparknlp.base import * + >>> from sparknlp.annotator import * + >>> from pyspark.ml import Pipeline + >>> documentAssembler = DocumentAssembler() \\ + ... .setInputCol("text") \\ + ... .setOutputCol("documents") + >>> llama2 = LLAMA2Transformer.pretrained("llama2-7b") \\ + ... .setInputCols(["documents"]) \\ + ... .setMaxOutputLength(50) \\ + ... .setOutputCol("generation") + >>> pipeline = Pipeline().setStages([documentAssembler, llama2]) + >>> data = spark.createDataFrame([["My name is Leonardo."]]).toDF("text") + >>> result = pipeline.fit(data).transform(data) + >>> result.select("summaries.generation").show(truncate=False) + +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ + |result | + +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ + |[My name is Leonardo. I am a man of letters. I have been a man for many years. I was born in the year 1776. I came to the United States in 1776, and I have lived in the United Kingdom since 1776.]| + -----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ + """ + + name = "LLAMA2Transformer" + + inputAnnotatorTypes = [AnnotatorType.DOCUMENT] + + outputAnnotatorType = AnnotatorType.DOCUMENT + + + configProtoBytes = Param(Params._dummy(), + "configProtoBytes", + "ConfigProto from tensorflow, serialized into byte array. Get with config_proto.SerializeToString()", + TypeConverters.toListInt) + + minOutputLength = Param(Params._dummy(), "minOutputLength", "Minimum length of the sequence to be generated", + typeConverter=TypeConverters.toInt) + + maxOutputLength = Param(Params._dummy(), "maxOutputLength", "Maximum length of output text", + typeConverter=TypeConverters.toInt) + + doSample = Param(Params._dummy(), "doSample", "Whether or not to use sampling; use greedy decoding otherwise", + typeConverter=TypeConverters.toBoolean) + + temperature = Param(Params._dummy(), "temperature", "The value used to module the next token probabilities", + typeConverter=TypeConverters.toFloat) + + topK = Param(Params._dummy(), "topK", + "The number of highest probability vocabulary tokens to keep for top-k-filtering", + typeConverter=TypeConverters.toInt) + + topP = Param(Params._dummy(), "topP", + "If set to float < 1, only the most probable tokens with probabilities that add up to ``top_p`` or higher are kept for generation", + typeConverter=TypeConverters.toFloat) + + repetitionPenalty = Param(Params._dummy(), "repetitionPenalty", + "The parameter for repetition penalty. 1.0 means no penalty. See `this paper `__ for more details", + typeConverter=TypeConverters.toFloat) + + noRepeatNgramSize = Param(Params._dummy(), "noRepeatNgramSize", + "If set to int > 0, all ngrams of that size can only occur once", + typeConverter=TypeConverters.toInt) + + ignoreTokenIds = Param(Params._dummy(), "ignoreTokenIds", + "A list of token ids which are ignored in the decoder's output", + typeConverter=TypeConverters.toListInt) + + + def setIgnoreTokenIds(self, value): + """A list of token ids which are ignored in the decoder's output. + + Parameters + ---------- + value : List[int] + The words to be filtered out + """ + return self._set(ignoreTokenIds=value) + + def setConfigProtoBytes(self, b): + """Sets configProto from tensorflow, serialized into byte array. + + Parameters + ---------- + b : List[int] + ConfigProto from tensorflow, serialized into byte array + """ + return self._set(configProtoBytes=b) + + def setMinOutputLength(self, value): + """Sets minimum length of the sequence to be generated. + + Parameters + ---------- + value : int + Minimum length of the sequence to be generated + """ + return self._set(minOutputLength=value) + + def setMaxOutputLength(self, value): + """Sets maximum length of output text. + + Parameters + ---------- + value : int + Maximum length of output text + """ + return self._set(maxOutputLength=value) + + def setDoSample(self, value): + """Sets whether or not to use sampling, use greedy decoding otherwise. + + Parameters + ---------- + value : bool + Whether or not to use sampling; use greedy decoding otherwise + """ + return self._set(doSample=value) + + def setTemperature(self, value): + """Sets the value used to module the next token probabilities. + + Parameters + ---------- + value : float + The value used to module the next token probabilities + """ + return self._set(temperature=value) + + def setTopK(self, value): + """Sets the number of highest probability vocabulary tokens to keep for + top-k-filtering. + + Parameters + ---------- + value : int + Number of highest probability vocabulary tokens to keep + """ + return self._set(topK=value) + + def setTopP(self, value): + """Sets the top cumulative probability for vocabulary tokens. + + If set to float < 1, only the most probable tokens with probabilities + that add up to ``topP`` or higher are kept for generation. + + Parameters + ---------- + value : float + Cumulative probability for vocabulary tokens + """ + return self._set(topP=value) + + def setRepetitionPenalty(self, value): + """Sets the parameter for repetition penalty. 1.0 means no penalty. + + Parameters + ---------- + value : float + The repetition penalty + + References + ---------- + See `Ctrl: A Conditional Transformer Language Model For Controllable + Generation `__ for more details. + """ + return self._set(repetitionPenalty=value) + + def setNoRepeatNgramSize(self, value): + """Sets size of n-grams that can only occur once. + + If set to int > 0, all ngrams of that size can only occur once. + + Parameters + ---------- + value : int + N-gram size can only occur once + """ + return self._set(noRepeatNgramSize=value) + + @keyword_only + def __init__(self, classname="com.johnsnowlabs.nlp.annotators.seq2seq.LLAMA2Transformer", java_model=None): + super(LLAMA2Transformer, self).__init__( + classname=classname, + java_model=java_model + ) + self._setDefault( + minOutputLength=0, + maxOutputLength=20, + doSample=False, + temperature=1.0, + topK=50, + topP=1.0, + repetitionPenalty=1.0, + noRepeatNgramSize=0, + ignoreTokenIds=[], + batchSize=1 + ) + + @staticmethod + def loadSavedModel(folder, spark_session): + """Loads a locally saved model. + + Parameters + ---------- + folder : str + Folder of the saved model + spark_session : pyspark.sql.SparkSession + The current SparkSession + + Returns + ------- + LLAMA2Transformer + The restored model + """ + from sparknlp.internal import _LLAMA2Loader + jModel = _LLAMA2Loader(folder, spark_session._jsparkSession)._java_obj + return LLAMA2Transformer(java_model=jModel) + + @staticmethod + def pretrained(name="llama2-7b", lang="en", remote_loc=None): + """Downloads and loads a pretrained model. + + Parameters + ---------- + name : str, optional + Name of the pretrained model, by default "llama2-7b" + lang : str, optional + Language of the pretrained model, by default "en" + remote_loc : str, optional + Optional remote address of the resource, by default None. Will use + Spark NLPs repositories otherwise. + + Returns + ------- + LLAMA2Transformer + The restored model + """ + from sparknlp.pretrained import ResourceDownloader + return ResourceDownloader.downloadModel(LLAMA2Transformer, name, lang, remote_loc) diff --git a/python/sparknlp/internal/__init__.py b/python/sparknlp/internal/__init__.py index f49a5e4768deab..fc2c4dccd96df6 100644 --- a/python/sparknlp/internal/__init__.py +++ b/python/sparknlp/internal/__init__.py @@ -156,6 +156,10 @@ def __init__(self, path, jspark): super(_GPT2Loader, self).__init__( "com.johnsnowlabs.nlp.annotators.seq2seq.GPT2Transformer.loadSavedModel", path, jspark) +class _LLAMA2Loader(ExtendedJavaWrapper): + def __init__(self, path, jspark): + super(_LLAMA2Loader, self).__init__( + "com.johnsnowlabs.nlp.annotators.seq2seq.LLAMA2Transformer.loadSavedModel", path, jspark) class _LongformerLoader(ExtendedJavaWrapper): def __init__(self, path, jspark): diff --git a/python/test/annotator/seq2seq/llama2_transformer_test.py b/python/test/annotator/seq2seq/llama2_transformer_test.py new file mode 100644 index 00000000000000..42b6ae3d2dcbaf --- /dev/null +++ b/python/test/annotator/seq2seq/llama2_transformer_test.py @@ -0,0 +1,47 @@ +# Copyright 2017-2022 John Snow Labs +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import pytest + +from sparknlp.annotator import * +from sparknlp.base import * +from test.util import SparkContextForTest + + +@pytest.mark.slow +class LLAMA2TransformerTextGenerationTestSpec(unittest.TestCase): + def setUp(self): + self.spark = SparkContextForTest.spark + + def runTest(self): + data = self.spark.createDataFrame([ + [1, """Leonardo Da Vinci invented the microscope?""".strip().replace("\n", " ")]]).toDF("id", "text") + + document_assembler = DocumentAssembler() \ + .setInputCol("text") \ + .setOutputCol("documents") + + llama2 = LLAMA2Transformer \ + .pretrained() \ + .setMaxOutputLength(50) \ + .setDoSample(False) \ + .setInputCols(["documents"]) \ + .setOutputCol("generation") + + pipeline = Pipeline().setStages([document_assembler, llama2]) + results = pipeline.fit(data).transform(data) + + results.select("generation.result").show(truncate=False) + diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/LLAMA2Transformer.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/LLAMA2Transformer.scala index 859d73107ffbe2..e5dac2fbdecae7 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/LLAMA2Transformer.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/LLAMA2Transformer.scala @@ -323,8 +323,6 @@ trait ReadLLAMA2TransformerDLModel extends ReadOnnxModel with ReadSentencePieceM def loadSavedModel(modelPath: String, spark: SparkSession): LLAMA2Transformer = { implicit val formats: DefaultFormats.type = DefaultFormats // for json4 - // print model path - println(s"$modelPath") val (localModelPath, detectedEngine) = modelSanityCheck(modelPath, isDecoder = true) val modelConfig: JValue = diff --git a/src/test/scala/com/johnsnowlabs/nlp/annotators/seq2seq/LLAMA2TestSpec.scala b/src/test/scala/com/johnsnowlabs/nlp/annotators/seq2seq/LLAMA2TestSpec.scala index c908fa54e909b2..e890569a2ae15f 100644 --- a/src/test/scala/com/johnsnowlabs/nlp/annotators/seq2seq/LLAMA2TestSpec.scala +++ b/src/test/scala/com/johnsnowlabs/nlp/annotators/seq2seq/LLAMA2TestSpec.scala @@ -25,7 +25,7 @@ import org.scalatest.flatspec.AnyFlatSpec class LLAMA2TestSpec extends AnyFlatSpec { - "bart-large-cnn" should "should handle temperature=0 correctly and not crash when predicting more than 1 element with doSample=True" taggedAs FastTest in { + "bart-large-cnn" should "should handle temperature=0 correctly and not crash when predicting more than 1 element with doSample=True" taggedAs SlowTest in { // Even tough the Paper states temperature in interval [0,1), using temperature=0 will result in division by 0 error. // Also DoSample=True may result in infinities being generated and distFiltered.length==0 which results in exception if we don't return 0 instead internally. val testData = ResourceHelper.spark From 95d2587f778646b8468b6fdf857be5a95beb072d Mon Sep 17 00:00:00 2001 From: Prabod Rathnayaka Date: Thu, 18 Jan 2024 16:08:46 +0000 Subject: [PATCH 06/11] added method to save onnx_data --- .../ml/onnx/OnnxSerializeModel.scala | 7 ++++++ .../johnsnowlabs/ml/onnx/OnnxWrapper.scala | 23 +++++++++++-------- .../annotators/seq2seq/LLAMA2TestSpec.scala | 6 +++-- 3 files changed, 24 insertions(+), 12 deletions(-) diff --git a/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxSerializeModel.scala b/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxSerializeModel.scala index c9e2f2890ee72f..ac8a2568d99826 100644 --- a/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxSerializeModel.scala +++ b/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxSerializeModel.scala @@ -52,6 +52,13 @@ trait WriteOnnxModel { // 3. Copy to dest folder fs.copyFromLocalFile(new Path(onnxFile), new Path(path)) + + // 4. check if there is a onnx_data file + + val onnxDataFile = Paths.get(onnxWrapper.onnxModelPath.get + "_data").toFile + if (onnxDataFile.exists()) { + fs.copyFromLocalFile(new Path(onnxDataFile.getAbsolutePath), new Path(path)) + } } // 4. Remove tmp folder diff --git a/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxWrapper.scala b/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxWrapper.scala index 9f36e61be637e1..fa0b4909589a92 100644 --- a/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxWrapper.scala +++ b/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxWrapper.scala @@ -23,17 +23,18 @@ import ai.onnxruntime.{OrtEnvironment, OrtSession} import com.johnsnowlabs.util.{ConfigHelper, FileHelper, ZipArchiveUtil} import org.apache.commons.io.FileUtils import org.slf4j.{Logger, LoggerFactory} - +import org.apache.hadoop.fs.{FileSystem, Path} import java.io._ import java.nio.file.{Files, Paths} import java.util.UUID import scala.util.{Failure, Success, Try} -class OnnxWrapper(var onnxModel: Array[Byte]) extends Serializable { +class OnnxWrapper(var onnxModel: Array[Byte], var onnxModelPath: Option[String] = None) + extends Serializable { /** For Deserialization */ def this() = { - this(null) + this(null, null) } // Important for serialization on none-kyro serializers @@ -44,7 +45,8 @@ class OnnxWrapper(var onnxModel: Array[Byte]) extends Serializable { this.synchronized { // TODO: After testing it works remove the Map.empty if (ortSession == null && ortEnv == null) { - val (session, env) = OnnxWrapper.withSafeOnnxModelLoader(onnxModel, onnxSessionOptions) + val (session, env) = + OnnxWrapper.withSafeOnnxModelLoader(onnxModel, onnxSessionOptions, onnxModelPath) ortEnv = env ortSession = session } @@ -123,24 +125,25 @@ object OnnxWrapper { if (useBundle) Paths.get(modelPath, s"$modelName.onnx").toString else Paths.get(folder, new File(folder).list().head).toString + var onnxDataFile: File = null // see if the onnx model has a .onnx_data file - val onnxDataFile: Boolean = if (useBundle) { - val onnxDataFile = Paths.get(modelPath, s"$modelName.onnx_data").toFile + val onnxDataFileExist: Boolean = if (useBundle) { + onnxDataFile = Paths.get(modelPath, s"$modelName.onnx_data").toFile onnxDataFile.exists() } else { - val onnxDataFile = Paths.get(folder, new File(folder).list().head + "_data").toFile + onnxDataFile = Paths.get(folder, new File(folder).list().head + "_data").toFile onnxDataFile.exists() } val modelFile = new File(onnxFile) val modelBytes = FileUtils.readFileToByteArray(modelFile) var session: OrtSession = null var env: OrtEnvironment = null - if (onnxDataFile) { + if (onnxDataFileExist) { val (_session, _env) = withSafeOnnxModelLoader(modelBytes, sessionOptions, Some(onnxFile)) session = _session env = _env } else { - val (_session, _env) = withSafeOnnxModelLoader(modelBytes, sessionOptions) + val (_session, _env) = withSafeOnnxModelLoader(modelBytes, sessionOptions, Some(onnxFile)) session = _session env = _env @@ -148,7 +151,7 @@ object OnnxWrapper { // 4. Remove tmp folder FileHelper.delete(tmpFolder) - val onnxWrapper = new OnnxWrapper(modelBytes) + val onnxWrapper = new OnnxWrapper(modelBytes, Option(onnxFile)) onnxWrapper.ortSession = session onnxWrapper.ortEnv = env onnxWrapper diff --git a/src/test/scala/com/johnsnowlabs/nlp/annotators/seq2seq/LLAMA2TestSpec.scala b/src/test/scala/com/johnsnowlabs/nlp/annotators/seq2seq/LLAMA2TestSpec.scala index e890569a2ae15f..9de31b070a5b3e 100644 --- a/src/test/scala/com/johnsnowlabs/nlp/annotators/seq2seq/LLAMA2TestSpec.scala +++ b/src/test/scala/com/johnsnowlabs/nlp/annotators/seq2seq/LLAMA2TestSpec.scala @@ -25,7 +25,7 @@ import org.scalatest.flatspec.AnyFlatSpec class LLAMA2TestSpec extends AnyFlatSpec { - "bart-large-cnn" should "should handle temperature=0 correctly and not crash when predicting more than 1 element with doSample=True" taggedAs SlowTest in { + "llama-7b" should "should handle temperature=0 correctly and not crash when predicting more than 1 element with doSample=True" taggedAs FastTest in { // Even tough the Paper states temperature in interval [0,1), using temperature=0 will result in division by 0 error. // Also DoSample=True may result in infinities being generated and distFiltered.length==0 which results in exception if we don't return 0 instead internally. val testData = ResourceHelper.spark @@ -38,7 +38,9 @@ class LLAMA2TestSpec extends AnyFlatSpec { .setOutputCol("documents") val bart = LLAMA2Transformer - .loadSavedModel("/home/prabod/Projects/ModelZoo/BART/BART/llama2_7b/", ResourceHelper.spark) + .loadSavedModel( + "/home/prabod/Projects/ModelZoo/BART/BART/llama2_7b/onnx/", + ResourceHelper.spark) .setInputCols(Array("documents")) .setDoSample(false) .setMaxOutputLength(50) From a9b2b5c19fc69cb7ba62d54abf986aad8914b55e Mon Sep 17 00:00:00 2001 From: Prabod Rathnayaka Date: Fri, 19 Jan 2024 11:18:23 +0000 Subject: [PATCH 07/11] added position ids --- .../scala/com/johnsnowlabs/ml/ai/LLAMA2.scala | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/LLAMA2.scala b/src/main/scala/com/johnsnowlabs/ml/ai/LLAMA2.scala index 1f22dc28c1b0f5..a22f0ee16f6677 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/LLAMA2.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/LLAMA2.scala @@ -203,13 +203,24 @@ private[johnsnowlabs] class LLAMA2( val inputIdsLong: Array[Array[Long]] = inputIds.map { tokenIds => tokenIds.map(_.toLong) } + val inputPositionIDsLong: Array[Array[Long]] = + inputIds.map { tokenIds => + tokenIds.zipWithIndex.map { case (_, i) => + i.toLong + } + } + val inputIdsLongTensor: OnnxTensor = OnnxTensor.createTensor(env, inputIdsLong) val decoderAttentionMask: OnnxTensor = OnnxTensor.createTensor(env, inputIdsLong.map(_.map(_ => 1L))) + val decoderPositionIDs: OnnxTensor = + OnnxTensor.createTensor(env, inputPositionIDsLong) + val decoderInputs: java.util.Map[String, OnnxTensor] = Map( OnnxSignatures.decoderInputIDs -> inputIdsLongTensor, - OnnxSignatures.decoderAttentionMask -> decoderAttentionMask).asJava + OnnxSignatures.decoderAttentionMask -> decoderAttentionMask, + OnnxSignatures.decoderPositionIDs -> decoderPositionIDs).asJava val sessionOutput = session.run(decoderInputs) val logits = sessionOutput.getFloatArray(OnnxSignatures.decoderOutput) inputIdsLongTensor.close() @@ -263,6 +274,8 @@ private[johnsnowlabs] class LLAMA2( private object OnnxSignatures { val decoderInputIDs: String = "input_ids" val decoderAttentionMask: String = "attention_mask" + val decoderPositionIDs: String = "position_ids" + // create decoder past for 32 layers of key and value eg. past_key_values.0.key and past_key_values.0.value val decoderPast: Array[String] = (0 until 32) .flatMap(i => Seq(s"past_key_values.$i.key", s"past_key_values.$i.value")) From 5304c1fffd63f9c8940981687feec89cec1ce6f6 Mon Sep 17 00:00:00 2001 From: Prabod Rathnayaka Date: Mon, 22 Jan 2024 14:16:39 +0000 Subject: [PATCH 08/11] - updated Generate.scala to accept onnx tensors - added beam search support for LLAMA2 --- .../scala/com/johnsnowlabs/ml/ai/Bart.scala | 34 ++++-- .../scala/com/johnsnowlabs/ml/ai/LLAMA2.scala | 108 ++++++++++++++---- .../ml/ai/VisionEncoderDecoder.scala | 18 +-- .../ml/ai/util/Generation/Generate.scala | 19 +-- .../seq2seq/LLAMA2Transformer.scala | 4 +- .../annotators/seq2seq/LLAMA2TestSpec.scala | 5 +- 6 files changed, 136 insertions(+), 52 deletions(-) diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/Bart.scala b/src/main/scala/com/johnsnowlabs/ml/ai/Bart.scala index 87934db7686034..61970ed2f92a3f 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/Bart.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/Bart.scala @@ -16,6 +16,7 @@ package com.johnsnowlabs.ml.ai +import ai.onnxruntime.{OnnxTensor, OrtEnvironment, OrtSession} import com.johnsnowlabs.ml.ai.util.Generation.Generate import com.johnsnowlabs.ml.tensorflow.sign.{ModelSignatureConstants, ModelSignatureManager} import com.johnsnowlabs.ml.tensorflow.{TensorResources, TensorflowWrapper} @@ -277,8 +278,8 @@ private[johnsnowlabs] class Bart( val decoderInputs = batch.map(_ => Array(this.eosTokenId)).toArray val modelOutputs = generate( batch, - decoderEncoderStateTensors, - encoderAttentionMaskTensors, + Left(decoderEncoderStateTensors), + Left(encoderAttentionMaskTensors), decoderInputs, maxOutputLength, minOutputLength, @@ -295,7 +296,7 @@ private[johnsnowlabs] class Bart( this.paddingTokenId, randomSeed, ignoreTokenIdsInt, - session) + Left(session)) tensorEncoder.clearTensors() tensorEncoder.clearSession(encoderOuts) @@ -362,10 +363,19 @@ private[johnsnowlabs] class Bart( override def getModelOutput( encoderInputIds: Seq[Array[Int]], decoderInputIds: Seq[Array[Int]], - decoderEncoderStateTensors: Tensor, - encoderAttentionMaskTensors: Tensor, + decoderEncoderStateTensors: Either[Tensor, OnnxTensor], + encoderAttentionMaskTensors: Either[Tensor, OnnxTensor], maxLength: Int, - session: Session): Array[Array[Float]] = { + session: Either[Session, (OrtEnvironment, OrtSession)]): Array[Array[Float]] = { + + // extract decoderEncoderStateTensors, encoderAttentionMaskTensors and Session from LEFT + assert(decoderEncoderStateTensors.isLeft) + assert(encoderAttentionMaskTensors.isLeft) + assert(session.isLeft) + + val decoderEncoderStateTensor: Tensor = decoderEncoderStateTensors.left.get + val encoderAttentionMaskTensor: Tensor = encoderAttentionMaskTensors.left.get + val sess: Session = session.left.get val sequencesLength = encoderInputIds.map(x => x.length).toArray var maxSentenceLength = sequencesLength.max // - curLen @@ -394,7 +404,7 @@ private[johnsnowlabs] class Bart( decoderInputBuffers) val runner = if (nextStateTensor1.isEmpty || nextStateTensor2.isEmpty) { - val r = session.runner + val r = sess.runner .feed( _tfBartSignatures.getOrElse( ModelSignatureConstants.InitDecoderInputIds.key, @@ -404,12 +414,12 @@ private[johnsnowlabs] class Bart( _tfBartSignatures.getOrElse( ModelSignatureConstants.InitDecoderEncoderInputIds.key, "missing_encoder_state_init"), - decoderEncoderStateTensors) + decoderEncoderStateTensor) .feed( _tfBartSignatures.getOrElse( ModelSignatureConstants.InitDecoderEncoderAttentionMask.key, "missing_decoder_encoder_attention_mask_init"), - encoderAttentionMaskTensors) + encoderAttentionMaskTensor) .fetch(_tfBartSignatures .getOrElse(ModelSignatureConstants.InitLogitsOutput.key, "missing_logits_init")) @@ -422,7 +432,7 @@ private[johnsnowlabs] class Bart( .fetch(_tfBartSignatures .getOrElse(ModelSignatureConstants.InitCachedOutPut2.key, "missing_cache2_out_init")) } else { - session.runner + sess.runner .feed( _tfBartSignatures.getOrElse( ModelSignatureConstants.CachedDecoderInputIds.key, @@ -432,12 +442,12 @@ private[johnsnowlabs] class Bart( _tfBartSignatures.getOrElse( ModelSignatureConstants.CachedDecoderEncoderInputIds.key, "missing_encoder_state"), - decoderEncoderStateTensors) + decoderEncoderStateTensor) .feed( _tfBartSignatures.getOrElse( ModelSignatureConstants.CachedDecoderEncoderAttentionMask.key, "missing_decoder_encoder_attention_mask"), - encoderAttentionMaskTensors) + encoderAttentionMaskTensor) .feed( _tfBartSignatures.getOrElse( ModelSignatureConstants.CachedDecoderInputCache1.key, diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/LLAMA2.scala b/src/main/scala/com/johnsnowlabs/ml/ai/LLAMA2.scala index a22f0ee16f6677..e0dcd2461b0a42 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/LLAMA2.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/LLAMA2.scala @@ -17,7 +17,7 @@ package com.johnsnowlabs.ml.ai import ai.onnxruntime.{OnnxTensor, OrtEnvironment, OrtSession} -import com.johnsnowlabs.ml.ai.util.Generation.GenerationConfig +import com.johnsnowlabs.ml.ai.util.Generation.{Generate, GenerationConfig} import com.johnsnowlabs.ml.onnx.OnnxSession import com.johnsnowlabs.ml.onnx.OnnxWrapper.DecoderWrappers import com.johnsnowlabs.ml.onnx.TensorResources.implicits._ @@ -26,22 +26,27 @@ import com.johnsnowlabs.nlp.Annotation import scala.collection.JavaConverters._ import com.johnsnowlabs.nlp.AnnotatorType.DOCUMENT +import org.tensorflow.{Session, Tensor} private[johnsnowlabs] class LLAMA2( val onnxWrappers: DecoderWrappers, val spp: SentencePieceWrapper, generationConfig: GenerationConfig) - extends Serializable { + extends Serializable + with Generate { + private val onnxSessionOptions: Map[String, String] = new OnnxSession().getSessionOptions + private val GenerationConfig( bosTokenId: Int, - _, + paddingTokenId: Int, eosTokenId: Int, vocabSize: Int, beginSuppressTokens, suppressTokenIds, forcedDecoderIds) = generationConfig + private val pieceSize = spp.getSppModel.getPieceSize /** Decode a sequence of sentences @@ -60,8 +65,6 @@ private[johnsnowlabs] class LLAMA2( /** Encode a sequence of sentences * @param sentences * Sequence of sentences - * @param task - * Task * @return * Sequence of encoded sentences */ @@ -88,9 +91,8 @@ private[johnsnowlabs] class LLAMA2( maxInputLength: Int): Array[Array[Int]] = { val (encoderSession, env) = onnxWrappers.decoder.getSession(onnxSessionOptions) val ignoreTokenIdsInt = ignoreTokenIds - val expandedEncoderInputIdsVals = - batch.flatMap(x => List.fill(beamSize)(x.take(maxInputLength))) - val sequencesLength = expandedEncoderInputIdsVals.map(x => x.length).toArray + val expandedDecoderInputsVals = batch + val sequencesLength = expandedDecoderInputsVals.map(x => x.length).toArray val maxSentenceLength = sequencesLength.max // - curLen val numReturn_sequences = 1 @@ -98,21 +100,52 @@ private[johnsnowlabs] class LLAMA2( var effectiveBatch_size = 1 var effectiveBatch_mult = 1 + if (doSample) { - effectiveBatch_size = expandedEncoderInputIdsVals.length * numReturn_sequences + effectiveBatch_size = expandedDecoderInputsVals.length * numReturn_sequences effectiveBatch_mult = numReturn_sequences } else { - effectiveBatch_size = expandedEncoderInputIdsVals.length + effectiveBatch_size = expandedDecoderInputsVals.length effectiveBatch_mult = 1 } // Run the prompt through the decoder and get the past - val decoderOutputs = - generateGreedyOnnx( - expandedEncoderInputIdsVals.toArray, - (encoderSession, env), - maxOutputLength) - decoderOutputs +// val decoderOutputs = +// generateGreedyOnnx( +// expandedDecoderInputsVals.toArray, +// (encoderSession, env), +// maxOutputLength) + + // dummy tensors for decoder encode state and attention mask + val decoderEncoderStateTensors = Right(OnnxTensor.createTensor(env, Array(0))) + val encoderAttentionMaskTensors = Right(OnnxTensor.createTensor(env, Array(1))) + + // output with beam search + val modelOutputs = generate( + batch, + decoderEncoderStateTensors, + encoderAttentionMaskTensors, + expandedDecoderInputsVals.toArray, + maxOutputLength + maxSentenceLength, + minOutputLength, + doSample, + beamSize, + 1, + temperature, + topK, + topP, + repetitionPenalty, + noRepeatNgramSize, + this.vocabSize, + this.eosTokenId, + this.paddingTokenId, + randomSeed, + ignoreTokenIdsInt, + Right((env, encoderSession)), + applySoftmax = false) + +// decoderOutputs + modelOutputs } def predict( @@ -195,6 +228,27 @@ private[johnsnowlabs] class LLAMA2( } + override def getModelOutput( + encoderInputIds: Seq[Array[Int]], + decoderInputIds: Seq[Array[Int]], + decoderEncoderStateTensors: Either[Tensor, OnnxTensor], + encoderAttentionMaskTensors: Either[Tensor, OnnxTensor], + maxLength: Int, + session: Either[Session, (OrtEnvironment, OrtSession)]): Array[Array[Float]] = { + + session.fold( + tfSession => { + // not implemented yet + Array() + }, + onnxSession => { + val (env, decoderSession) = onnxSession + val decoderOutputs = + getDecoderOutputs(decoderInputIds.toArray, onnxSession = (decoderSession, env)) + decoderOutputs + }) + + } private def getDecoderOutputs( inputIds: Array[Array[Int]], onnxSession: (OrtSession, OrtEnvironment)): (Array[Array[Float]]) = { @@ -222,11 +276,25 @@ private[johnsnowlabs] class LLAMA2( OnnxSignatures.decoderAttentionMask -> decoderAttentionMask, OnnxSignatures.decoderPositionIDs -> decoderPositionIDs).asJava val sessionOutput = session.run(decoderInputs) - val logits = sessionOutput.getFloatArray(OnnxSignatures.decoderOutput) - inputIdsLongTensor.close() - val batchLogits = logits.grouped(vocabSize).toArray - batchLogits + val sequenceLength = inputIds.head.length + val batchSize = inputIds.length + +// val logits = sessionOutput.getFloatArray(OnnxSignatures.decoderOutput) +// inputIdsLongTensor.close() +// decoderPositionIDs.close() +// decoderAttentionMask.close() +// val batchLogits = logits.grouped(vocabSize).toArray +// batchLogits + + val logitsRaw = sessionOutput.getFloatArray(OnnxSignatures.decoderOutput) + val decoderOutputs = (0 until batchSize).map(i => { + logitsRaw + .slice( + i * sequenceLength * vocabSize + (sequenceLength - 1) * vocabSize, + i * sequenceLength * vocabSize + sequenceLength * vocabSize) + }) + decoderOutputs.toArray } /** Gets the index with the highest score diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/VisionEncoderDecoder.scala b/src/main/scala/com/johnsnowlabs/ml/ai/VisionEncoderDecoder.scala index bc4b1fde5cedf5..37b3de3c33ef94 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/VisionEncoderDecoder.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/VisionEncoderDecoder.scala @@ -16,6 +16,7 @@ package com.johnsnowlabs.ml.ai +import ai.onnxruntime.{OnnxTensor, OrtEnvironment, OrtSession} import com.johnsnowlabs.ml.ai.util.Generation.{Generate, GenerationConfig} import com.johnsnowlabs.ml.tensorflow.sign.{ModelSignatureConstants, ModelSignatureManager} import com.johnsnowlabs.ml.tensorflow.{TensorResources, TensorflowWrapper} @@ -181,8 +182,8 @@ private[johnsnowlabs] class VisionEncoderDecoder( generate( inputIds = encoderIds, - decoderEncoderStateTensors = decoderEncoderStateTensors, - encoderAttentionMaskTensors = encoderAttentionMaskTensors, + decoderEncoderStateTensors = Left(decoderEncoderStateTensors), + encoderAttentionMaskTensors = Left(encoderAttentionMaskTensors), decoderInputs = decoderInputIds, maxOutputLength, minOutputLength, @@ -199,7 +200,7 @@ private[johnsnowlabs] class VisionEncoderDecoder( generationConfig.padId, randomSeed, Array.empty, - session) + Left(session)) } def generateFromImage( @@ -292,11 +293,14 @@ private[johnsnowlabs] class VisionEncoderDecoder( override def getModelOutput( encoderInputIds: Seq[Array[Int]], decoderInputIds: Seq[Array[Int]], - decoderEncoderStateTensors: Tensor, - encoderAttentionMaskTensors: Tensor, + decoderEncoderStateTensors: Either[Tensor, OnnxTensor], + encoderAttentionMaskTensors: Either[Tensor, OnnxTensor], maxLength: Int, - session: Session): Array[Array[Float]] = - getModelOutput(decoderInputIds, decoderEncoderStateTensors, session) + session: Either[Session, (OrtEnvironment, OrtSession)]): Array[Array[Float]] = { + val sess: Session = session.left.get + val decoderEncoderStateTensor: Tensor = decoderEncoderStateTensors.left.get + getModelOutput(decoderInputIds, decoderEncoderStateTensor, sess) + } def getModelOutput( decoderInputIds: Seq[Array[Int]], diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/util/Generation/Generate.scala b/src/main/scala/com/johnsnowlabs/ml/ai/util/Generation/Generate.scala index 3560a6859967ce..ee96819081fd3d 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/util/Generation/Generate.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/util/Generation/Generate.scala @@ -16,6 +16,7 @@ package com.johnsnowlabs.ml.ai.util.Generation +import ai.onnxruntime.{OnnxTensor, OrtEnvironment, OrtSession} import com.johnsnowlabs.ml.ai.util.Generation.Logit.LogitProcess.{ MinLengthLogitProcessor, NoRepeatNgramsLogitProcessor, @@ -82,8 +83,8 @@ trait Generate { */ def generate( inputIds: Seq[Array[Int]], - decoderEncoderStateTensors: Tensor, - encoderAttentionMaskTensors: Tensor, + decoderEncoderStateTensors: Either[Tensor, OnnxTensor], + encoderAttentionMaskTensors: Either[Tensor, OnnxTensor], decoderInputs: Array[Array[Int]], maxOutputLength: Int, minOutputLength: Int, @@ -100,7 +101,7 @@ trait Generate { paddingTokenId: Int, randomSeed: Option[Long], ignoreTokenIds: Array[Int] = Array(), - session: Session, + session: Either[Session, (OrtEnvironment, OrtSession)], applySoftmax: Boolean = true): Array[Array[Int]] = { // TODO: Add support for ignoreTokenIds @@ -178,8 +179,8 @@ trait Generate { def beamSearch( encoderInputIdsVals: Seq[Array[Int]], inputIdsVal: Seq[Array[Int]], - decoderEncoderStateTensors: Tensor, - encoderAttentionMaskTensors: Tensor, + decoderEncoderStateTensors: Either[Tensor, OnnxTensor], + encoderAttentionMaskTensors: Either[Tensor, OnnxTensor], beamScorer: BeamScorer, logitProcessor: LogitProcessorList, maxLength: Int, @@ -187,7 +188,7 @@ trait Generate { eosTokenId: Int, doSample: Boolean, randomSeed: Option[Long], - session: Session, + session: Either[Session, (OrtEnvironment, OrtSession)], applySoftmax: Boolean): Array[Array[Int]] = { val inputIds = inputIdsVal val batchSize = beamScorer.getBeamHypothesesSeq.length @@ -434,10 +435,10 @@ trait Generate { def getModelOutput( encoderInputIds: Seq[Array[Int]], decoderInputIds: Seq[Array[Int]], - decoderEncoderStateTensors: Tensor, - encoderAttentionMaskTensors: Tensor, + decoderEncoderStateTensors: Either[Tensor, OnnxTensor], + encoderAttentionMaskTensors: Either[Tensor, OnnxTensor], maxLength: Int, - session: Session): Array[Array[Float]] + session: Either[Session, (OrtEnvironment, OrtSession)]): Array[Array[Float]] /** Samples from a multinomial distribution using the provided logits. * diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/LLAMA2Transformer.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/LLAMA2Transformer.scala index e5dac2fbdecae7..474d9a1b8a12b9 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/LLAMA2Transformer.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/LLAMA2Transformer.scala @@ -219,9 +219,9 @@ class LLAMA2Transformer(override val uid: String) minOutputLength -> 0, maxOutputLength -> 20, doSample -> false, - temperature -> 1.0, + temperature -> 0.6, topK -> 50, - topP -> 1.0, + topP -> 0.9, repetitionPenalty -> 1.0, noRepeatNgramSize -> 3, ignoreTokenIds -> Array(), diff --git a/src/test/scala/com/johnsnowlabs/nlp/annotators/seq2seq/LLAMA2TestSpec.scala b/src/test/scala/com/johnsnowlabs/nlp/annotators/seq2seq/LLAMA2TestSpec.scala index 9de31b070a5b3e..2a87dc4352b86a 100644 --- a/src/test/scala/com/johnsnowlabs/nlp/annotators/seq2seq/LLAMA2TestSpec.scala +++ b/src/test/scala/com/johnsnowlabs/nlp/annotators/seq2seq/LLAMA2TestSpec.scala @@ -39,12 +39,13 @@ class LLAMA2TestSpec extends AnyFlatSpec { val bart = LLAMA2Transformer .loadSavedModel( - "/home/prabod/Projects/ModelZoo/BART/BART/llama2_7b/onnx/", + "/home/prabod/Projects/ModelZoo/LLAMA2/llama2-7b-int4-cpu-no-merged/", ResourceHelper.spark) .setInputCols(Array("documents")) - .setDoSample(false) + .setDoSample(true) .setMaxOutputLength(50) .setOutputCol("generation") + .setBeamSize(2) new Pipeline() .setStages(Array(documentAssembler, bart)) .fit(testData) From 46e8f15df890d5358c786bbf3679cb512216e13d Mon Sep 17 00:00:00 2001 From: Prabod Rathnayaka Date: Mon, 22 Jan 2024 14:18:21 +0000 Subject: [PATCH 09/11] updated max input length --- .../johnsnowlabs/nlp/annotators/seq2seq/LLAMA2Transformer.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/LLAMA2Transformer.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/LLAMA2Transformer.scala index 474d9a1b8a12b9..f70e06d437cc0d 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/LLAMA2Transformer.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/LLAMA2Transformer.scala @@ -227,7 +227,7 @@ class LLAMA2Transformer(override val uid: String) ignoreTokenIds -> Array(), batchSize -> 1, beamSize -> 1, - maxInputLength -> 512) + maxInputLength -> 4096) /** takes a document and annotations and produces new annotations of this annotator's annotation * type From c0b2c4f1904489eb06031efb7fe4ba55f95af05f Mon Sep 17 00:00:00 2001 From: Prabod Rathnayaka Date: Mon, 22 Jan 2024 15:39:39 +0000 Subject: [PATCH 10/11] updated python default params changed test to slow test --- python/sparknlp/annotator/seq2seq/llama2_transformer.py | 4 ++-- .../johnsnowlabs/nlp/annotators/seq2seq/LLAMA2TestSpec.scala | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/sparknlp/annotator/seq2seq/llama2_transformer.py b/python/sparknlp/annotator/seq2seq/llama2_transformer.py index 0960d53c09cc11..c5c80fbf00692e 100644 --- a/python/sparknlp/annotator/seq2seq/llama2_transformer.py +++ b/python/sparknlp/annotator/seq2seq/llama2_transformer.py @@ -291,9 +291,9 @@ def __init__(self, classname="com.johnsnowlabs.nlp.annotators.seq2seq.LLAMA2Tran minOutputLength=0, maxOutputLength=20, doSample=False, - temperature=1.0, + temperature=0.6, topK=50, - topP=1.0, + topP=0.9, repetitionPenalty=1.0, noRepeatNgramSize=0, ignoreTokenIds=[], diff --git a/src/test/scala/com/johnsnowlabs/nlp/annotators/seq2seq/LLAMA2TestSpec.scala b/src/test/scala/com/johnsnowlabs/nlp/annotators/seq2seq/LLAMA2TestSpec.scala index 2a87dc4352b86a..8fdef329ad1f53 100644 --- a/src/test/scala/com/johnsnowlabs/nlp/annotators/seq2seq/LLAMA2TestSpec.scala +++ b/src/test/scala/com/johnsnowlabs/nlp/annotators/seq2seq/LLAMA2TestSpec.scala @@ -25,7 +25,7 @@ import org.scalatest.flatspec.AnyFlatSpec class LLAMA2TestSpec extends AnyFlatSpec { - "llama-7b" should "should handle temperature=0 correctly and not crash when predicting more than 1 element with doSample=True" taggedAs FastTest in { + "llama-7b" should "should handle temperature=0 correctly and not crash when predicting more than 1 element with doSample=True" taggedAs SlowTest in { // Even tough the Paper states temperature in interval [0,1), using temperature=0 will result in division by 0 error. // Also DoSample=True may result in infinities being generated and distFiltered.length==0 which results in exception if we don't return 0 instead internally. val testData = ResourceHelper.spark From 4dbd0d4d0c9c80028c6b1df8a690775d31d3fe8d Mon Sep 17 00:00:00 2001 From: Prabod Rathnayaka Date: Wed, 31 Jan 2024 15:17:17 +0000 Subject: [PATCH 11/11] fixed serialization bug --- .../ml/onnx/OnnxSerializeModel.scala | 14 ++++++++++++-- .../com/johnsnowlabs/ml/onnx/OnnxWrapper.scala | 18 +++++++++++++----- .../annotators/seq2seq/LLAMA2Transformer.scala | 13 +++++++++++-- 3 files changed, 36 insertions(+), 9 deletions(-) diff --git a/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxSerializeModel.scala b/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxSerializeModel.scala index ac8a2568d99826..b482ed733b54a0 100644 --- a/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxSerializeModel.scala +++ b/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxSerializeModel.scala @@ -134,8 +134,18 @@ trait ReadOnnxModel { val localPath = new Path(tmpFolder, localModelFile).toString - // 3. Read ONNX state - val onnxWrapper = OnnxWrapper.read(localPath, zipped = zipped, useBundle = useBundle) + val fsPath = new Path(path, localModelFile).toString + + // 3. Copy onnx_data file if exists + val onnxDataFile = Paths.get(fsPath + "_data").toFile + + if (onnxDataFile.exists()) { + fs.copyToLocalFile(new Path(path, localModelFile + "_data"), new Path(tmpFolder)) + } + + // 4. Read ONNX state + val onnxWrapper = + OnnxWrapper.read(localPath, zipped = zipped, useBundle = useBundle, modelName = modelName) (modelName, onnxWrapper) }).toMap diff --git a/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxWrapper.scala b/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxWrapper.scala index fa0b4909589a92..7ea50744f5be9f 100644 --- a/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxWrapper.scala +++ b/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxWrapper.scala @@ -126,14 +126,22 @@ object OnnxWrapper { else Paths.get(folder, new File(folder).list().head).toString var onnxDataFile: File = null + // see if the onnx model has a .onnx_data file - val onnxDataFileExist: Boolean = if (useBundle) { - onnxDataFile = Paths.get(modelPath, s"$modelName.onnx_data").toFile - onnxDataFile.exists() - } else { - onnxDataFile = Paths.get(folder, new File(folder).list().head + "_data").toFile + // get parent directory of onnx file if modelPath is a file + val parentDir = if (zipped) Paths.get(modelPath).getParent.toString else modelPath + + val onnxDataFileExist: Boolean = { + onnxDataFile = Paths.get(parentDir, s"${modelName.replace(".onnx", "")}.onnx_data").toFile onnxDataFile.exists() } + + if (onnxDataFileExist) { + val onnxDataFileTmp = + Paths.get(tmpFolder, s"${modelName.replace(".onnx", "")}.onnx_data").toFile + FileUtils.copyFile(onnxDataFile, onnxDataFileTmp) + } + val modelFile = new File(onnxFile) val modelBytes = FileUtils.readFileToByteArray(modelFile) var session: OrtSession = null diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/LLAMA2Transformer.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/LLAMA2Transformer.scala index f70e06d437cc0d..3193c6b3c5e57d 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/LLAMA2Transformer.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/LLAMA2Transformer.scala @@ -148,6 +148,7 @@ class LLAMA2Transformer(override val uid: String) with ParamsAndFeaturesWritable with WriteOnnxModel with HasGeneratorProperties + with WriteSentencePieceModel with HasEngine { def this() = this(Identifiable.randomUID("LLAMA2TRANSFORMER")) @@ -278,6 +279,13 @@ class LLAMA2Transformer(override val uid: String) spark, Seq((wrappers.decoder, "decoder_model.onnx")), LLAMA2Transformer.suffix) + val obj = getModelIfNotSet + writeSentencePieceModel( + path, + spark, + obj.spp, + LLAMA2Transformer.suffix, + LLAMA2Transformer.sppFile) } } } @@ -309,9 +317,10 @@ trait ReadLLAMA2TransformerDLModel extends ReadOnnxModel with ReadSentencePieceM def readModel(instance: LLAMA2Transformer, path: String, spark: SparkSession): Unit = { instance.getEngine match { case ONNX.name => - val wrappers = readOnnxModels(path, spark, Seq("decoder_model"), suffix) + val wrappers = + readOnnxModels(path, spark, Seq("decoder_model.onnx"), suffix) val onnxWrappers = - DecoderWrappers(decoder = wrappers("decoder_model")) + DecoderWrappers(decoder = wrappers("decoder_model.onnx")) val spp = readSentencePieceModel(path, spark, "_llama2_spp", sppFile) instance.setModelIfNotSet(spark, onnxWrappers, spp) case _ =>