diff --git a/python/sparknlp/annotator/seq2seq/llama2_transformer.py b/python/sparknlp/annotator/seq2seq/llama2_transformer.py index 0960d53c09cc11..c5c80fbf00692e 100644 --- a/python/sparknlp/annotator/seq2seq/llama2_transformer.py +++ b/python/sparknlp/annotator/seq2seq/llama2_transformer.py @@ -291,9 +291,9 @@ def __init__(self, classname="com.johnsnowlabs.nlp.annotators.seq2seq.LLAMA2Tran minOutputLength=0, maxOutputLength=20, doSample=False, - temperature=1.0, + temperature=0.6, topK=50, - topP=1.0, + topP=0.9, repetitionPenalty=1.0, noRepeatNgramSize=0, ignoreTokenIds=[], diff --git a/src/test/scala/com/johnsnowlabs/nlp/annotators/seq2seq/LLAMA2TestSpec.scala b/src/test/scala/com/johnsnowlabs/nlp/annotators/seq2seq/LLAMA2TestSpec.scala index 2a87dc4352b86a..8fdef329ad1f53 100644 --- a/src/test/scala/com/johnsnowlabs/nlp/annotators/seq2seq/LLAMA2TestSpec.scala +++ b/src/test/scala/com/johnsnowlabs/nlp/annotators/seq2seq/LLAMA2TestSpec.scala @@ -25,7 +25,7 @@ import org.scalatest.flatspec.AnyFlatSpec class LLAMA2TestSpec extends AnyFlatSpec { - "llama-7b" should "should handle temperature=0 correctly and not crash when predicting more than 1 element with doSample=True" taggedAs FastTest in { + "llama-7b" should "should handle temperature=0 correctly and not crash when predicting more than 1 element with doSample=True" taggedAs SlowTest in { // Even tough the Paper states temperature in interval [0,1), using temperature=0 will result in division by 0 error. // Also DoSample=True may result in infinities being generated and distFiltered.length==0 which results in exception if we don't return 0 instead internally. val testData = ResourceHelper.spark