diff --git a/python/sparknlp/annotator/er/entity_ruler.py b/python/sparknlp/annotator/er/entity_ruler.py index d470dec59ab2ac..daae01cfd74cd6 100755 --- a/python/sparknlp/annotator/er/entity_ruler.py +++ b/python/sparknlp/annotator/er/entity_ruler.py @@ -228,5 +228,5 @@ def pretrained(name, lang="en", remote_loc=None): @staticmethod def loadStorage(path, spark, storage_ref): - HasStorageModel.loadStorages(path, spark, storage_ref, EntityRulerModel.databases) + HasStorageModel.loadStorages(path, spark, storage_ref, EntityRulerModel.database) diff --git a/python/sparknlp/base/light_pipeline.py b/python/sparknlp/base/light_pipeline.py index d17c5e8fb2b695..0622652fc01a42 100644 --- a/python/sparknlp/base/light_pipeline.py +++ b/python/sparknlp/base/light_pipeline.py @@ -75,7 +75,7 @@ def _validateStagesInputCols(self, stages): input_cols = stage.getInputCols() if type(input_cols) == str: input_cols = [input_cols] - input_annotator_types = stage.inputAnnotatorTypes + input_annotator_types = stage.inputAnnotatorTypes + stage.optionalInputAnnotatorTypes for input_col in input_cols: annotator_type = annotator_types.get(input_col) if annotator_type is None or annotator_type not in input_annotator_types: diff --git a/python/test/annotator/er/entity_ruler_test.py b/python/test/annotator/er/entity_ruler_test.py index f371f38d10759f..b50195a9b963e5 100644 --- a/python/test/annotator/er/entity_ruler_test.py +++ b/python/test/annotator/er/entity_ruler_test.py @@ -64,4 +64,24 @@ def runTest(self): self.assertTrue(result.select("entity").count() > 0) +@pytest.mark.fast +class EntityRulerLightPipelineTestSpec(unittest.TestCase): + def setUp(self): + self.empty_df = SparkContextForTest.spark.createDataFrame([[""]]).toDF("text") + self.path = os.getcwd() + "/../src/test/resources/entity-ruler/url_regex.json" + + def runTest(self): + document_assembler = DocumentAssembler().setInputCol("text").setOutputCol("document") + tokenizer = Tokenizer().setInputCols('document').setOutputCol('token') + + entity_ruler = EntityRulerApproach() \ + .setInputCols(["document", "token"]) \ + .setOutputCol("entity") \ + .setPatternsResource(self.path) + + pipeline = Pipeline(stages=[document_assembler, tokenizer, entity_ruler]) + pipeline_model = pipeline.fit(self.empty_df) + light_pipeline = LightPipeline(pipeline_model) + result = light_pipeline.annotate("This is Google's URI http://google.com. And this is Yahoo's URI http://yahoo.com") + self.assertTrue(len(result["entity"]) == 2) diff --git a/src/main/scala/com/johnsnowlabs/nlp/LightPipeline.scala b/src/main/scala/com/johnsnowlabs/nlp/LightPipeline.scala index 9ed103d060b017..2271bd945c64b5 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/LightPipeline.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/LightPipeline.scala @@ -296,7 +296,7 @@ class LightPipeline(val pipelineModel: PipelineModel, parseEmbeddings: Boolean = inputCols = inputCols ++ optionalColumns } - inputCols + inputCols.distinct } def fullAnnotateJava(target: String): java.util.Map[String, java.util.List[IAnnotation]] = { diff --git a/src/test/resources/entity-ruler/url_regex.json b/src/test/resources/entity-ruler/url_regex.json new file mode 100644 index 00000000000000..5ce2fcd8ebcd4d --- /dev/null +++ b/src/test/resources/entity-ruler/url_regex.json @@ -0,0 +1,8 @@ +[ + { + "id": "url-google", + "label": "URL", + "patterns": ["((?:(?:http|https)://)?(www.)?[a-zA-Z0-9@:%._\\+~#?&//=]{2,256}\\.(?:com|org|net|int|edu|gov|mil)(?:\\.[-a-zA-Z0-9:%_\\+~#?&//=]+)?)"], + "regex": true + } +] \ No newline at end of file