diff --git a/examples/python/transformers/HuggingFace_in_Spark_NLP_BertForZeroShotClassification.ipynb b/examples/python/transformers/HuggingFace_in_Spark_NLP_BertForZeroShotClassification.ipynb new file mode 100644 index 00000000000000..796edc7d0ac0e1 --- /dev/null +++ b/examples/python/transformers/HuggingFace_in_Spark_NLP_BertForZeroShotClassification.ipynb @@ -0,0 +1,630 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "8IXf_Q668WRo" + }, + "source": [ + "![JohnSnowLabs](https://sparknlp.org/assets/images/logo.png)\n", + "\n", + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/JohnSnowLabs/spark-nlp/blob/master/examples/python/transformers/HuggingFace%20in%20Spark%20NLP%20-%20BertForZeroShotClassification.ipynb)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "fDfihUkE8WRr" + }, + "source": [ + "## Import BertForZeroShotClassification models from HuggingFace 🤗 into Spark NLP 🚀 \n", + "\n", + "Let's keep in mind a few things before we start 😊 \n", + "\n", + "- This feature is only in `Spark NLP 4.4.0` and after. So please make sure you have upgraded to the latest Spark NLP release\n", + "- You can import Bert models trained/fine-tuned for sequence classification via `BertForSequenceClassification` or `TFBertForSequenceClassification`. We can use these models for zero-shot classification.\n", + " - These models are usually under `Sequence Classification` category and have `bert` in their labels\n", + " - For zero-shot classification, we will use models trained on the nli data sets. The model should have been trained on the labels `contradiction`, `entailment` and `neutral`.\n", + "- Reference: [TFBertForSequenceClassification](https://huggingface.co/docs/transformers/main/en/model_doc/bert#transformers.TFBertForSequenceClassification)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "vMg3NbLo8WRs" + }, + "source": [ + "## Export and Save HuggingFace model" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Ykej1XKH8WRu" + }, + "source": [ + "- Let's install `HuggingFace` and `TensorFlow`. You don't need `TensorFlow` to be installed for Spark NLP, however, we need it to load and save models from HuggingFace.\n", + "- We lock TensorFlow on `2.11.0` version and Transformers on `4.25.1`. This doesn't mean it won't work with the future releases, but we wanted you to know which versions have been tested successfully." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "yn28bSQi8WRu", + "outputId": "b49cc806-96c5-4013-d17b-cade1e93960a" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.8/5.8 MB\u001b[0m \u001b[31m63.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m588.3/588.3 MB\u001b[0m \u001b[31m1.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m236.8/236.8 kB\u001b[0m \u001b[31m21.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.8/7.8 MB\u001b[0m \u001b[31m76.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.7/1.7 MB\u001b[0m \u001b[31m59.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m65.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m6.0/6.0 MB\u001b[0m \u001b[31m82.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m439.2/439.2 kB\u001b[0m \u001b[31m37.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m4.9/4.9 MB\u001b[0m \u001b[31m107.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", + "tensorflow-datasets 4.9.2 requires protobuf>=3.20, but you have protobuf 3.19.6 which is incompatible.\n", + "tensorflow-metadata 1.13.1 requires protobuf<5,>=3.20.3, but you have protobuf 3.19.6 which is incompatible.\u001b[0m\u001b[31m\n", + "\u001b[0m" + ] + } + ], + "source": [ + "!pip install -q transformers==4.25.1 tensorflow==2.11.0" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ehfCmKt98WRw" + }, + "source": [ + "- HuggingFace comes with a native `saved_model` feature inside `save_pretrained` function for TensorFlow based models. We will use that to save it as TF `SavedModel`.\n", + "- We'll use [bert-base-mnli](https://huggingface.co/aloxatel/bert-base-mnli) model from HuggingFace as an example\n", + "- In addition to `TFBertForSequenceClassification` we also need to save the `BertTokenizer`. This is the same for every model, these are assets needed for tokenization inside Spark NLP." + ] + }, + { + "cell_type": "code", + "source": [ + "from transformers import TFBertForSequenceClassification, BertTokenizer \n", + "import tensorflow as tf\n", + "\n", + "MODEL_NAME = 'aloxatel/bert-base-mnli'\n", + "\n", + "tokenizer = BertTokenizer.from_pretrained(MODEL_NAME)\n", + "tokenizer.save_pretrained('./{}_tokenizer/'.format(MODEL_NAME))\n", + "\n", + "try:\n", + " model = TFBertForSequenceClassification.from_pretrained(MODEL_NAME)\n", + "except:\n", + " model = TFBertForSequenceClassification.from_pretrained(MODEL_NAME, from_pt=True)\n", + " \n", + "# Define TF Signature\n", + "@tf.function(\n", + " input_signature=[\n", + " {\n", + " \"input_ids\": tf.TensorSpec((None, None), tf.int32, name=\"input_ids\"),\n", + " \"attention_mask\": tf.TensorSpec((None, None), tf.int32, name=\"attention_mask\"),\n", + " \"token_type_ids\": tf.TensorSpec((None, None), tf.int32, name=\"token_type_ids\"),\n", + " }\n", + " ]\n", + ")\n", + "def serving_fn(input):\n", + " return model(input)\n", + "\n", + "model.save_pretrained(\"./{}\".format(MODEL_NAME), saved_model=True, signatures={\"serving_default\": serving_fn})" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "LsiRkfEBQTzS", + "outputId": "f80aa406-d04c-4541-ba08-37cd63ad5065" + }, + "execution_count": 3, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "All PyTorch model weights were used when initializing TFBertForSequenceClassification.\n", + "\n", + "All the weights of TFBertForSequenceClassification were initialized from the PyTorch model.\n", + "If your task is similar to the task the model of the checkpoint was trained on, you can already use TFBertForSequenceClassification for predictions without further training.\n", + "WARNING:absl:Found untraced functions such as embeddings_layer_call_fn, embeddings_layer_call_and_return_conditional_losses, encoder_layer_call_fn, encoder_layer_call_and_return_conditional_losses, pooler_layer_call_fn while saving (showing 5 of 420). These functions will not be directly callable after loading.\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "eDjo0QGq8WRy" + }, + "source": [ + "Let's have a look inside these two directories and see what we are dealing with:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "daGPGUdz8WRz", + "outputId": "11d8c9bc-ac26-42d6-d3e0-fc08ba159102" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "total 427968\n", + "-rw-r--r-- 1 root root 813 Jun 6 15:13 config.json\n", + "drwxr-xr-x 3 root root 4096 Jun 6 15:13 saved_model\n", + "-rw-r--r-- 1 root root 438226204 Jun 6 15:13 tf_model.h5\n" + ] + } + ], + "source": [ + "!ls -l {MODEL_NAME}" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "CwQH0R7h8WR1", + "outputId": "39dd8684-d1a7-4d51-daf8-d8bb994f1d01" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "total 9208\n", + "drwxr-xr-x 2 root root 4096 Jun 6 15:13 assets\n", + "-rw-r--r-- 1 root root 56 Jun 6 15:13 fingerprint.pb\n", + "-rw-r--r-- 1 root root 166830 Jun 6 15:13 keras_metadata.pb\n", + "-rw-r--r-- 1 root root 9245668 Jun 6 15:13 saved_model.pb\n", + "drwxr-xr-x 2 root root 4096 Jun 6 15:13 variables\n" + ] + } + ], + "source": [ + "!ls -l {MODEL_NAME}/saved_model/1" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "IPztfyM38WR2", + "outputId": "67c260e5-dff1-418e-85cd-229876e429f0" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "total 236\n", + "-rw-r--r-- 1 root root 125 Jun 6 15:12 special_tokens_map.json\n", + "-rw-r--r-- 1 root root 540 Jun 6 15:12 tokenizer_config.json\n", + "-rw-r--r-- 1 root root 231508 Jun 6 15:12 vocab.txt\n" + ] + } + ], + "source": [ + "!ls -l {MODEL_NAME}_tokenizer" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "gjrYDipS8WR2" + }, + "source": [ + "- As you can see, we need the SavedModel from `saved_model/1/` path\n", + "- We also be needing `vocab.txt` from the tokenizer\n", + "- All we need is to just copy the `vocab.txt` to `saved_model/1/assets` which Spark NLP will look for\n", + "- In addition to vocabs, we also need `labels` and their `ids` which is saved inside the model's config. We will save this inside `labels.txt`" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "id": "QnQ0jke38WR3" + }, + "outputs": [], + "source": [ + "asset_path = '{}/saved_model/1/assets'.format(MODEL_NAME)\n", + "\n", + "!cp {MODEL_NAME}_tokenizer/vocab.txt {asset_path}" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "id": "WPvOXbeZ8WR4", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "ba3ac9d9-bcbe-4ca1-ff23-f163c667fea8" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "['contradiction', 'entailment', 'neutral']\n" + ] + } + ], + "source": [ + "# get label strings\n", + "labels = [model.config.id2label[l] for l, v in model.config.id2label.items()]\n", + "print(labels)\n", + "\n", + "with open(asset_path+'/labels.txt', 'w') as f:\n", + " f.write('\\n'.join(labels))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "UzQ650AZ8WR4" + }, + "source": [ + "Voila! We have our `vocab.txt` and `labels.txt` inside assets directory" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "QcBOfJ918WR4", + "outputId": "0b3dbe3b-3b43-4f58-f5f8-d5a4151ebcbd" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "total 232\n", + "-rw-r--r-- 1 root root 32 Jun 6 15:14 labels.txt\n", + "-rw-r--r-- 1 root root 231508 Jun 6 15:14 vocab.txt\n" + ] + } + ], + "source": [ + "!ls -l {MODEL_NAME}/saved_model/1/assets" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "zk28iNof8WR5" + }, + "source": [ + "## Import and Save BertForZeroShotClassification in Spark NLP\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "J__aVVu48WR5" + }, + "source": [ + "- Let's install and setup Spark NLP in Google Colab\n", + "- This part is pretty easy via our simple script" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "udnbTHNj8WR6", + "outputId": "5c00752b-c7a0-4bad-b369-5052af7ffcb5" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Installing PySpark 3.2.3 and Spark NLP 4.4.3\n", + "setup Colab for PySpark 3.2.3 and Spark NLP 4.4.3\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m281.5/281.5 MB\u001b[0m \u001b[31m2.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m489.8/489.8 kB\u001b[0m \u001b[31m42.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m199.7/199.7 kB\u001b[0m \u001b[31m20.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h Building wheel for pyspark (setup.py) ... \u001b[?25l\u001b[?25hdone\n" + ] + } + ], + "source": [ + "! wget -q http://setup.johnsnowlabs.com/colab.sh -O - | bash" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "5u9B2ldj8WR6" + }, + "source": [ + "Let's start Spark with Spark NLP included via our simple `start()` function" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "id": "twQ6BHyo8WR6" + }, + "outputs": [], + "source": [ + "import sparknlp\n", + "# let's start Spark with Spark NLP\n", + "spark = sparknlp.start()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "rOEy0EXR8WR7" + }, + "source": [ + "- Let's use `loadSavedModel` functon in `BertForZeroShotClassification` which allows us to load TensorFlow model in SavedModel format\n", + "- Most params can be set later when you are loading this model in `BertForZeroShotClassification` in runtime like `setMaxSentenceLength`, so don't worry what you are setting them now\n", + "- `loadSavedModel` accepts two params, first is the path to the TF SavedModel. The second is the SparkSession that is `spark` variable we previously started via `sparknlp.start()`\n", + "- NOTE: `loadSavedModel` accepts local paths in addition to distributed file systems such as `HDFS`, `S3`, `DBFS`, etc. This feature was introduced in Spark NLP 4.2.2 release. Keep in mind the best and recommended way to move/share/reuse Spark NLP models is to use `write.save` so you can use `.load()` from any file systems natively." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "id": "lcqReFJO8WR7" + }, + "outputs": [], + "source": [ + "from sparknlp.annotator import *\n", + "from sparknlp.base import *\n", + "\n", + "zero_shot_classifier = BertForZeroShotClassification.loadSavedModel(\n", + " '{}/saved_model/1'.format(MODEL_NAME),\n", + " spark\n", + " )\\\n", + " .setInputCols([\"document\", \"token\"]) \\\n", + " .setOutputCol(\"class\") \\\n", + " .setCandidateLabels([\"urgent\", \"mobile\", \"travel\", \"movie\", \"music\", \"sport\", \"weather\", \"technology\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "VmHVmBCo8WR9" + }, + "source": [ + "- Let's save it on disk so it is easier to be moved around and also be used later via `.load` function" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "id": "9RBvw6p58WR9" + }, + "outputs": [], + "source": [ + "zero_shot_classifier.write().overwrite().save(\"./{}_spark_nlp\".format(MODEL_NAME))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "DgUg2p0v8WR9" + }, + "source": [ + "Let's clean up stuff we don't need anymore" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "cdBziZhw8WR-" + }, + "outputs": [], + "source": [ + "!rm -rf {MODEL_NAME}_tokenizer {MODEL_NAME}" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "_iwYIQ6U8WR-" + }, + "source": [ + "Awesome 😎 !\n", + "\n", + "This is your BertForZeroShotClassification model from HuggingFace 🤗 loaded and saved by Spark NLP 🚀 " + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "8JAkr3438WR-", + "outputId": "5a8535dd-b945-4b8f-f95e-b5fb23b8cb28" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "total 436628\n", + "-rw-r--r-- 1 root root 447094331 Jun 6 15:16 bert_classification_tensorflow\n", + "drwxr-xr-x 5 root root 4096 Jun 6 15:16 fields\n", + "drwxr-xr-x 2 root root 4096 Jun 6 15:16 metadata\n" + ] + } + ], + "source": [ + "! ls -l {MODEL_NAME}_spark_nlp" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "D5c2xWtt8WR-" + }, + "source": [ + "Now let's see how we can use it on other machines, clusters, or any place you wish to use your new and shiny BertForSequenceClassification model 😊 " + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": { + "id": "JjxWoPhW8WR_" + }, + "outputs": [], + "source": [ + "zero_shot_classifier_loaded = BertForZeroShotClassification.load(\"./{}_spark_nlp\".format(MODEL_NAME))\\\n", + " .setInputCols([\"document\",'token'])\\\n", + " .setOutputCol(\"class\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "rAITDhUg8WSA" + }, + "source": [ + "This is how you can use your loaded classifier model in Spark NLP 🚀 pipeline:" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "b4svOlV88WSA", + "outputId": "839f4e33-3a27-4ebe-ea2b-64ecd27d628a" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "+------------+\n", + "| result|\n", + "+------------+\n", + "| [urgent]|\n", + "|[technology]|\n", + "| [mobile]|\n", + "| [travel]|\n", + "| [movie]|\n", + "| [sport]|\n", + "| [urgent]|\n", + "+------------+\n", + "\n" + ] + } + ], + "source": [ + "from sparknlp.base import *\n", + "from sparknlp.annotator import *\n", + "from pyspark.ml import Pipeline, PipelineModel\n", + "\n", + "document_assembler = DocumentAssembler() \\\n", + " .setInputCol(\"text\") \\\n", + " .setOutputCol(\"document\")\n", + "\n", + "tokenizer = Tokenizer().setInputCols(\"document\").setOutputCol(\"token\")\n", + "\n", + "pipeline = Pipeline(stages=[\n", + " document_assembler,\n", + " tokenizer,\n", + " zero_shot_classifier_loaded\n", + "])\n", + "\n", + "text = [[\"I have a problem with my iphone that needs to be resolved asap!!\"],\n", + " [\"Last week I upgraded my iOS version and ever since then my phone has been overheating whenever I use your app.\"],\n", + " [\"I have a phone and I love it!\"],\n", + " [\"I really want to visit Germany and I am planning to go there next year.\"],\n", + " [\"Let's watch some movies tonight! I am in the mood for a horror movie.\"],\n", + " [\"Have you watched the match yesterday? It was a great game!\"],\n", + " [\"We need to harry up and get to the airport. We are going to miss our flight!\"]]\n", + "\n", + "# create a DataFrame in PySpark\n", + "inputDataset = spark.createDataFrame(text, [\"text\"])\n", + "model = pipeline.fit(inputDataset)\n", + "model.transform(inputDataset).select(\"class.result\").show()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "26gEdXR28WSB" + }, + "source": [ + "That's it! You can now go wild and use hundreds of `BertForSequenceClassification` models as zero-shot classifiers from HuggingFace 🤗 in Spark NLP 🚀 " + ] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Python [conda env:nlpdev]", + "language": "python", + "name": "conda-env-nlpdev-py" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.16" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/examples/python/transformers/HuggingFace_in_Spark_NLP_DistilBertForZeroClassification.ipynb b/examples/python/transformers/HuggingFace_in_Spark_NLP_DistilBertForZeroClassification.ipynb new file mode 100644 index 00000000000000..139799db44700a --- /dev/null +++ b/examples/python/transformers/HuggingFace_in_Spark_NLP_DistilBertForZeroClassification.ipynb @@ -0,0 +1,2479 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "8IXf_Q668WRo" + }, + "source": [ + "![JohnSnowLabs](https://sparknlp.org/assets/images/logo.png)\n", + "\n", + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/JohnSnowLabs/spark-nlp/blob/master/examples/python/transformers/HuggingFace%20in%20Spark%20NLP%20-%20DistilBertForZeroShotClassification.ipynb)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "fDfihUkE8WRr" + }, + "source": [ + "## Import DistilBertForZeroShotClassification models from HuggingFace 🤗 into Spark NLP 🚀 \n", + "\n", + "Let's keep in mind a few things before we start 😊 \n", + "\n", + "- This feature is only in `Spark NLP 4.4.1` and after. So please make sure you have upgraded to the latest Spark NLP release\n", + "- You can import DistilBERT models trained/fine-tuned for sequence classification via `DistilBertForSequenceClassification` or `TFDistilBertForSequenceClassification`. We can use these models for zero-shot classification.\n", + " - These models are usually under `Sequence Classification` category and have `distilbert` in their labels\n", + " - For zero-shot classification, We will usually use models trained on the nli data sets for best performance.\n", + "- Reference: [TFDistilBertForSequenceClassification](https://huggingface.co/transformers/model_doc/distilbert.html#tfdistilbertforsequenceclassification)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "vMg3NbLo8WRs" + }, + "source": [ + "## Export and Save HuggingFace model" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Ykej1XKH8WRu" + }, + "source": [ + "- Let's install `HuggingFace` and `TensorFlow`. You don't need `TensorFlow` to be installed for Spark NLP, however, we need it to load and save models from HuggingFace.\n", + "- We lock TensorFlow on `2.11.0` version and Transformers on `4.25.1`. This doesn't mean it won't work with the future releases, but we wanted you to know which versions have been tested successfully." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "yn28bSQi8WRu", + "outputId": "54c3b582-f829-4052-ce29-791454c17e82" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.8/5.8 MB\u001b[0m \u001b[31m62.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m588.3/588.3 MB\u001b[0m \u001b[31m2.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m236.8/236.8 kB\u001b[0m \u001b[31m19.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.8/7.8 MB\u001b[0m \u001b[31m104.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.7/1.7 MB\u001b[0m \u001b[31m72.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m62.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m6.0/6.0 MB\u001b[0m \u001b[31m109.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m439.2/439.2 kB\u001b[0m \u001b[31m38.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m4.9/4.9 MB\u001b[0m \u001b[31m103.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", + "tensorflow-datasets 4.9.2 requires protobuf>=3.20, but you have protobuf 3.19.6 which is incompatible.\n", + "tensorflow-metadata 1.13.1 requires protobuf<5,>=3.20.3, but you have protobuf 3.19.6 which is incompatible.\u001b[0m\u001b[31m\n", + "\u001b[0m" + ] + } + ], + "source": [ + "!pip install -q transformers==4.25.1 tensorflow==2.11.0" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ehfCmKt98WRw" + }, + "source": [ + "- HuggingFace comes with a native `saved_model` feature inside `save_pretrained` function for TensorFlow based models. We will use that to save it as TF `SavedModel`.\n", + "- We'll use [distilbert-base-uncased-mnli](https://huggingface.co/typeform/distilbert-base-uncased-mnli) model from HuggingFace as an example\n", + " - For zero-shot classification, We will usually use models trained on the (m)nli data set for best performance.\n", + "- In addition to `TFDistilBertForSequenceClassification` we also need to save the `DistilBertTokenizer`. This is the same for every model, these are assets needed for tokenization inside Spark NLP." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 475, + "referenced_widgets": [ + "f8a1ac89cf5e4a26ad6d30c03a2b8e4d", + "a2a5e6ebdac742f7b1d4d33123c1c744", + "b10a1e979e3344ee857b1dfdf88ca748", + "851a51a9db664e2ca38b6d195384f47a", + "83fbc4e3720b4791aedc53b77ed3cb19", + "9d2fa61ddebc49d88c0e9e44d83cf36e", + "66238454a296491e8bd5ef008d450e38", + "695c064280044e169f15d9347ab23281", + "726d4e34581145479154f456c696c278", + "6df0bcc1a5f144b1bc7fe6f3cfa6a05f", + "0706c29c1d3c41c2a964f61b2fb72e20", + "9535415ff7c24a38b8d2e50e3774af01", + "660ab89806b94f72986d05f86785a4ad", + "13c72780c3bc4b3987fb5f78bb2a4904", + "f9436cd1fbbe4d75917b1a4171fed17a", + "ace0263340a34ac9b8579d5b4666faa7", + "b23d9d2b72d0488caf2b5dcca31b6314", + "ce75b803b79a43969269fdf6a890e16e", + "c9926babdab441b8a0f2a4c4f59ef92b", + "5399e1fe768345a490175e21a1ece578", + "3abdf0fcd9e54331b70c8a5008f0c1b5", + "58dc8096add9432180fd5c8b6eb101c4", + "c73aaf2bcb674518bb751f508f09292a", + "09818de216094490870001aca113adba", + "856bda93ec8941be917ab99623ca4851", + "ed5990ff8df04548bd259ae01b6e134a", + "8de2583fc33d48fd8e71b70c9919ea24", + "522d095cdad54bd48a22260924ecb9c2", + "f872b7001bcc436985d5cc2e72f7482c", + "cd2d0543ae6242d2a5d9571d5d4d0726", + "262a30c072604c048c828db2cd210176", + "40e34c26038146de9224f61d7bbda5b4", + "66698d2afeab47458a94bb4aa089184f", + "16d73d0610d04e52a0a93a45fe09854c", + "9a1344b59691413480a60f4e3d8cb741", + "d7af4f84b0ee43f7a26bb4ea7d72f048", + "7eb50fb21bf74d0f928741211108ea94", + "b7599ef85f9943e8874a30d5ca071825", + "d0993edddf534814be5e9e9c726e6011", + "01c0067c52b44de59fe07165babc8594", + "abd7567a0bd0495c9a721843786a276d", + "3d71b8d6a7bd4e26a0e8f6577dbc6d93", + "c954347f47894d1daa8fb2802e6ebb92", + "38ee16f0a5d84d92a0be6b3bef817dfe", + "68a558a48ab144a28137110ffc0dcf3b", + "341116f0e5074e3f8ce48b1cb0709b00", + "cc903d6ee439487ab10b881d10167ccf", + "9a094b07268445b0a7de94bb21e92707", + "afaa696d4c2c40df8fcc27ce79f634f1", + "4db458113c9b414d8044cfb8adb7bbba", + "38b3aa539af04122a723bb79bee08485", + "57c1248bde6549f4aa5be8aa9b5a353f", + "830a592bd3594d2cae4fe35c0620b94e", + "4b5a8de4b92d412eb59306984ac327b1", + "5cf0e182cba142e7a7e14bac01e39e04" + ] + }, + "id": "oCOSyDn88WRx", + "outputId": "a2b5b435-eb43-4f62-93ad-1c9275c3a21e" + }, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "Downloading (…)solve/main/vocab.txt: 0%| | 0.00/232k [00:00, because it is not built.\n", + "WARNING:tensorflow:Skipping full serialization of Keras layer , because it is not built.\n", + "WARNING:tensorflow:Skipping full serialization of Keras layer , because it is not built.\n", + "WARNING:tensorflow:Skipping full serialization of Keras layer , because it is not built.\n", + "WARNING:tensorflow:Skipping full serialization of Keras layer , because it is not built.\n", + "WARNING:tensorflow:Skipping full serialization of Keras layer , because it is not built.\n", + "WARNING:absl:Found untraced functions such as embeddings_layer_call_fn, embeddings_layer_call_and_return_conditional_losses, transformer_layer_call_fn, transformer_layer_call_and_return_conditional_losses, LayerNorm_layer_call_fn while saving (showing 5 of 164). These functions will not be directly callable after loading.\n" + ] + } + ], + "source": [ + "from transformers import TFDistilBertForSequenceClassification, DistilBertTokenizer \n", + "import tensorflow as tf\n", + "\n", + "MODEL_NAME = 'typeform/distilbert-base-uncased-mnli'\n", + "\n", + "tokenizer = DistilBertTokenizer.from_pretrained(MODEL_NAME)\n", + "tokenizer.save_pretrained('./{}_tokenizer/'.format(MODEL_NAME))\n", + "\n", + "try:\n", + " model = TFDistilBertForSequenceClassification.from_pretrained(MODEL_NAME)\n", + "except:\n", + " model = TFDistilBertForSequenceClassification.from_pretrained(MODEL_NAME, from_pt=True)\n", + " \n", + "# Define TF Signature\n", + "@tf.function(\n", + " input_signature=[\n", + " {\n", + " \"input_ids\": tf.TensorSpec((None, None), tf.int32, name=\"input_ids\"),\n", + " \"attention_mask\": tf.TensorSpec((None, None), tf.int32, name=\"attention_mask\") \n", + " }\n", + " ]\n", + ")\n", + "def serving_fn(input):\n", + " return model(input)\n", + "\n", + "model.save_pretrained(\"./{}\".format(MODEL_NAME), saved_model=True, signatures={\"serving_default\": serving_fn})\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "eDjo0QGq8WRy" + }, + "source": [ + "Let's have a look inside these two directories and see what we are dealing with:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "daGPGUdz8WRz", + "outputId": "d84e4167-28a1-47f0-f7e5-d28722ffe63a" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "total 261688\n", + "-rw-r--r-- 1 root root 753 Jun 3 15:53 config.json\n", + "drwxr-xr-x 3 root root 4096 Jun 3 15:53 saved_model\n", + "-rw-r--r-- 1 root root 267954880 Jun 3 15:53 tf_model.h5\n" + ] + } + ], + "source": [ + "!ls -l {MODEL_NAME}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "CwQH0R7h8WR1", + "outputId": "8abc85a8-3f94-4b61-9ed4-f52dcf969092" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "total 5008\n", + "drwxr-xr-x 2 root root 4096 Jun 3 15:53 assets\n", + "-rw-r--r-- 1 root root 56 Jun 3 15:53 fingerprint.pb\n", + "-rw-r--r-- 1 root root 80289 Jun 3 15:53 keras_metadata.pb\n", + "-rw-r--r-- 1 root root 5032374 Jun 3 15:53 saved_model.pb\n", + "drwxr-xr-x 2 root root 4096 Jun 3 15:53 variables\n" + ] + } + ], + "source": [ + "!ls -l {MODEL_NAME}/saved_model/1" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "IPztfyM38WR2", + "outputId": "11fad132-29d2-4057-da33-da8b87bcb38b" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "total 236\n", + "-rw-r--r-- 1 root root 125 Jun 3 15:52 special_tokens_map.json\n", + "-rw-r--r-- 1 root root 574 Jun 3 15:52 tokenizer_config.json\n", + "-rw-r--r-- 1 root root 231508 Jun 3 15:52 vocab.txt\n" + ] + } + ], + "source": [ + "!ls -l {MODEL_NAME}_tokenizer" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "gjrYDipS8WR2" + }, + "source": [ + "- As you can see, we need the SavedModel from `saved_model/1/` path\n", + "- We also be needing `vocab.txt` from the tokenizer\n", + "- All we need is to just copy the `vocab.txt` to `saved_model/1/assets` which Spark NLP will look for\n", + "- In addition to vocabs, we also need `labels` and their `ids` which is saved inside the model's config. We will save this inside `labels.txt`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "QnQ0jke38WR3" + }, + "outputs": [], + "source": [ + "asset_path = '{}/saved_model/1/assets'.format(MODEL_NAME)\n", + "\n", + "!cp {MODEL_NAME}_tokenizer/vocab.txt {asset_path}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "WPvOXbeZ8WR4" + }, + "outputs": [], + "source": [ + "# get label2id dictionary \n", + "labels = model.config.label2id\n", + "# sort the dictionary based on the id\n", + "labels = sorted(labels, key=labels.get)\n", + "\n", + "with open(asset_path+'/labels.txt', 'w') as f:\n", + " f.write('\\n'.join(labels))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "UzQ650AZ8WR4" + }, + "source": [ + "Voila! We have our `vocab.txt` and `labels.txt` inside assets directory" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "QcBOfJ918WR4", + "outputId": "10112997-f328-4747-fb7f-0d37072e29e8" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "total 232\n", + "-rw-r--r-- 1 root root 32 Jun 3 15:53 labels.txt\n", + "-rw-r--r-- 1 root root 231508 Jun 3 15:53 vocab.txt\n" + ] + } + ], + "source": [ + "!ls -l {MODEL_NAME}/saved_model/1/assets" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "zk28iNof8WR5" + }, + "source": [ + "## Import and Save DistilBertForZeroShotClassification in Spark NLP\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "J__aVVu48WR5" + }, + "source": [ + "- Let's install and setup Spark NLP in Google Colab\n", + "- This part is pretty easy via our simple script" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "udnbTHNj8WR6", + "outputId": "e0b6b426-2e0d-4be8-d831-bcea5d64f288" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Installing PySpark 3.2.3 and Spark NLP 4.4.3\n", + "setup Colab for PySpark 3.2.3 and Spark NLP 4.4.3\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m281.5/281.5 MB\u001b[0m \u001b[31m1.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m489.8/489.8 kB\u001b[0m \u001b[31m34.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m199.7/199.7 kB\u001b[0m \u001b[31m18.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h Building wheel for pyspark (setup.py) ... \u001b[?25l\u001b[?25hdone\n" + ] + } + ], + "source": [ + "! wget -q http://setup.johnsnowlabs.com/colab.sh -O - | bash" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "5u9B2ldj8WR6" + }, + "source": [ + "Let's start Spark with Spark NLP included via our simple `start()` function" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "twQ6BHyo8WR6" + }, + "outputs": [], + "source": [ + "import sparknlp\n", + "# let's start Spark with Spark NLP\n", + "spark = sparknlp.start()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "rOEy0EXR8WR7" + }, + "source": [ + "- Let's use `loadSavedModel` functon in `DistilBertForZeroShotClassification` which allows us to load TensorFlow model in SavedModel format\n", + "- Most params can be set later when you are loading this model in `DistilBertForZeroShotClassification` in runtime like `setMaxSentenceLength`, so don't worry what you are setting them now\n", + "- `loadSavedModel` accepts two params, first is the path to the TF SavedModel. The second is the SparkSession that is `spark` variable we previously started via `sparknlp.start()`\n", + "- NOTE: `loadSavedModel` accepts local paths in addition to distributed file systems such as `HDFS`, `S3`, `DBFS`, etc. This feature was introduced in Spark NLP 4.2.2 release. Keep in mind the best and recommended way to move/share/reuse Spark NLP models is to use `write.save` so you can use `.load()` from any file systems natively." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "lcqReFJO8WR7" + }, + "outputs": [], + "source": [ + "from sparknlp.annotator import *\n", + "from sparknlp.base import *\n", + "\n", + "zero_shot_classifier = DistilBertForZeroShotClassification.loadSavedModel(\n", + " '{}/saved_model/1'.format(MODEL_NAME),\n", + " spark\n", + " )\\\n", + " .setInputCols([\"document\", \"token\"]) \\\n", + " .setOutputCol(\"class\") \\\n", + " .setCandidateLabels([\"urgent\", \"mobile\", \"travel\", \"movie\", \"music\", \"sport\", \"weather\", \"technology\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "VmHVmBCo8WR9" + }, + "source": [ + "- Let's save it on disk so it is easier to be moved around and also be used later via `.load` function" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "9RBvw6p58WR9" + }, + "outputs": [], + "source": [ + "zero_shot_classifier.write().overwrite().save(\"./{}_spark_nlp\".format(MODEL_NAME))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "DgUg2p0v8WR9" + }, + "source": [ + "Let's clean up stuff we don't need anymore" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "cdBziZhw8WR-" + }, + "outputs": [], + "source": [ + "!rm -rf {MODEL_NAME}_tokenizer {MODEL_NAME}" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "_iwYIQ6U8WR-" + }, + "source": [ + "Awesome 😎 !\n", + "\n", + "This is your DistilBertForSequenceClassification model from HuggingFace 🤗 loaded and saved by Spark NLP 🚀 " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "8JAkr3438WR-", + "outputId": "2ec0cc08-2122-4301-e7dc-84fb91eabf5e" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "total 266440\n", + "-rw-r--r-- 1 root root 272826157 Jun 3 15:58 distilbert_classification_tensorflow\n", + "drwxr-xr-x 5 root root 4096 Jun 3 15:58 fields\n", + "drwxr-xr-x 2 root root 4096 Jun 3 15:58 metadata\n" + ] + } + ], + "source": [ + "! ls -l {MODEL_NAME}_spark_nlp" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "D5c2xWtt8WR-" + }, + "source": [ + "Now let's see how we can use it on other machines, clusters, or any place you wish to use your new and shiny BertForSequenceClassification model 😊 " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "JjxWoPhW8WR_" + }, + "outputs": [], + "source": [ + "zero_shot_classifier_loaded = DistilBertForZeroShotClassification.load(\"./{}_spark_nlp\".format(MODEL_NAME))\\\n", + " .setInputCols([\"document\",'token'])\\\n", + " .setOutputCol(\"class\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "rAITDhUg8WSA" + }, + "source": [ + "This is how you can use your loaded classifier model in Spark NLP 🚀 pipeline:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "b4svOlV88WSA", + "outputId": "da5aefa6-efb2-43f4-9cf4-537cac5afe3b" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "+------------+\n", + "| result|\n", + "+------------+\n", + "| [mobile]|\n", + "|[technology]|\n", + "| [mobile]|\n", + "| [travel]|\n", + "| [weather]|\n", + "| [sport]|\n", + "| [urgent]|\n", + "+------------+\n", + "\n" + ] + } + ], + "source": [ + "from sparknlp.base import *\n", + "from sparknlp.annotator import *\n", + "from pyspark.ml import Pipeline, PipelineModel\n", + "\n", + "document_assembler = DocumentAssembler() \\\n", + " .setInputCol(\"text\") \\\n", + " .setOutputCol(\"document\")\n", + "\n", + "tokenizer = Tokenizer().setInputCols(\"document\").setOutputCol(\"token\")\n", + "\n", + "pipeline = Pipeline(stages=[\n", + " document_assembler,\n", + " tokenizer,\n", + " zero_shot_classifier_loaded\n", + "])\n", + "\n", + "text = [[\"I have a problem with my iphone that needs to be resolved asap!!\"],\n", + " [\"Last week I upgraded my iOS version and ever since then my phone has been overheating whenever I use your app.\"],\n", + " [\"I have a phone and I love it!\"],\n", + " [\"I really want to visit Germany and I am planning to go there next year.\"],\n", + " [\"Let's watch some movies tonight! I am in the mood for a horror movie.\"],\n", + " [\"Have you watched the match yesterday? It was a great game!\"],\n", + " [\"We need to harry up and get to the airport. We are going to miss our flight!\"]]\n", + "\n", + "# create a DataFrame in PySpark\n", + "inputDataset = spark.createDataFrame(text, [\"text\"])\n", + "model = pipeline.fit(inputDataset)\n", + "model.transform(inputDataset).select(\"class.result\").show()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "26gEdXR28WSB" + }, + "source": [ + "That's it! You can now go wild and use hundreds of \n", + "`DistilBertForSequenceClassification` models as zero-shot classifiers from HuggingFace 🤗 in Spark NLP 🚀" + ] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Python [conda env:nlpdev]", + "language": "python", + "name": "conda-env-nlpdev-py" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.16" + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "f8a1ac89cf5e4a26ad6d30c03a2b8e4d": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_a2a5e6ebdac742f7b1d4d33123c1c744", + "IPY_MODEL_b10a1e979e3344ee857b1dfdf88ca748", + "IPY_MODEL_851a51a9db664e2ca38b6d195384f47a" + ], + "layout": "IPY_MODEL_83fbc4e3720b4791aedc53b77ed3cb19" + } + }, + "a2a5e6ebdac742f7b1d4d33123c1c744": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_9d2fa61ddebc49d88c0e9e44d83cf36e", + "placeholder": "​", + "style": "IPY_MODEL_66238454a296491e8bd5ef008d450e38", + "value": "Downloading (…)solve/main/vocab.txt: 100%" + } + }, + "b10a1e979e3344ee857b1dfdf88ca748": { + "model_module": "@jupyter-widgets/controls", + "model_name": "FloatProgressModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_695c064280044e169f15d9347ab23281", + "max": 231508, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_726d4e34581145479154f456c696c278", + "value": 231508 + } + }, + "851a51a9db664e2ca38b6d195384f47a": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_6df0bcc1a5f144b1bc7fe6f3cfa6a05f", + "placeholder": "​", + "style": "IPY_MODEL_0706c29c1d3c41c2a964f61b2fb72e20", + "value": " 232k/232k [00:00<00:00, 1.43MB/s]" + } + }, + "83fbc4e3720b4791aedc53b77ed3cb19": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "9d2fa61ddebc49d88c0e9e44d83cf36e": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "66238454a296491e8bd5ef008d450e38": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "695c064280044e169f15d9347ab23281": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "726d4e34581145479154f456c696c278": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "6df0bcc1a5f144b1bc7fe6f3cfa6a05f": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "0706c29c1d3c41c2a964f61b2fb72e20": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "9535415ff7c24a38b8d2e50e3774af01": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_660ab89806b94f72986d05f86785a4ad", + "IPY_MODEL_13c72780c3bc4b3987fb5f78bb2a4904", + "IPY_MODEL_f9436cd1fbbe4d75917b1a4171fed17a" + ], + "layout": "IPY_MODEL_ace0263340a34ac9b8579d5b4666faa7" + } + }, + "660ab89806b94f72986d05f86785a4ad": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_b23d9d2b72d0488caf2b5dcca31b6314", + "placeholder": "​", + "style": "IPY_MODEL_ce75b803b79a43969269fdf6a890e16e", + "value": "Downloading (…)cial_tokens_map.json: 100%" + } + }, + "13c72780c3bc4b3987fb5f78bb2a4904": { + "model_module": "@jupyter-widgets/controls", + "model_name": "FloatProgressModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_c9926babdab441b8a0f2a4c4f59ef92b", + "max": 112, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_5399e1fe768345a490175e21a1ece578", + "value": 112 + } + }, + "f9436cd1fbbe4d75917b1a4171fed17a": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_3abdf0fcd9e54331b70c8a5008f0c1b5", + "placeholder": "​", + "style": "IPY_MODEL_58dc8096add9432180fd5c8b6eb101c4", + "value": " 112/112 [00:00<00:00, 4.22kB/s]" + } + }, + "ace0263340a34ac9b8579d5b4666faa7": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "b23d9d2b72d0488caf2b5dcca31b6314": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "ce75b803b79a43969269fdf6a890e16e": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "c9926babdab441b8a0f2a4c4f59ef92b": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "5399e1fe768345a490175e21a1ece578": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "3abdf0fcd9e54331b70c8a5008f0c1b5": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "58dc8096add9432180fd5c8b6eb101c4": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "c73aaf2bcb674518bb751f508f09292a": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_09818de216094490870001aca113adba", + "IPY_MODEL_856bda93ec8941be917ab99623ca4851", + "IPY_MODEL_ed5990ff8df04548bd259ae01b6e134a" + ], + "layout": "IPY_MODEL_8de2583fc33d48fd8e71b70c9919ea24" + } + }, + "09818de216094490870001aca113adba": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_522d095cdad54bd48a22260924ecb9c2", + "placeholder": "​", + "style": "IPY_MODEL_f872b7001bcc436985d5cc2e72f7482c", + "value": "Downloading (…)okenizer_config.json: 100%" + } + }, + "856bda93ec8941be917ab99623ca4851": { + "model_module": "@jupyter-widgets/controls", + "model_name": "FloatProgressModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_cd2d0543ae6242d2a5d9571d5d4d0726", + "max": 258, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_262a30c072604c048c828db2cd210176", + "value": 258 + } + }, + "ed5990ff8df04548bd259ae01b6e134a": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_40e34c26038146de9224f61d7bbda5b4", + "placeholder": "​", + "style": "IPY_MODEL_66698d2afeab47458a94bb4aa089184f", + "value": " 258/258 [00:00<00:00, 12.2kB/s]" + } + }, + "8de2583fc33d48fd8e71b70c9919ea24": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "522d095cdad54bd48a22260924ecb9c2": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "f872b7001bcc436985d5cc2e72f7482c": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "cd2d0543ae6242d2a5d9571d5d4d0726": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "262a30c072604c048c828db2cd210176": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "40e34c26038146de9224f61d7bbda5b4": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "66698d2afeab47458a94bb4aa089184f": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "16d73d0610d04e52a0a93a45fe09854c": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_9a1344b59691413480a60f4e3d8cb741", + "IPY_MODEL_d7af4f84b0ee43f7a26bb4ea7d72f048", + "IPY_MODEL_7eb50fb21bf74d0f928741211108ea94" + ], + "layout": "IPY_MODEL_b7599ef85f9943e8874a30d5ca071825" + } + }, + "9a1344b59691413480a60f4e3d8cb741": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_d0993edddf534814be5e9e9c726e6011", + "placeholder": "​", + "style": "IPY_MODEL_01c0067c52b44de59fe07165babc8594", + "value": "Downloading (…)lve/main/config.json: 100%" + } + }, + "d7af4f84b0ee43f7a26bb4ea7d72f048": { + "model_module": "@jupyter-widgets/controls", + "model_name": "FloatProgressModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_abd7567a0bd0495c9a721843786a276d", + "max": 776, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_3d71b8d6a7bd4e26a0e8f6577dbc6d93", + "value": 776 + } + }, + "7eb50fb21bf74d0f928741211108ea94": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_c954347f47894d1daa8fb2802e6ebb92", + "placeholder": "​", + "style": "IPY_MODEL_38ee16f0a5d84d92a0be6b3bef817dfe", + "value": " 776/776 [00:00<00:00, 27.6kB/s]" + } + }, + "b7599ef85f9943e8874a30d5ca071825": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "d0993edddf534814be5e9e9c726e6011": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "01c0067c52b44de59fe07165babc8594": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "abd7567a0bd0495c9a721843786a276d": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "3d71b8d6a7bd4e26a0e8f6577dbc6d93": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "c954347f47894d1daa8fb2802e6ebb92": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "38ee16f0a5d84d92a0be6b3bef817dfe": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "68a558a48ab144a28137110ffc0dcf3b": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_341116f0e5074e3f8ce48b1cb0709b00", + "IPY_MODEL_cc903d6ee439487ab10b881d10167ccf", + "IPY_MODEL_9a094b07268445b0a7de94bb21e92707" + ], + "layout": "IPY_MODEL_afaa696d4c2c40df8fcc27ce79f634f1" + } + }, + "341116f0e5074e3f8ce48b1cb0709b00": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_4db458113c9b414d8044cfb8adb7bbba", + "placeholder": "​", + "style": "IPY_MODEL_38b3aa539af04122a723bb79bee08485", + "value": "Downloading tf_model.h5: 100%" + } + }, + "cc903d6ee439487ab10b881d10167ccf": { + "model_module": "@jupyter-widgets/controls", + "model_name": "FloatProgressModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_57c1248bde6549f4aa5be8aa9b5a353f", + "max": 267954880, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_830a592bd3594d2cae4fe35c0620b94e", + "value": 267954880 + } + }, + "9a094b07268445b0a7de94bb21e92707": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_4b5a8de4b92d412eb59306984ac327b1", + "placeholder": "​", + "style": "IPY_MODEL_5cf0e182cba142e7a7e14bac01e39e04", + "value": " 268M/268M [00:06<00:00, 40.4MB/s]" + } + }, + "afaa696d4c2c40df8fcc27ce79f634f1": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "4db458113c9b414d8044cfb8adb7bbba": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "38b3aa539af04122a723bb79bee08485": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "57c1248bde6549f4aa5be8aa9b5a353f": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "830a592bd3594d2cae4fe35c0620b94e": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "4b5a8de4b92d412eb59306984ac327b1": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "5cf0e182cba142e7a7e14bac01e39e04": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + } + } + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/examples/python/transformers/HuggingFace_in_Spark_NLP_RoBertaForZeroShotClassification.ipynb b/examples/python/transformers/HuggingFace_in_Spark_NLP_RoBertaForZeroShotClassification.ipynb new file mode 100644 index 00000000000000..22ebb65c7945e0 --- /dev/null +++ b/examples/python/transformers/HuggingFace_in_Spark_NLP_RoBertaForZeroShotClassification.ipynb @@ -0,0 +1,2839 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "8IXf_Q668WRo" + }, + "source": [ + "![JohnSnowLabs](https://sparknlp.org/assets/images/logo.png)\n", + "\n", + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/JohnSnowLabs/spark-nlp/blob/master/examples/python/transformers/HuggingFace%20in%20Spark%20NLP%20-%20RoBertaForZeroShotClassification.ipynb)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "fDfihUkE8WRr" + }, + "source": [ + "## Import RoBertaForZeroShotClassification models from HuggingFace 🤗 into Spark NLP 🚀 \n", + "\n", + "Let's keep in mind a few things before we start 😊 \n", + "\n", + "- This feature is only in `Spark NLP 4.4.2` and after. So please make sure you have upgraded to the latest Spark NLP release\n", + "- You can import RoBerta models trained/fine-tuned for sequence classification via `RobertaForSequenceClassification` or `TFRobertaForSequenceClassification`. We can use these models for zero-shot classification.\n", + " - These models are usually under `Sequence Classification` category and have `roberta` in their labels\n", + " - For zero-shot classification, We will usually use models trained on the nli data sets for best performance.\n", + "- Reference: [TFRobertaForSequenceClassification](https://huggingface.co/docs/transformers/v4.29.1/en/model_doc/roberta#transformers.TFRobertaForSequenceClassification)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "vMg3NbLo8WRs" + }, + "source": [ + "## Export and Save HuggingFace model" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Ykej1XKH8WRu" + }, + "source": [ + "- Let's install `HuggingFace` and `TensorFlow`. You don't need `TensorFlow` to be installed for Spark NLP, however, we need it to load and save models from HuggingFace.\n", + "- We lock TensorFlow on `2.11.0` version and Transformers on `4.25.1`. This doesn't mean it won't work with the future releases, but we wanted you to know which versions have been tested successfully." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "yn28bSQi8WRu", + "outputId": "ca6816ea-232a-4d44-8526-d14948561b0a" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.8/5.8 MB\u001b[0m \u001b[31m51.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m588.3/588.3 MB\u001b[0m \u001b[31m2.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m236.8/236.8 kB\u001b[0m \u001b[31m21.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.8/7.8 MB\u001b[0m \u001b[31m95.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.7/1.7 MB\u001b[0m \u001b[31m68.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m58.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m6.0/6.0 MB\u001b[0m \u001b[31m88.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m439.2/439.2 kB\u001b[0m \u001b[31m38.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m4.9/4.9 MB\u001b[0m \u001b[31m104.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", + "tensorflow-datasets 4.9.2 requires protobuf>=3.20, but you have protobuf 3.19.6 which is incompatible.\n", + "tensorflow-metadata 1.13.1 requires protobuf<5,>=3.20.3, but you have protobuf 3.19.6 which is incompatible.\u001b[0m\u001b[31m\n", + "\u001b[0m" + ] + } + ], + "source": [ + "!pip install -q transformers==4.25.1 tensorflow==2.11.0" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ehfCmKt98WRw" + }, + "source": [ + "- HuggingFace comes with a native `saved_model` feature inside `save_pretrained` function for TensorFlow based models. We will use that to save it as TF `SavedModel`.\n", + "- We'll use [cross-encoder/nli-roberta-base](cross-encoder/nli-roberta-base) model from HuggingFace as an example\n", + " - For zero-shot classification, We will usually use models trained on the (m)nli data set for best performance.\n", + "- In addition to `TFRobertaForSequenceClassification` we also need to save the `RobertaTokenizer`. This is the same for every model, these are assets needed for tokenization inside Spark NLP." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 333, + "referenced_widgets": [ + "13304044c4d74d14af90338c5f3d5493", + "8bf57ab59ef7430298a33e893a6b6743", + "a7bc950c63434a4ea756b0689e69c0fb", + "fe6473dcaecc4bba8d4c5b1d24b56dde", + "1777d601c93842d7a8d6439671ffc4b6", + "a2a74759f3544822911bb9d208e19398", + "3f2f0b73ebb845c3be7bb024abcc168c", + "1f7796c027b247808ff57f0bc959c037", + "20c79f8d5e68453c961f8bc758b9cc05", + "b6302a2a129743f981b659778eb11605", + "eed2c7f332f247d9b794ae800c1e2604", + "7683b71b597340f9be1fd0df61d2e26e", + "7d61763ef4324825ac49027c0c503a41", + "317c2bd80fe34fc5beee20f741abe120", + "776446051c914eadb326d288dd25260e", + "bb9dcee68fc240a1863665efc8ac8798", + "4e83111b26c547329461050650b88ca5", + "ce5514402f6141ef8f958447036ff6d7", + "27db9b81b53141719d6916e8d90e43f1", + "f0432c9b017b46acab7bdb2f5839ea3f", + "95a6d8a45782463d90667e58acd39e8f", + "4b1ba2442b554cb9b127b86762211c5c", + "f46f7c60e06347bf8571060e0e33316d", + "8551ac2036b344b7a7c92ee4298aeb0c", + "e906ec85e3894643a6e3b16761600e9f", + "b1e94dce806f4a1bb6fdd780eec5d80e", + "05148ad940fc43b19ec9b4d7af0b339b", + "f397c9ef82fc410187f6e1efe3fc9775", + "9e557261a3c444a79e97cbedf0dd3d0b", + "85720396966f498db2975248ca100c6b", + "5b56199e5e7c447c98bb97b60c4241f5", + "a45f190ad5df4a6b9d058ea1162cfb03", + "aafe8988f20f47afb773dd31a8233ce7", + "9687bb7a41ef4e0d815f33b986d700e6", + "9268ee6bcfdc475c8af00fd363727531", + "66f2971ea05549768a5527a52f397305", + "6e4a2f12e3d34ee491d887d515842d98", + "8a7ac648178c40eb8caf767560e33eb0", + "b5e67cb55e624666ac086f4adc970959", + "2e8660d4bbaa46079abd448d8b628d18", + "df40d99ce76140019b8b47b2c75e752b", + "4a987727e9de47988c5ac0bf8b9e57dc", + "8d8a1992e90e4521bb6b9094540fd4e0", + "b4b8b11edb5043b89921d68d250603d5", + "d44690bc2a8c4466916e71a4c13cf0e1", + "5a9db27686e044c9bb299bcca53ec9e7", + "d7bef74ee03545968ea1cc559006d046", + "7d7007c1b14a40479479868563d26f9a", + "a85c3bdc21e2487b8a0e13bf212e2929", + "957865dabc1044249459439aae39c75d", + "e1ca1239d791483ba4b4300a9742fa36", + "5814640904e44dc8a28a3ff280c3ce9f", + "55a1233a0555432db42f32f37033c8e5", + "96c1b3c4c139440b87742a06ba7f16d8", + "d6f2944954564fed9819ae9cd63cb885", + "7475900bacdd41a185d080c62ad91a12", + "966f7d7fe6084c7d9c2a0c6668a106e5", + "091f5a1c52c24e9fb046e5da680b644f", + "ad95e9b538a24b2ca1f7163852d927bc", + "960d2088d5314f4986fa46a5338f02b7", + "fce6cd5893e64bf59fc4fd579a49a006", + "224f786d7cf1427a85150cde9f8eb09b", + "027552ca8e8c47d3ab477aebd8b53e03", + "c95275785ae441b186a887bbbbf38611", + "2a84bfe6579f47dea24d804b1671334c", + "4c08c4c39642458d96ed17d99943fcc3" + ] + }, + "id": "oCOSyDn88WRx", + "outputId": "3d224291-7e2a-4f8e-d5fd-3c045ec14bc3" + }, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "Downloading (…)olve/main/vocab.json: 0%| | 0.00/899k [00:00