Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP make possible to load several components with the same name #1546

Closed
wants to merge 5 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions rasa_nlu/components.py
Original file line number Diff line number Diff line change
@@ -17,7 +17,7 @@
from typing import Hashable

from rasa_nlu import config
from rasa_nlu.config import RasaNLUModelConfig
from rasa_nlu.config import RasaNLUModelConfig, override_defaults
from rasa_nlu.training_data import Message

if typing.TYPE_CHECKING:
@@ -178,6 +178,7 @@ class Component(object):
language_list = None

def __init__(self, component_config=None):
# type: (Dict[Text, Any]) -> None
if not component_config:
component_config = {}

@@ -206,7 +207,7 @@ def load(cls,
model_dir=None, # type: Optional[Text]
model_metadata=None, # type: Optional[Metadata]
cached_component=None, # type: Optional[Component]
**kwargs # type: **Any
**kwargs # type: Any
):
# type: (...) -> Component
"""Load this component from file.
@@ -225,19 +226,19 @@ def load(cls,
return cls(component_config)

@classmethod
def create(cls, cfg):
# type: (RasaNLUModelConfig) -> Component
def create(cls, component_config):
# type: (Dict[Text, Any]) -> Component
"""Creates this component (e.g. before a training is started).
Method can access all configuration parameters."""

# Check language supporting
language = cfg.language
language = component_config.get('language')
if not cls.can_handle_language(language):
# check failed
raise UnsupportedLanguageError(cls.name, language)

return cls(cfg.for_component(cls.name, cls.defaults))
return cls(override_defaults(cls.defaults, component_config))

def provide_context(self):
# type: () -> Optional[Dict[Text, Any]]
@@ -427,15 +428,15 @@ def load_component(self,
"{}".format(component_name, e))

def create_component(self, component_name, cfg):
# type: (Text, RasaNLUModelConfig) -> Component
# type: (Text, Dict) -> Component
"""Tries to retrieve a component from the cache,
calls `create` to create a new component."""
from rasa_nlu import registry
from rasa_nlu.model import Metadata

try:
component, cache_key = self.__get_cached_component(
component_name, Metadata(cfg.as_dict(), None))
component_name, Metadata(cfg, None))
if component is None:
component = registry.create_component_by_name(component_name,
cfg)
4 changes: 1 addition & 3 deletions rasa_nlu/extractors/crf_entity_extractor.py
Original file line number Diff line number Diff line change
@@ -119,9 +119,7 @@ def required_packages(cls):
return ["sklearn_crfsuite", "sklearn"]

def train(self, training_data, config, **kwargs):
# type: (TrainingData, RasaNLUModelConfig) -> None

self.component_config = config.for_component(self.name, self.defaults)
# type: (TrainingData, RasaNLUModelConfig, Any) -> None

self._validate_configuration()

14 changes: 7 additions & 7 deletions rasa_nlu/extractors/duckling_extractor.py
Original file line number Diff line number Diff line change
@@ -13,7 +13,7 @@
from typing import Optional
from typing import Text

from rasa_nlu.config import RasaNLUModelConfig
from rasa_nlu.config import override_defaults
from rasa_nlu.extractors import EntityExtractor
from rasa_nlu.model import Metadata
from rasa_nlu.training_data import Message
@@ -110,11 +110,11 @@ def create_duckling_wrapper(cls, language):
raise Exception("Duckling error. {}".format(e))

@classmethod
def create(cls, config):
# type: (RasaNLUModelConfig) -> DucklingExtractor
def create(cls, component_config):
# type: (Dict[Text, Any]) -> DucklingExtractor

component_config = config.for_component(cls.name, cls.defaults)
dims = component_config.get("dimensions")
component_conf = override_defaults(cls.defaults, component_config)
dims = component_conf.get("dimensions")
if dims:
unknown_dimensions = [dim
for dim in dims
@@ -125,8 +125,8 @@ def create(cls, config):
"".format(", ".join(unknown_dimensions),
", ".join(cls.available_dimensions())))

wrapper = cls.create_duckling_wrapper(config["language"])
return DucklingExtractor(component_config, wrapper)
wrapper = cls.create_duckling_wrapper(component_conf["language"])
return DucklingExtractor(component_conf, wrapper)

@classmethod
def cache_key(cls, model_metadata):
18 changes: 7 additions & 11 deletions rasa_nlu/extractors/duckling_http_extractor.py
Original file line number Diff line number Diff line change
@@ -9,12 +9,9 @@

import requests
import simplejson
from typing import Any
from typing import List
from typing import Optional
from typing import Text
from typing import Any, List, Optional, Text, Dict

from rasa_nlu.config import RasaNLUModelConfig
from rasa_nlu.config import override_defaults
from rasa_nlu.extractors import EntityExtractor
from rasa_nlu.extractors.duckling_extractor import (
filter_irrelevant_matches, convert_duckling_format_to_rasa)
@@ -49,18 +46,17 @@ class DucklingHTTPExtractor(EntityExtractor):
}

def __init__(self, component_config=None, language=None):
# type: (Text, Optional[List[Text]]) -> None
# type: (Dict[Text, Any], Optional[List[Text]]) -> None

super(DucklingHTTPExtractor, self).__init__(component_config)
self.language = language

@classmethod
def create(cls, config):
# type: (RasaNLUModelConfig) -> DucklingHTTPExtractor
def create(cls, component_config):
# type: (Dict[Text, Any]) -> DucklingHTTPExtractor

return cls(config.for_component(cls.name,
cls.defaults),
config.language)
return cls(override_defaults(cls.defaults, component_config),
component_config.get('language'))

def _locale(self):
if not self.component_config.get("locale"):
23 changes: 15 additions & 8 deletions rasa_nlu/model.py
Original file line number Diff line number Diff line change
@@ -157,19 +157,22 @@ def __init__(self,
@staticmethod
def _build_pipeline(cfg, component_builder):
# type: (RasaNLUModelConfig, ComponentBuilder) -> List
"""Transform the passed names of the pipeline components into classes"""
"""Transform the passed names of the pipeline components into classes
"""
pipeline = []

# Transform the passed names of the pipeline components into classes
for component_name in cfg.component_names:
for component in cfg.pipeline:
component_config = component.copy()
component_name = component_config.pop('name')
component_config['language'] = cfg.language
component = component_builder.create_component(
component_name, cfg)
component_name, component_config)
pipeline.append(component)

return pipeline

def train(self, data, **kwargs):
# type: (TrainingData) -> Interpreter
# type: (TrainingData, Any) -> Interpreter
"""Trains the underlying pipeline using the provided training data."""

self.training_data = data
@@ -200,9 +203,13 @@ def train(self, data, **kwargs):

return Interpreter(self.pipeline, context)

def persist(self, path, persistor=None, project_name=None,
fixed_model_name=None):
# type: (Text, Optional[Persistor], Text) -> Text
def persist(self,
path, # type: Text
persistor=None, # type: Optional[Persistor]
project_name=None, # type: Optional[Text]
fixed_model_name=None # type: Optional[Text]
):
# type: (...) -> Text
"""Persist all components of the pipeline to the passed path.
Returns the directory of the persisted model."""
15 changes: 6 additions & 9 deletions rasa_nlu/registry.py
Original file line number Diff line number Diff line change
@@ -10,10 +10,7 @@

import typing
from rasa_nlu import utils
from typing import Any
from typing import Optional
from typing import Text
from typing import Type
from typing import Any, Optional, Text, Type, Dict

from rasa_nlu.classifiers.keyword_intent_classifier import \
KeywordIntentClassifier
@@ -124,7 +121,7 @@ def load_component_by_name(component_name, # type: Text
model_dir, # type: Text
metadata, # type: Metadata
cached_component, # type: Optional[Component]
**kwargs # type: **Any
**kwargs # type: Any
):
# type: (...) -> Optional[Component]
"""Resolves a component and calls its load method to init it based on a
@@ -134,10 +131,10 @@ def load_component_by_name(component_name, # type: Text
return component_clz.load(model_dir, metadata, cached_component, **kwargs)


def create_component_by_name(component_name, config):
# type: (Text, RasaNLUModelConfig) -> Optional[Component]
def create_component_by_name(component_name, component_config):
# type: (Text, Dict) -> Optional[Component]
"""Resolves a component and calls it's create method to init it based on a
previously persisted model."""

component_clz = get_component_class(component_name)
return component_clz.create(config)
component_cls = get_component_class(component_name)
return component_cls.create(component_config)
8 changes: 4 additions & 4 deletions rasa_nlu/utils/mitie_utils.py
Original file line number Diff line number Diff line change
@@ -14,7 +14,7 @@
from typing import Text

from rasa_nlu.components import Component
from rasa_nlu.config import RasaNLUModelConfig
from rasa_nlu.config import RasaNLUModelConfig, override_defaults
from rasa_nlu.model import Metadata

if typing.TYPE_CHECKING:
@@ -49,11 +49,11 @@ def required_packages(cls):
return ["mitie"]

@classmethod
def create(cls, cfg):
# type: (RasaNLUModelConfig) -> MitieNLP
def create(cls, component_config):
# type: (Dict[Text, Any]) -> MitieNLP
import mitie

component_conf = cfg.for_component(cls.name, cls.defaults)
component_conf = override_defaults(cls.defaults, component_config)
model_file = component_conf.get("model")
if not model_file:
raise Exception("The MITIE component 'nlp_mitie' needs "
12 changes: 6 additions & 6 deletions rasa_nlu/utils/spacy_utils.py
Original file line number Diff line number Diff line change
@@ -13,7 +13,7 @@
from typing import Text

from rasa_nlu.components import Component
from rasa_nlu.config import RasaNLUModelConfig
from rasa_nlu.config import RasaNLUModelConfig, override_defaults
from rasa_nlu.training_data import Message
from rasa_nlu.training_data import TrainingData

@@ -55,17 +55,17 @@ def required_packages(cls):
return ["spacy"]

@classmethod
def create(cls, cfg):
# type: (RasaNLUModelConfig) -> SpacyNLP
def create(cls, component_config):
# type: (Dict[Text, Any]) -> SpacyNLP
import spacy

component_conf = cfg.for_component(cls.name, cls.defaults)
component_conf = override_defaults(cls.defaults, component_config)
spacy_model_name = component_conf.get("model")

# if no model is specified, we fall back to the language string
if not spacy_model_name:
spacy_model_name = cfg.language
component_conf["model"] = cfg.language
spacy_model_name = component_config.get('language')
component_conf["model"] = component_config.get('language')

logger.info("Trying to load spacy model with "
"name '{}'".format(spacy_model_name))