diff --git a/CHANGELOG.md b/CHANGELOG.md index 40184b130..09f1b640e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ All notable changes to this project will be documented in this file. ### Fixed - Fixed a crash due to missing resources when refitting the `CRFSlotFiller` [#771](https://github.com/snipsco/snips-nlu/pull/771) - Fixed issue with egg fragments in download cli [#769](https://github.com/snipsco/snips-nlu/pull/769) +- Fixed an issue causing the `None` intent to be ignored when using the `parse` API in conjunction with `intents` and `top_n` [#781](https://github.com/snipsco/snips-nlu/pull/781) ## [0.19.4] - 2019-03-06 ### Added diff --git a/snips_nlu/nlu_engine/nlu_engine.py b/snips_nlu/nlu_engine/nlu_engine.py index a301fd477..7496190b7 100644 --- a/snips_nlu/nlu_engine/nlu_engine.py +++ b/snips_nlu/nlu_engine/nlu_engine.py @@ -178,7 +178,8 @@ def parse(self, text, intents=None, top_n=None): intents_results = self.get_intents(text) if intents is not None: intents_results = [res for res in intents_results - if res[RES_INTENT_NAME] in intents] + if res[RES_INTENT_NAME] is None + or res[RES_INTENT_NAME] in intents] intents_results = intents_results[:top_n] results = [] for intent_res in intents_results: diff --git a/snips_nlu/tests/test_nlu_engine.py b/snips_nlu/tests/test_nlu_engine.py index 87448ff8d..84ee94ba2 100644 --- a/snips_nlu/tests/test_nlu_engine.py +++ b/snips_nlu/tests/test_nlu_engine.py @@ -105,6 +105,8 @@ def get_slots(self, text, intent): # When results = nlu_engine.parse(text, top_n=3) + results_with_filter = nlu_engine.parse( + text, intents=["intent1", "intent3"], top_n=3) # Then expected_results = [ @@ -123,7 +125,23 @@ def get_slots(self, text, intent): [] ), ] + expected_results_with_filter = [ + extraction_result( + intent_classification_result("intent1", 0.5), + [custom_slot( + unresolved_slot((0, 3), "foo", "entity1", "slot1"))] + ), + extraction_result( + intent_classification_result(None, 0.15), + [] + ), + extraction_result( + intent_classification_result("intent3", 0.05), + [] + ), + ] self.assertListEqual(expected_results, results) + self.assertListEqual(expected_results_with_filter, results_with_filter) def test_should_get_intents(self): # Given