Skip to content

Commit

Permalink
Refactor OpenAIEmbeddings (#14334)
Browse files Browse the repository at this point in the history
* SPARKNLP-1036: Onnx Example notebooks (#14234)

* SPARKNLP-1036: Fix dev python kernel names

* SPARKNLP-1036: Bump transformers version

* SPARKNLP-1036: Fix Colab buttons

* SPARKNLP-1036: Pin onnx version for compatibility

* SPARKNLP-1036: Upgrade Spark version

* SPARKNLP-1036: Minor Fixes

* SPARKNLP-1036: Clean Metadata

* SPARKNLP-1036: Add/Adjust Documentation

- Note for supported Spark Version of Annotators
- added missing Documentation for BGEEmbeddings

* Fixies (#14307)

* refactor OpenAIEmbeddings in Scala

* refactor OpenAIEmbeddings in Python

* add pytest.mark.slow and improve doc

---------

Co-authored-by: Devin Ha <33089471+DevinTDHa@users.noreply.github.com>
Co-authored-by: Lev <agsfer@gmail.com>
Co-authored-by: Maziyar Panahi <maziyar.panahi@iscpif.fr>
  • Loading branch information
4 people authored Jun 28, 2024
1 parent e88682c commit 09dc500
Show file tree
Hide file tree
Showing 7 changed files with 209 additions and 116 deletions.
Empty file.
10 changes: 10 additions & 0 deletions python/com/johnsnowlabs/ml/ai/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import sys

if sys.version_info[0] == 2:
raise ImportError(
"Spark NLP for Python 2.x is deprecated since version >= 4.0. "
"Please use an older versions to use it with this Python version."
)
else:
import sparknlp
sys.modules['com.johnsnowlabs.ml.ai'] = sparknlp
1 change: 1 addition & 0 deletions python/sparknlp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
sys.modules['com.johnsnowlabs.nlp.annotators.coref'] = annotator
sys.modules['com.johnsnowlabs.nlp.annotators.cv'] = annotator
sys.modules['com.johnsnowlabs.nlp.annotators.audio'] = annotator
sys.modules['com.johnsnowlabs.ml.ai'] = annotator

annotators = annotator
embeddings = annotator
Expand Down
112 changes: 43 additions & 69 deletions python/sparknlp/annotator/openai/openai_embeddings.py

Large diffs are not rendered by default.

62 changes: 62 additions & 0 deletions python/test/annotator/embeddings/open_ai_embeddings_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright 2017-2022 John Snow Labs
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import pytest
from sparknlp.annotator import *
from sparknlp.base import *
from pyspark.sql import DataFrame
from pyspark.sql import SparkSession

@pytest.mark.slow
class OpenAIEmbeddingsTestCase(unittest.TestCase):
# Set your OpenAI API key to run unit test...
def setUp(self):
self.spark = SparkSession.builder \
.appName("Tests") \
.master("local[*]") \
.config("spark.driver.memory","8G") \
.config("spark.driver.maxResultSize", "2G") \
.config("spark.jars", "lib/sparknlp.jar") \
.config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") \
.config("spark.kryoserializer.buffer.max", "1000m") \
.config("spark.jsl.settings.openai.api.key","") \
.getOrCreate()

def test_openai_embeddings(self):

documentAssembler = DocumentAssembler() \
.setInputCol("text") \
.setOutputCol("document")
openai_embeddings = OpenAIEmbeddings() \
.setInputCols("document") \
.setOutputCol("embeddings") \
.setModel("text-embedding-ada-002")

import tempfile
openai_embeddings.write().overwrite().save("file:///" + tempfile.gettempdir() + "/openai_embeddings")
loaded = OpenAIEmbeddings.load("file:///" + tempfile.gettempdir() + "/openai_embeddings")

pipeline = Pipeline().setStages([
documentAssembler,
loaded
])

sample_text = [["The food was delicious and the waiter..."]]
sample_df = self.spark.createDataFrame(sample_text).toDF("text")
pipeline.fit(sample_df).transform(sample_df).select("embeddings").show(truncate=False)



if __name__ == '__main__':
unittest.main()
99 changes: 54 additions & 45 deletions src/main/scala/com/johnsnowlabs/ml/ai/OpenAIEmbeddings.scala

Large diffs are not rendered by default.

41 changes: 39 additions & 2 deletions src/test/scala/com/johnsnowlabs/ml/ai/OpenAIEmbeddingsTest.scala
Original file line number Diff line number Diff line change
@@ -1,13 +1,28 @@
package com.johnsnowlabs.ml.ai

import com.johnsnowlabs.nlp.annotators.SparkSessionTest
import com.johnsnowlabs.tags.SlowTest
import org.apache.spark.ml.Pipeline
import org.apache.spark.sql.SparkSession
import org.scalatest.flatspec.AnyFlatSpec

class OpenAIEmbeddingsTest extends AnyFlatSpec with SparkSessionTest {
class OpenAIEmbeddingsTest extends AnyFlatSpec {

private val spark = SparkSession
.builder()
.appName("test")
.master("local[*]")
.config("spark.driver.memory", "16G")
.config("spark.driver.maxResultSize", "0")
.config("spark.kryoserializer.buffer.max", "2000M")
.config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
.config("spark.jsl.settings.openai.api.key",
"" // Set your OpenAI API key here...
)
.getOrCreate()

import spark.implicits._
private val documentAssembler =
new com.johnsnowlabs.nlp.DocumentAssembler().setInputCol("text").setOutputCol("document")

"OpenAIEmbeddings" should "generate a completion for prompts" taggedAs SlowTest in {
// Set OPENAI_API_KEY env variable to make this work
Expand All @@ -25,4 +40,26 @@ class OpenAIEmbeddingsTest extends AnyFlatSpec with SparkSessionTest {
completionDF.select("embeddings").show(false)
}

"OpenAIEmbeddings" should "work with escape chars" taggedAs SlowTest in {
val data = Seq(
(1, "Hello \"World\""),
(2, "Hello \n World"),
(3, "Hello \t World"),
(4, "Hello \r World"),
(5, "Hello \b World"),
(6, "Hello \f World"),
(7, "Hello \\ World"))
val columns = Seq("id", "text")
val testDF = spark.createDataFrame(data).toDF(columns: _*)

val openAIEmbeddings = new OpenAIEmbeddings()
.setInputCols("document")
.setOutputCol("embeddings")
.setModel("text-embedding-ada-002")

val pipeline = new Pipeline().setStages(Array(documentAssembler, openAIEmbeddings))
val resultDF = pipeline.fit(testDF).transform(testDF)
resultDF.select("embeddings").show(false)
}

}

0 comments on commit 09dc500

Please sign in to comment.