From 145806da19d56a60b9d108bcc7cd46eca983ae27 Mon Sep 17 00:00:00 2001 From: Rajat Date: Thu, 30 May 2024 09:59:59 -0400 Subject: [PATCH] Fix incorrect LLAMA2 position ID --- .../scala/com/johnsnowlabs/ml/ai/LLAMA2.scala | 36 +++++++++++-------- 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/LLAMA2.scala b/src/main/scala/com/johnsnowlabs/ml/ai/LLAMA2.scala index 11e20eca5b7fee..13968ce48cab3a 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/LLAMA2.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/LLAMA2.scala @@ -46,7 +46,6 @@ private[johnsnowlabs] class LLAMA2( else if (openvinoWrapper.isDefined) Openvino.name else ONNX.name - private var nextPositionId: Option[Array[Long]] = None private val GenerationConfig( bosTokenId: Int, paddingTokenId: Int, @@ -168,7 +167,6 @@ private[johnsnowlabs] class LLAMA2( applySoftmax = false, ovInferRequest = ovInferRequest) - nextPositionId = None modelOutputs } @@ -272,38 +270,49 @@ private[johnsnowlabs] class LLAMA2( decoderOutputs case Openvino.name => val decoderOutputs = - getDecoderOutputsOv(decoderInputIds.toArray, ovInferRequest.get) + getDecoderOutputsOv( + encoderInputIds.toArray, + decoderInputIds.toArray, + ovInferRequest.get) decoderOutputs } } private def getDecoderOutputsOv( - inputIds: Array[Array[Int]], + encoderInputIds: Array[Array[Int]], + decoderInputIds: Array[Array[Int]], inferRequest: InferRequest): (Array[Array[Float]]) = { val (inputIdsLong, inputPositionIDsLong): (Array[Long], Array[Long]) = - if (nextPositionId.isDefined) { - val inpIdsLong = inputIds.map { tokenIds => tokenIds.last.toLong } - (inpIdsLong, nextPositionId.get) - } else { - val inpIdsLong = inputIds.flatMap { tokenIds => tokenIds.map(_.toLong) } - val posIdsLong = inputIds.flatMap { tokenIds => + if (encoderInputIds.head.length == decoderInputIds.head.length) { + // First pass + val inpIdsLong = decoderInputIds.flatMap { tokenIds => tokenIds.map(_.toLong) } + val posIdsLong = decoderInputIds.flatMap { tokenIds => tokenIds.zipWithIndex.map { case (_, i) => i.toLong } } (inpIdsLong, posIdsLong) + } else { + // Subsequent passes + val inpIdsLong = decoderInputIds.map { tokenIds => tokenIds.last.toLong } + val posIdsLong = decoderInputIds.map { tokenIds => + tokenIds.zipWithIndex.map { case (_, i) => + i.toLong + }.last + } + (inpIdsLong, posIdsLong) } val attentionMask: Array[Long] = - inputIds.flatMap { tokenIds => tokenIds.map(_ => 1L) } + decoderInputIds.flatMap { tokenIds => tokenIds.map(_ => 1L) } - val batchSize: Int = inputIds.length + val batchSize: Int = decoderInputIds.length val beamIdx: Array[Int] = new Array[Int](batchSize) val shape: Array[Int] = Array(batchSize, inputIdsLong.length / batchSize) val inputIdsLongTensor: org.intel.openvino.Tensor = new org.intel.openvino.Tensor(shape, inputIdsLong) val decoderAttentionMask: org.intel.openvino.Tensor = - new org.intel.openvino.Tensor(Array(batchSize, inputIds.head.length), attentionMask) + new org.intel.openvino.Tensor(Array(batchSize, decoderInputIds.head.length), attentionMask) val decoderPositionIDs: org.intel.openvino.Tensor = new org.intel.openvino.Tensor(shape, inputPositionIDsLong) val beamIdxTensor: org.intel.openvino.Tensor = @@ -318,7 +327,6 @@ private[johnsnowlabs] class LLAMA2( val result = inferRequest.get_tensor("logits") val logitsRaw = result.data() - nextPositionId = Some(inputIds.map(tokenIds => tokenIds.length.toLong)) val sequenceLength = inputIdsLong.length / batchSize val decoderOutputs = (0 until batchSize).map(i => {