Skip to content

Commit

Permalink
Fix incorrect LLAMA2 position ID (#14308)
Browse files Browse the repository at this point in the history
  • Loading branch information
rajatkrishna authored Jun 3, 2024
1 parent c2048be commit adc193e
Showing 1 changed file with 22 additions and 14 deletions.
36 changes: 22 additions & 14 deletions src/main/scala/com/johnsnowlabs/ml/ai/LLAMA2.scala
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ private[johnsnowlabs] class LLAMA2(
else if (openvinoWrapper.isDefined) Openvino.name
else ONNX.name

private var nextPositionId: Option[Array[Long]] = None
private val GenerationConfig(
bosTokenId: Int,
paddingTokenId: Int,
Expand Down Expand Up @@ -168,7 +167,6 @@ private[johnsnowlabs] class LLAMA2(
applySoftmax = false,
ovInferRequest = ovInferRequest)

nextPositionId = None
modelOutputs
}

Expand Down Expand Up @@ -272,38 +270,49 @@ private[johnsnowlabs] class LLAMA2(
decoderOutputs
case Openvino.name =>
val decoderOutputs =
getDecoderOutputsOv(decoderInputIds.toArray, ovInferRequest.get)
getDecoderOutputsOv(
encoderInputIds.toArray,
decoderInputIds.toArray,
ovInferRequest.get)
decoderOutputs
}
}

private def getDecoderOutputsOv(
inputIds: Array[Array[Int]],
encoderInputIds: Array[Array[Int]],
decoderInputIds: Array[Array[Int]],
inferRequest: InferRequest): (Array[Array[Float]]) = {
val (inputIdsLong, inputPositionIDsLong): (Array[Long], Array[Long]) =
if (nextPositionId.isDefined) {
val inpIdsLong = inputIds.map { tokenIds => tokenIds.last.toLong }
(inpIdsLong, nextPositionId.get)
} else {
val inpIdsLong = inputIds.flatMap { tokenIds => tokenIds.map(_.toLong) }
val posIdsLong = inputIds.flatMap { tokenIds =>
if (encoderInputIds.head.length == decoderInputIds.head.length) {
// First pass
val inpIdsLong = decoderInputIds.flatMap { tokenIds => tokenIds.map(_.toLong) }
val posIdsLong = decoderInputIds.flatMap { tokenIds =>
tokenIds.zipWithIndex.map { case (_, i) =>
i.toLong
}
}
(inpIdsLong, posIdsLong)
} else {
// Subsequent passes
val inpIdsLong = decoderInputIds.map { tokenIds => tokenIds.last.toLong }
val posIdsLong = decoderInputIds.map { tokenIds =>
tokenIds.zipWithIndex.map { case (_, i) =>
i.toLong
}.last
}
(inpIdsLong, posIdsLong)
}
val attentionMask: Array[Long] =
inputIds.flatMap { tokenIds => tokenIds.map(_ => 1L) }
decoderInputIds.flatMap { tokenIds => tokenIds.map(_ => 1L) }

val batchSize: Int = inputIds.length
val batchSize: Int = decoderInputIds.length
val beamIdx: Array[Int] = new Array[Int](batchSize)
val shape: Array[Int] = Array(batchSize, inputIdsLong.length / batchSize)

val inputIdsLongTensor: org.intel.openvino.Tensor =
new org.intel.openvino.Tensor(shape, inputIdsLong)
val decoderAttentionMask: org.intel.openvino.Tensor =
new org.intel.openvino.Tensor(Array(batchSize, inputIds.head.length), attentionMask)
new org.intel.openvino.Tensor(Array(batchSize, decoderInputIds.head.length), attentionMask)
val decoderPositionIDs: org.intel.openvino.Tensor =
new org.intel.openvino.Tensor(shape, inputPositionIDsLong)
val beamIdxTensor: org.intel.openvino.Tensor =
Expand All @@ -318,7 +327,6 @@ private[johnsnowlabs] class LLAMA2(

val result = inferRequest.get_tensor("logits")
val logitsRaw = result.data()
nextPositionId = Some(inputIds.map(tokenIds => tokenIds.length.toLong))

val sequenceLength = inputIdsLong.length / batchSize
val decoderOutputs = (0 until batchSize).map(i => {
Expand Down

0 comments on commit adc193e

Please sign in to comment.