diff --git a/compose.yml b/compose.yml index 273e49bb7..dc477fd78 100644 --- a/compose.yml +++ b/compose.yml @@ -91,6 +91,18 @@ services: retries: 5 start_period: 30s + ner: + image: deeppavlov/deeppavlov:latest + profiles: + - extras + environment: + - CONFIG=ner_conll2003_bert + restart: unless-stopped + ports: + - 5000:5000 + volumes: + - ~/.deeppavlov:/root/.deeppavlov/ + - ~/.cache:/root/.cache/ dashboard: env_file: [.env_file] build: diff --git a/dff/pipeline/pipeline/actor.py b/dff/pipeline/pipeline/actor.py index c4f49a082..33a3d2daa 100644 --- a/dff/pipeline/pipeline/actor.py +++ b/dff/pipeline/pipeline/actor.py @@ -37,7 +37,7 @@ from dff.script.core.script import Script, Node from dff.script.core.normalization import normalize_label, normalize_response from dff.script.core.keywords import GLOBAL, LOCAL -from dff.pipeline.service.utils import wrap_sync_function_in_async +from dff.utils.devel.async_helpers import wrap_sync_function_in_async logger = logging.getLogger(__name__) diff --git a/dff/pipeline/pipeline/pipeline.py b/dff/pipeline/pipeline/pipeline.py index a96b7b814..b036c2dc0 100644 --- a/dff/pipeline/pipeline/pipeline.py +++ b/dff/pipeline/pipeline/pipeline.py @@ -25,6 +25,7 @@ from dff.messengers.console import CLIMessengerInterface from dff.messengers.common import MessengerInterface +from dff.slots.slots import GroupSlot from ..service.group import ServiceGroup from ..types import ( ServiceBuilder, @@ -56,6 +57,7 @@ class Pipeline: :param label_priority: Default priority value for all actor :py:const:`labels ` where there is no priority. Defaults to `1.0`. :param condition_handler: Handler that processes a call of actor condition functions. Defaults to `None`. + :param slots: Slots configuration. :param handlers: This variable is responsible for the usage of external handlers on the certain stages of work of :py:class:`~dff.script.Actor`. @@ -89,6 +91,7 @@ def __init__( fallback_label: Optional[NodeLabel2Type] = None, label_priority: float = 1.0, condition_handler: Optional[Callable] = None, + slots: Optional[Union[GroupSlot, Dict]] = None, handlers: Optional[Dict[ActorStage, List[Callable]]] = None, messenger_interface: Optional[MessengerInterface] = None, context_storage: Optional[Union[DBContextStorage, Dict]] = None, @@ -101,6 +104,7 @@ def __init__( self.actor: Actor = None self.messenger_interface = CLIMessengerInterface() if messenger_interface is None else messenger_interface self.context_storage = {} if context_storage is None else context_storage + self.slots = GroupSlot.model_validate(slots) if slots is not None else None self._services_pipeline = ServiceGroup( components, before_handler=before_handler, @@ -208,6 +212,7 @@ def from_script( fallback_label: Optional[NodeLabel2Type] = None, label_priority: float = 1.0, condition_handler: Optional[Callable] = None, + slots: Optional[Union[GroupSlot, Dict]] = None, parallelize_processing: bool = False, handlers: Optional[Dict[ActorStage, List[Callable]]] = None, context_storage: Optional[Union[DBContextStorage, Dict]] = None, @@ -229,6 +234,7 @@ def from_script( :param label_priority: Default priority value for all actor :py:const:`labels ` where there is no priority. Defaults to `1.0`. :param condition_handler: Handler that processes a call of actor condition functions. Defaults to `None`. + :param slots: Slots configuration. :param parallelize_processing: This flag determines whether or not the functions defined in the ``PRE_RESPONSE_PROCESSING`` and ``PRE_TRANSITIONS_PROCESSING`` sections of the script should be parallelized over respective groups. @@ -257,6 +263,7 @@ def from_script( fallback_label=fallback_label, label_priority=label_priority, condition_handler=condition_handler, + slots=slots, parallelize_processing=parallelize_processing, handlers=handlers, messenger_interface=messenger_interface, @@ -320,6 +327,9 @@ async def _run_pipeline( if update_ctx_misc is not None: ctx.misc.update(update_ctx_misc) + if self.slots is not None: + ctx.framework_data.slot_manager.set_root_slot(self.slots) + ctx.add_request(request) result = await self._services_pipeline(ctx, self) diff --git a/dff/pipeline/service/extra.py b/dff/pipeline/service/extra.py index caebbd58c..f9194ec4f 100644 --- a/dff/pipeline/service/extra.py +++ b/dff/pipeline/service/extra.py @@ -14,7 +14,8 @@ from dff.script import Context -from .utils import collect_defined_constructor_parameters_to_dict, _get_attrs_with_updates, wrap_sync_function_in_async +from .utils import collect_defined_constructor_parameters_to_dict, _get_attrs_with_updates +from dff.utils.devel.async_helpers import wrap_sync_function_in_async from ..types import ( ServiceRuntimeInfo, ExtraHandlerType, diff --git a/dff/pipeline/service/service.py b/dff/pipeline/service/service.py index c0834fee4..9eae76468 100644 --- a/dff/pipeline/service/service.py +++ b/dff/pipeline/service/service.py @@ -17,7 +17,8 @@ from dff.script import Context -from .utils import wrap_sync_function_in_async, collect_defined_constructor_parameters_to_dict, _get_attrs_with_updates +from .utils import collect_defined_constructor_parameters_to_dict, _get_attrs_with_updates +from dff.utils.devel.async_helpers import wrap_sync_function_in_async from ..types import ( ServiceBuilder, StartConditionCheckerFunction, diff --git a/dff/pipeline/service/utils.py b/dff/pipeline/service/utils.py index b92b952e9..06d168ad2 100644 --- a/dff/pipeline/service/utils.py +++ b/dff/pipeline/service/utils.py @@ -5,24 +5,7 @@ These functions provide a variety of utility functionality. """ -import asyncio -from typing import Callable, Any, Optional, Tuple, Mapping - - -async def wrap_sync_function_in_async(func: Callable, *args, **kwargs) -> Any: - """ - Utility function, that wraps both functions and coroutines in coroutines. - Invokes `func` if it is just a callable and awaits, if this is a coroutine. - - :param func: Callable to wrap. - :param \\*args: Function args. - :param \\**kwargs: Function kwargs. - :return: What function returns. - """ - if asyncio.iscoroutinefunction(func): - return await func(*args, **kwargs) - else: - return func(*args, **kwargs) +from typing import Any, Optional, Tuple, Mapping def _get_attrs_with_updates( diff --git a/dff/script/core/context.py b/dff/script/core/context.py index 03c2cae5e..c84e5eee4 100644 --- a/dff/script/core/context.py +++ b/dff/script/core/context.py @@ -27,6 +27,7 @@ from dff.script.core.message import Message from dff.script.core.types import NodeLabel2Type from dff.pipeline.types import ComponentExecutionState +from dff.slots.slots import SlotManager if TYPE_CHECKING: from dff.script.core.script import Node @@ -56,6 +57,8 @@ class FrameworkData(BaseModel): "Actor service data. Cleared at the end of every turn." stats: Dict[str, Any] = Field(default_factory=dict) "Enables complex stats collection across multiple turns." + slot_manager: SlotManager = Field(default_factory=SlotManager) + "Stores extracted slots." class Context(BaseModel): diff --git a/dff/script/core/normalization.py b/dff/script/core/normalization.py index e6926227d..2784b2647 100644 --- a/dff/script/core/normalization.py +++ b/dff/script/core/normalization.py @@ -33,9 +33,11 @@ def normalize_label(label: Label, default_flow_label: LabelType = "") -> Label: """ if callable(label): - def get_label_handler(ctx: Context, pipeline: Pipeline) -> ConstLabel: + def get_label_handler(ctx: Context, pipeline: Pipeline) -> Optional[ConstLabel]: try: new_label = label(ctx, pipeline) + if new_label is None: + return None new_label = normalize_label(new_label, default_flow_label) flow_label, node_label, _ = new_label node = pipeline.script.get(flow_label, {}).get(node_label) diff --git a/dff/slots/__init__.py b/dff/slots/__init__.py new file mode 100644 index 000000000..7579a8523 --- /dev/null +++ b/dff/slots/__init__.py @@ -0,0 +1,7 @@ +# -*- coding: utf-8 -*- +# flake8: noqa: F401 + +from dff.slots.slots import GroupSlot, ValueSlot, RegexpSlot, FunctionSlot +from dff.slots.conditions import slots_extracted +from dff.slots.processing import extract, extract_all, unset, unset_all, fill_template +from dff.slots.response import filled_template diff --git a/dff/slots/conditions.py b/dff/slots/conditions.py new file mode 100644 index 000000000..80ebdf222 --- /dev/null +++ b/dff/slots/conditions.py @@ -0,0 +1,32 @@ +""" +Conditions +--------------------------- +Provides slot-related conditions. +""" + +from __future__ import annotations +from typing import TYPE_CHECKING, Literal + +if TYPE_CHECKING: + from dff.script import Context + from dff.slots.slots import SlotName + from dff.pipeline import Pipeline + + +def slots_extracted(*slots: SlotName, mode: Literal["any", "all"] = "all"): + """ + Conditions that checks if slots are extracted. + + :param slots: Names for slots that need to be checked. + :param mode: Whether to check if all slots are extracted or any slot is extracted. + """ + + def check_slot_state(ctx: Context, pipeline: Pipeline) -> bool: + manager = ctx.framework_data.slot_manager + if mode == "all": + return all(manager.is_slot_extracted(slot) for slot in slots) + elif mode == "any": + return any(manager.is_slot_extracted(slot) for slot in slots) + raise ValueError(f"{mode!r} not in ['any', 'all'].") + + return check_slot_state diff --git a/dff/slots/processing.py b/dff/slots/processing.py new file mode 100644 index 000000000..1f99c4c23 --- /dev/null +++ b/dff/slots/processing.py @@ -0,0 +1,98 @@ +""" +Processing +--------------------------- +This module provides wrappers for :py:class:`~dff.slots.slots.SlotManager`'s API. +""" + +from __future__ import annotations + +import logging +from typing import Awaitable, Callable, TYPE_CHECKING + +if TYPE_CHECKING: + from dff.slots.slots import SlotName + from dff.script import Context + from dff.pipeline import Pipeline + +logger = logging.getLogger(__name__) + + +def extract(*slots: SlotName) -> Callable[[Context, Pipeline], Awaitable[None]]: + """ + Extract slots listed slots. + This will override all slots even if they are already extracted. + + :param slots: List of slot names to extract. + """ + + async def inner(ctx: Context, pipeline: Pipeline) -> None: + manager = ctx.framework_data.slot_manager + for slot in slots: # todo: maybe gather + await manager.extract_slot(slot, ctx, pipeline) + + return inner + + +def extract_all(): + """ + Extract all slots defined in the pipeline. + """ + + async def inner(ctx: Context, pipeline: Pipeline): + manager = ctx.framework_data.slot_manager + await manager.extract_all(ctx, pipeline) + + return inner + + +def unset(*slots: SlotName) -> Callable[[Context, Pipeline], None]: + """ + Mark specified slots as not extracted and clear extracted values. + + :param slots: List of slot names to extract. + """ + + def unset_inner(ctx: Context, pipeline: Pipeline) -> None: + manager = ctx.framework_data.slot_manager + for slot in slots: + manager.unset_slot(slot) + + return unset_inner + + +def unset_all(): + """ + Mark all slots as not extracted and clear all extracted values. + """ + + def inner(ctx: Context, pipeline: Pipeline): + manager = ctx.framework_data.slot_manager + manager.unset_all_slots() + + return inner + + +def fill_template() -> Callable[[Context, Pipeline], None]: + """ + Fill the response template in the current node. + + Response message of the current node should be a format-string: e.g. "Your username is {profile.username}". + """ + + def inner(ctx: Context, pipeline: Pipeline) -> None: + manager = ctx.framework_data.slot_manager + # get current node response + response = ctx.current_node.response + + if response is None: + return + + if callable(response): + response = response(ctx, pipeline) + + new_text = manager.fill_template(response.text) + + response.text = new_text + ctx.current_node.response = response + + return inner diff --git a/dff/slots/response.py b/dff/slots/response.py new file mode 100644 index 000000000..152b79ddb --- /dev/null +++ b/dff/slots/response.py @@ -0,0 +1,34 @@ +""" +Response +--------------------------- +Slot-related DFF responses. +""" + +from __future__ import annotations +from typing import Callable, TYPE_CHECKING + +if TYPE_CHECKING: + from dff.script import Context, Message + from dff.pipeline import Pipeline + + +def filled_template(template: Message) -> Callable[[Context, Pipeline], Message]: + """ + Fill template with slot values. + The `text` attribute of the template message should be a format-string: + e.g. "Your username is {profile.username}". + + For the example above, if ``profile.username`` slot has value "admin", + it would return a copy of the message with the following text: + "Your username is admin". + + :param template: Template message with a format-string text. + """ + + def fill_inner(ctx: Context, pipeline: Pipeline) -> Message: + message = template.model_copy() + new_text = ctx.framework_data.slot_manager.fill_template(template.text) + message.text = new_text + return message + + return fill_inner diff --git a/dff/slots/slots.py b/dff/slots/slots.py new file mode 100644 index 000000000..536501751 --- /dev/null +++ b/dff/slots/slots.py @@ -0,0 +1,418 @@ +""" +Slots +----- +This module defines base classes for slots and some concrete implementations of them. +""" + +from __future__ import annotations + +import asyncio +import re +from abc import ABC, abstractmethod +from typing import Callable, Any, Awaitable, TYPE_CHECKING, Union +from typing_extensions import TypeAlias +import logging +from functools import reduce + +from pydantic import BaseModel, model_validator, Field + +from dff.utils.devel.async_helpers import wrap_sync_function_in_async +from dff.utils.devel.json_serialization import PickleEncodedValue + +if TYPE_CHECKING: + from dff.script import Context, Message + from dff.pipeline.pipeline.pipeline import Pipeline + + +logger = logging.getLogger(__name__) + + +SlotName: TypeAlias = str +""" +A string to identify slots. + +Top-level slots are identified by their key in a :py:class:`~.GroupSlot`. + +E.g. + +.. code:: python + + GroupSlot( + user=RegexpSlot(), + password=FunctionSlot, + ) + +Has two slots with names "user" and "password". + +For nested group slots use dots to separate names: + +.. code:: python + + GroupSlot( + user=GroupSlot( + name=FunctionSlot, + password=FunctionSlot, + ) + ) + +Has two slots with names "user.name" and "user.password". +""" + + +def recursive_getattr(obj, slot_name: SlotName): + def two_arg_getattr(__o, name): + # pydantic handles exception when accessing a non-existing extra-field on its own + # return None by default to avoid that + return getattr(__o, name, None) + + return reduce(two_arg_getattr, [obj, *slot_name.split(".")]) + + +def recursive_setattr(obj, slot_name: SlotName, value): + parent_slot, _, slot = slot_name.rpartition(".") + + if parent_slot: + setattr(recursive_getattr(obj, parent_slot), slot, value) + else: + setattr(obj, slot, value) + + +class SlotNotExtracted(Exception): + """This exception can be returned or raised by slot extractor if slot extraction is unsuccessful.""" + + pass + + +class ExtractedSlot(BaseModel, ABC): + """ + Represents value of an extracted slot. + + Instances of this class are managed by framework and + are stored in :py:attr:`~dff.script.core.context.FrameworkData.slot_manager`. + They can be accessed via the ``ctx.framework_data.slot_manager.get_extracted_slot`` method. + """ + + @property + @abstractmethod + def __slot_extracted__(self) -> bool: + """Whether the slot is extracted.""" + raise NotImplementedError + + def __unset__(self): + """Mark slot as not extracted and clear extracted data (except for default value).""" + raise NotImplementedError + + @abstractmethod + def __str__(self): + """String representation is used to fill templates.""" + raise NotImplementedError + + +class ExtractedValueSlot(ExtractedSlot): + """Value extracted from :py:class:`~.ValueSlot`.""" + + is_slot_extracted: bool + extracted_value: PickleEncodedValue + default_value: PickleEncodedValue = None + + @property + def __slot_extracted__(self) -> bool: + return self.is_slot_extracted + + def __unset__(self): + self.is_slot_extracted = False + self.extracted_value = SlotNotExtracted("Slot manually unset.") + + @property + def value(self): + """Extracted value or the default value if the slot is not extracted.""" + return self.extracted_value if self.is_slot_extracted else self.default_value + + def __str__(self): + return str(self.value) + + +class ExtractedGroupSlot(ExtractedSlot, extra="allow"): + __pydantic_extra__: dict[str, Union["ExtractedValueSlot", "ExtractedGroupSlot"]] + + @property + def __slot_extracted__(self) -> bool: + return all([slot.__slot_extracted__ for slot in self.__pydantic_extra__.values()]) + + def __unset__(self): + for child in self.__pydantic_extra__.values(): + child.__unset__() + + def __str__(self): + return str({key: str(value) for key, value in self.__pydantic_extra__.items()}) + + def update(self, old: "ExtractedGroupSlot"): + """ + Rebase this extracted groups slot on top of another one. + This is required to merge slot storage in-context + with a potentially different slot configuration passed to pipeline. + + :param old: An instance of :py:class:`~.ExtractedGroupSlot` stored in-context. + Extracted values will be transferred to this object. + """ + for slot in old.__pydantic_extra__: + if slot in self.__pydantic_extra__: + new_slot = self.__pydantic_extra__[slot] + old_slot = old.__pydantic_extra__[slot] + if isinstance(new_slot, ExtractedGroupSlot) and isinstance(old_slot, ExtractedGroupSlot): + new_slot.update(old_slot) + if isinstance(new_slot, ExtractedValueSlot) and isinstance(old_slot, ExtractedValueSlot): + self.__pydantic_extra__[slot] = old_slot + + +class BaseSlot(BaseModel, frozen=True): + """ + BaseSlot is a base class for all slots. + """ + + @abstractmethod + async def get_value(self, ctx: Context, pipeline: Pipeline) -> ExtractedSlot: + """ + Extract slot value from :py:class:`~.Context` and return an instance of :py:class:`~.ExtractedSlot`. + """ + raise NotImplementedError + + @abstractmethod + def init_value(self) -> ExtractedSlot: + """ + Provide an initial value to fill slot storage with. + """ + raise NotImplementedError + + +class ValueSlot(BaseSlot, frozen=True): + """ + Value slot is a base class for all slots that are designed to extract concrete values. + Subclass it, if you want to declare your own slot type. + """ + + default_value: Any = None + + @abstractmethod + async def extract_value(self, ctx: Context, pipeline: Pipeline) -> Union[Any, SlotNotExtracted]: + """ + Return value extracted from context. + + Return :py:exc:`~.SlotNotExtracted` to mark extraction as unsuccessful. + + Raising exceptions is also allowed and will result in an unsuccessful extraction as well. + """ + raise NotImplementedError + + async def get_value(self, ctx: Context, pipeline: Pipeline) -> ExtractedValueSlot: + """Wrapper for :py:meth:`~.ValueSlot.extract_value` to handle exceptions.""" + extracted_value = SlotNotExtracted("Caught an exit exception.") + is_slot_extracted = False + + try: + extracted_value = await self.extract_value(ctx, pipeline) + is_slot_extracted = not isinstance(extracted_value, SlotNotExtracted) + except Exception as error: + logger.exception(f"Exception occurred during {self.__class__.__name__!r} extraction.", exc_info=error) + extracted_value = error + finally: + return ExtractedValueSlot.model_construct( + is_slot_extracted=is_slot_extracted, + extracted_value=extracted_value, + default_value=self.default_value, + ) + + def init_value(self) -> ExtractedValueSlot: + return ExtractedValueSlot.model_construct( + is_slot_extracted=False, + extracted_value=SlotNotExtracted("Initial slot extraction."), + default_value=self.default_value, + ) + + +class GroupSlot(BaseSlot, extra="allow", frozen=True): + """ + Base class for :py:class:`~.RootSlot` and :py:class:`~.GroupSlot`. + """ + + __pydantic_extra__: dict[str, Union["ValueSlot", "GroupSlot"]] + + def __init__(self, **kwargs): # supress unexpected argument warnings + super().__init__(**kwargs) + + @model_validator(mode="after") + def __check_extra_field_names__(self): + """ + Extra field names cannot be dunder names or contain dots. + """ + for field in self.__pydantic_extra__.keys(): + if "." in field: + raise ValueError(f"Extra field name cannot contain dots: {field!r}") + if field.startswith("__") and field.endswith("__"): + raise ValueError(f"Extra field names cannot be dunder: {field!r}") + return self + + async def get_value(self, ctx: Context, pipeline: Pipeline) -> ExtractedGroupSlot: + child_values = await asyncio.gather( + *(child.get_value(ctx, pipeline) for child in self.__pydantic_extra__.values()) + ) + return ExtractedGroupSlot( + **{child_name: child_value for child_value, child_name in zip(child_values, self.__pydantic_extra__.keys())} + ) + + def init_value(self) -> ExtractedGroupSlot: + return ExtractedGroupSlot( + **{child_name: child.init_value() for child_name, child in self.__pydantic_extra__.items()} + ) + + +class RegexpSlot(ValueSlot, frozen=True): + """ + RegexpSlot is a slot type that extracts its value using a regular expression. + You can pass a compiled or a non-compiled pattern to the `regexp` argument. + If you want to extract a particular group, but not the full match, + change the `match_group_idx` parameter. + """ + + regexp: str + match_group_idx: int = 0 + "Index of the group to match." + + async def extract_value(self, ctx: Context, _: Pipeline) -> Union[str, SlotNotExtracted]: + request_text = ctx.last_request.text + search = re.search(self.regexp, request_text) + return ( + search.group(self.match_group_idx) + if search + else SlotNotExtracted(f"Failed to match pattern {self.regexp!r} in {request_text!r}.") + ) + + +class FunctionSlot(ValueSlot, frozen=True): + """ + A simpler version of :py:class:`~.ValueSlot`. + + Uses a user-defined `func` to extract slot value from the :py:attr:`~.Context.last_request` Message. + """ + + func: Callable[[Message], Union[Awaitable[Union[Any, SlotNotExtracted]], Any, SlotNotExtracted]] + + async def extract_value(self, ctx: Context, _: Pipeline) -> Union[Any, SlotNotExtracted]: + return await wrap_sync_function_in_async(self.func, ctx.last_request) + + +class SlotManager(BaseModel): + """ + Provides API for managing slots. + + An instance of this class can be accessed via ``ctx.framework_data.slot_manager``. + """ + + slot_storage: ExtractedGroupSlot = Field(default_factory=ExtractedGroupSlot) + """Slot storage. Stored inside ctx.framework_data.""" + root_slot: GroupSlot = Field(default_factory=GroupSlot, exclude=True) + """Slot configuration passed during pipeline initialization.""" + + def set_root_slot(self, root_slot: GroupSlot): + """ + Set root_slot configuration from pipeline. + Update extracted slots with the new configuration: + + New slots are added with their :py:meth:`~.BaseSlot.init_value`. + Old extracted slot values are preserved only if their configuration did not change. + That is if they are still present in the config and if their fundamental type did not change + (i.e. `GroupSlot` did not turn into a `ValueSlot` or vice versa). + + This method is called by pipeline and is not supposed to be used otherwise. + """ + self.root_slot = root_slot + new_slot_storage = root_slot.init_value() + new_slot_storage.update(self.slot_storage) + self.slot_storage = new_slot_storage + + def get_slot(self, slot_name: SlotName) -> BaseSlot: + """ + Get slot configuration from the slot name. + + :raises KeyError: If the slot with the specified name does not exist. + """ + try: + slot = recursive_getattr(self.root_slot, slot_name) + if isinstance(slot, BaseSlot): + return slot + except (AttributeError, KeyError): + pass + raise KeyError(f"Could not find slot {slot_name!r}.") + + async def extract_slot(self, slot_name: SlotName, ctx: Context, pipeline: Pipeline) -> None: + """ + Extract slot `slot_name` and store extracted value in `slot_storage`. + + :raises KeyError: If the slot with the specified name does not exist. + """ + slot = self.get_slot(slot_name) + value = await slot.get_value(ctx, pipeline) + + recursive_setattr(self.slot_storage, slot_name, value) + + async def extract_all(self, ctx: Context, pipeline: Pipeline): + """ + Extract all slots from slot configuration `root_slot` and set `slot_storage` to the extracted value. + """ + self.slot_storage = await self.root_slot.get_value(ctx, pipeline) + + def get_extracted_slot(self, slot_name: SlotName) -> ExtractedSlot: + """ + Retrieve extracted value from `slot_storage`. + + :raises KeyError: If the slot with the specified name does not exist. + """ + try: + slot = recursive_getattr(self.slot_storage, slot_name) + if isinstance(slot, ExtractedSlot): + return slot + except (AttributeError, KeyError): + pass + raise KeyError(f"Could not find slot {slot_name!r}.") + + def is_slot_extracted(self, slot_name: str) -> bool: + """ + Return if the specified slot is extracted. + + :raises KeyError: If the slot with the specified name does not exist. + """ + return self.get_extracted_slot(slot_name).__slot_extracted__ + + def all_slots_extracted(self) -> bool: + """ + Return if all slots are extracted. + """ + return self.slot_storage.__slot_extracted__ + + def unset_slot(self, slot_name: SlotName) -> None: + """ + Mark specified slot as not extracted and clear extracted value. + + :raises KeyError: If the slot with the specified name does not exist. + """ + self.get_extracted_slot(slot_name).__unset__() + + def unset_all_slots(self) -> None: + """ + Mark all slots as not extracted and clear all extracted values. + """ + self.slot_storage.__unset__() + + def fill_template(self, template: str) -> str: + """ + Fill `template` string with extracted slot values and return a formatted string. + + `template` should be a format-string: + + E.g. "Your username is {profile.username}". + + For the example above, if ``profile.username`` slot has value "admin", + it would return the following text: + "Your username is admin". + """ + return template.format(**dict(self.slot_storage.__pydantic_extra__.items())) diff --git a/dff/utils/devel/__init__.py b/dff/utils/devel/__init__.py index 08ff1afbc..affbce004 100644 --- a/dff/utils/devel/__init__.py +++ b/dff/utils/devel/__init__.py @@ -11,3 +11,4 @@ JSONSerializableExtras, ) from .extra_field_helpers import grab_extra_fields +from .async_helpers import wrap_sync_function_in_async diff --git a/dff/utils/devel/async_helpers.py b/dff/utils/devel/async_helpers.py new file mode 100644 index 000000000..13cbc640b --- /dev/null +++ b/dff/utils/devel/async_helpers.py @@ -0,0 +1,24 @@ +""" +Async Helpers +------------- +Tools to help with async. +""" + +import asyncio +from typing import Callable, Any + + +async def wrap_sync_function_in_async(func: Callable, *args, **kwargs) -> Any: + """ + Utility function, that wraps both functions and coroutines in coroutines. + Invokes `func` if it is just a callable and awaits, if this is a coroutine. + + :param func: Callable to wrap. + :param \\*args: Function args. + :param \\**kwargs: Function kwargs. + :return: What function returns. + """ + if asyncio.iscoroutinefunction(func): + return await func(*args, **kwargs) + else: + return func(*args, **kwargs) diff --git a/dff/utils/devel/json_serialization.py b/dff/utils/devel/json_serialization.py index 017fc791e..f198dc47c 100644 --- a/dff/utils/devel/json_serialization.py +++ b/dff/utils/devel/json_serialization.py @@ -141,6 +141,8 @@ class MyClass(BaseModel): my_obj = MyClass() # the field cannot be set during init my_obj.my_field = unserializable_object # can be set manually to avoid validation +Alternatively, ``BaseModel.model_construct`` may be used to bypass validation, +though it would bypass validation of all fields. """ JSONPickleSerializer = PlainSerializer(json_pickle_serializer, when_used="json") diff --git a/docs/source/conf.py b/docs/source/conf.py index a4b65beb8..dd7c92bab 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -183,6 +183,7 @@ def setup(_): ("responses", "Responses"), ], ), + ("tutorials.slots", "Slots"), ("tutorials.utils", "Utils"), ("tutorials.stats", "Stats"), ] @@ -193,6 +194,7 @@ def setup(_): ("dff.messengers", "Messenger Interfaces"), ("dff.pipeline", "Pipeline"), ("dff.script", "Script"), + ("dff.slots", "Slots"), ("dff.stats", "Stats"), ("dff.utils.testing", "Testing Utils"), ("dff.utils.turn_caching", "Caching"), diff --git a/docs/source/user_guides.rst b/docs/source/user_guides.rst index 0cb1a1531..635ebb031 100644 --- a/docs/source/user_guides.rst +++ b/docs/source/user_guides.rst @@ -9,6 +9,14 @@ those include but are not limited to: dialog graph creation, specifying start an setting transitions and conditions, using ``Context`` object in order to receive information about current script execution. +:doc:`Slot extraction <./user_guides/slot_extraction>` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The ``slot extraction`` guide demonstrates the slot extraction functionality +currently integrated in the library. ``DFF`` only provides basic building blocks for this task, +which can be trivially extended to support any NLU engine or slot extraction model +of your liking. + :doc:`Context guide <./user_guides/context_guide>` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -34,6 +42,7 @@ and to locate and remove performance bottlenecks. :hidden: user_guides/basic_conceptions + user_guides/slot_extraction user_guides/context_guide user_guides/superset_guide user_guides/optimization_guide diff --git a/docs/source/user_guides/slot_extraction.rst b/docs/source/user_guides/slot_extraction.rst new file mode 100644 index 000000000..fc2b697e3 --- /dev/null +++ b/docs/source/user_guides/slot_extraction.rst @@ -0,0 +1,176 @@ +Slot Extraction +--------------- + +Introduction +~~~~~~~~~~~~ + +Extracting and filling slots is an essential part of any conversational service +that comprises the inherent business logic. Like most frameworks, DFF +provides components that address this task as a part of its ``slots`` module. +These can be easily customized to leverage neural networks specifically designed +for slot extraction or any other logic you might want to integrate. + +API overview +~~~~~~~~~~~~ + +Defining slots +============== + +The basic building block of the API is the +`BaseSlot <../apiref/dff.slots.slots.html#dff.slots.slots.BaseSlot>`_ class +and its descendants that vary depending on the value extraction logic. +Each slot has a name by which it can be accessed and a method for extracting values. +Below, we demonstrate the most basic class that extracts values +from user utterances using a regular expression: +`RegexpSlot <../apiref/dff.slots.slots.html#dff.slots.types.RegexpSlot>`_. + +.. code-block:: python + + from dff.slots import RegexpSlot + ... + email_slot = RegexpSlot(regexp=r"[a-z@\.A-Z]+") + +The slots can implement arbitrary logic including requests to external services. +For instance, Deeppavlov library includes a number of models that may be of use for slot +extraction task. In particular, we will demonstrate the use of the following +`NER model `_ +that was trained and validated on the conll_2003 dataset. + +.. code-block:: shell + + docker pull deeppavlov/deeppavlov:latest + docker run -d --name=ner \ + -e CONFIG=ner_conll2003_bert \ + -p 5000:5000 \ + -v ~/.deeppavlov:/root/deeppavlov \ + -v ~/.cache:/root/cache \ + deeppavlov/deeppavlov:latest + +Now that you have a Deeppavlov docker image running on port 5000, you can take the following steps to take +full advantage of its predictions. + +.. code-block:: python + + import requests + from dff.slots import FunctionSlot + from dff.script import Message + + # we assume that there is a 'NER' service running on port 5000 + def extract_first_name(utterance: Message) -> str: + """Return the first entity of type B-PER (first name) found in the utterance.""" + ner_request = requests.post( + "http://localhost:5000/model", + json={"x": [utterance.text]} + ) + ner_tuple = ner_request.json() + if "B-PER" not in ner_tuple[1][0]: + return "" + return ner_tuple[0][0][ner_tuple[1][0].index("B-PER")] + + name_slot = FunctionSlot(func=extract_first_name) + +Individual slots can be grouped allowing the developer to access them together +as a namespace. This can be achieved using the +`GroupSlot <../apiref/dff.slots.slots.html#dff.slots.slots.GroupSlot>`_ +component that is initialized with other slot instances as its children. +The group slots also allows for arbitrary nesting, i.e. it is possible to include +group slots in other group slots. + +.. code-block:: python + + from dff.slots import GroupSlot + + profile_slot = GroupSlot(name=name_slot, email=email_slot) + +After defining all your slots, pass ``GroupSlot`` as pipeline's `slots` argument. +That slot is a root slot: it contains all other group and value slots. + +.. code-block:: python + + from dff.pipeline import Pipeline + + pipeline = Pipeline.from_script(..., slots=profile_slot) + +Slot names +========== + +Any slot can be accessed by a slot name: +A dot-separated string that acts as a path from the root slot to the needed slot. + +In the example above ``name_slot`` would have the name "name" +because that is the key used to store it in the ``profile_slot``. + +If you have a nested structure (of ``GroupSlots``) separate the names with dots: + +.. code-block:: python + + from dff.slots import GroupSlot + + root_slot = GroupSlot(profile=GroupSlot(name=name_slot, email=email_slot)) + +In this example ``name_slot`` would be accessible by the "profile.name" name. + +Using slots +=========== + +Slots can be extracted at the ``PRE_TRANSITIONS_PROCESSING`` stage +using the `extract <../apiref/dff.slots.processing.html#dff.slots.processing.extract>`_ +function from the `processing` submodule. +You can pass any number of names of the slots that you want to extract to this function. + +.. code-block:: python + + from dff.slots.processing import extract + + PRE_TRANSITIONS_PROCESSING: {"extract_first_name": extract("name", "email")} + +The `conditions` submodule provides a function for checking if specific slots have been extracted. + +.. code-block:: python + + from dff.slots.conditions import slots_extracted + + TRANSITIONS: {"all_information": slots_extracted("name", "email", mode="all")} + TRANSITIONS: {"partial_information": slots_extracted("name", "email", mode="any")} + +.. note:: + + You can combine ``slots_extracted`` with the + `negation <../apiref/dff.script.conditions.std_conditions.html#dff.script.conditions.std_conditions.negation>`_ + condition to make a transition to an extractor node if a slot has not been extracted yet. + +Both `processing` and `response` submodules provide functions for filling templates with +extracted slot values. +Choose whichever one you like, there's not much difference between them at the moment. + +.. code-block:: python + + from dff.slots.processing import fill_template + from dff.slots.response import filled_template + + PRE_RESPONSE_PROCESSING: {"fill_response_slots": slot_procs.fill_template()} + RESPONSE: Message(text="Your first name: {name}") + + + RESPONSE: filled_template(Message(text="Your first name: {name}")) + +Some real examples of scripts utilizing slot extraction can be found in the +`tutorials section <../tutorials/tutorials.slots.1_basic_example.html>`_. + +Further reading +=============== + +All of the functions described in the previous sections call methods of the +`SlotManager <../apiref/dff.slots.slots.html#dff.slots.slots.SlotManager>`_ +class under the hood. + +An instance of this class can be accessed in runtime via ``ctx.framework_data.slot_manager``. + +This class allows for more detailed access to the slots API. +For example, you can access exceptions that occurred during slot extraction: + +.. code-block:: python + + slot_manager = ctx.framework_data.slot_manager + extracted_value = slot_manager.get_extracted_slot("name") + exception = extracted_value.extracted_value if not extracted_value.is_slot_extracted else None diff --git a/poetry.lock b/poetry.lock index 5f1b51b77..eadc0df50 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1486,6 +1486,20 @@ files = [ dnspython = ">=2.0.0" idna = ">=2.0.0" +[[package]] +name = "eval-type-backport" +version = "0.2.0" +description = "Like `typing._eval_type`, but lets older Python versions use newer typing features." +optional = false +python-versions = ">=3.8" +files = [ + {file = "eval_type_backport-0.2.0-py3-none-any.whl", hash = "sha256:ac2f73d30d40c5a30a80b8739a789d6bb5e49fdffa66d7912667e2015d9c9933"}, + {file = "eval_type_backport-0.2.0.tar.gz", hash = "sha256:68796cfbc7371ebf923f03bdf7bef415f3ec098aeced24e054b253a0e78f7b37"}, +] + +[package.extras] +tests = ["pytest"] + [[package]] name = "exceptiongroup" version = "1.2.1" @@ -1821,8 +1835,8 @@ files = [ [package.dependencies] cffi = {version = ">=1.12.2", markers = "platform_python_implementation == \"CPython\" and sys_platform == \"win32\""} greenlet = [ - {version = ">=3.0rc3", markers = "platform_python_implementation == \"CPython\" and python_version >= \"3.11\""}, {version = ">=2.0.0", markers = "platform_python_implementation == \"CPython\" and python_version < \"3.11\""}, + {version = ">=3.0rc3", markers = "platform_python_implementation == \"CPython\" and python_version >= \"3.11\""}, ] "zope.event" = "*" "zope.interface" = "*" @@ -3880,8 +3894,8 @@ files = [ [package.dependencies] numpy = [ - {version = ">=1.23.2", markers = "python_version >= \"3.11\""}, {version = ">=1.20.3", markers = "python_version < \"3.10\""}, + {version = ">=1.23.2", markers = "python_version >= \"3.11\""}, {version = ">=1.21.0", markers = "python_version >= \"3.10\" and python_version < \"3.11\""}, ] python-dateutil = ">=2.8.2" @@ -5032,7 +5046,6 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -7254,4 +7267,4 @@ ydb = ["six", "ydb"] [metadata] lock-version = "2.0" python-versions = "^3.8.1,!=3.9.7" -content-hash = "0fea55ff020381487754e65f4630de8bfaedc65c37b6d071763d427f1667b7b3" +content-hash = "a4e53a8b58504d6e4f877ac5e7901d5aa8451003bf9edf55ebfb4df7af8424ab" diff --git a/pyproject.toml b/pyproject.toml index a7ac29fb4..948809258 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,6 +55,7 @@ python = "^3.8.1,!=3.9.7" # `streamlit` package does not support python 3.9.7, pydantic = ">=2.0" # `pydantic` version more than 2 required nest-asyncio = "*" typing-extensions = "*" +eval_type_backport = "*" wrapt = "*" colorama = "*" ydb = { version = "*", optional = true } @@ -212,6 +213,7 @@ markers = [ "all: reserved by allow-skip", "none: reserved by allow-skip", ] +asyncio_mode = "auto" [tool.coverage.run] diff --git a/tests/slots/__init__.py b/tests/slots/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/slots/conftest.py b/tests/slots/conftest.py new file mode 100644 index 000000000..a75b6fa69 --- /dev/null +++ b/tests/slots/conftest.py @@ -0,0 +1,28 @@ +import pytest + +from dff.script import Message, TRANSITIONS, RESPONSE, Context +from dff.script import conditions as cnd +from dff.pipeline import Pipeline +from dff.slots.slots import SlotNotExtracted + + +@pytest.fixture(scope="function", autouse=True) +def patch_exception_equality(monkeypatch): + monkeypatch.setattr( + SlotNotExtracted, "__eq__", lambda self, other: type(self) is type(other) and self.args == other.args + ) + yield + + +@pytest.fixture(scope="function") +def pipeline(): + script = {"flow": {"node": {RESPONSE: Message(), TRANSITIONS: {"node": cnd.true()}}}} + pipeline = Pipeline.from_script(script=script, start_label=("flow", "node")) + return pipeline + + +@pytest.fixture(scope="function") +def context(): + ctx = Context() + ctx.add_request(Message(text="Hi")) + return ctx diff --git a/tests/slots/test_slot_manager.py b/tests/slots/test_slot_manager.py new file mode 100644 index 000000000..12f786038 --- /dev/null +++ b/tests/slots/test_slot_manager.py @@ -0,0 +1,262 @@ +import pytest + +from dff.slots.slots import ( + SlotManager, + RegexpSlot, + GroupSlot, + FunctionSlot, + ExtractedGroupSlot, + ExtractedValueSlot, + SlotNotExtracted, +) +from dff.script import Message + + +def faulty_func(_): + raise SlotNotExtracted("Error.") + + +init_value_slot = ExtractedValueSlot.model_construct( + is_slot_extracted=False, + extracted_value=SlotNotExtracted("Initial slot extraction."), + default_value=None, +) + + +root_slot = GroupSlot( + person=GroupSlot( + name=RegexpSlot(regexp=r"(?<=am ).+?(?=\.)"), + surname=FunctionSlot(func=faulty_func), + email=RegexpSlot(regexp=r"[a-zA-Z\.]+@[a-zA-Z\.]+"), + ), + msg_len=FunctionSlot(func=lambda msg: len(msg.text)), +) + + +extracted_slot_values = { + "person.name": ExtractedValueSlot.model_construct( + is_slot_extracted=True, extracted_value="Bot", default_value=None + ), + "person.surname": ExtractedValueSlot.model_construct( + is_slot_extracted=False, extracted_value=SlotNotExtracted("Error."), default_value=None + ), + "person.email": ExtractedValueSlot.model_construct( + is_slot_extracted=True, extracted_value="bot@bot", default_value=None + ), + "msg_len": ExtractedValueSlot.model_construct(is_slot_extracted=True, extracted_value=29, default_value=None), +} + + +extracted_slot_values["person"] = ExtractedGroupSlot( + name=extracted_slot_values["person.name"], + surname=extracted_slot_values["person.surname"], + email=extracted_slot_values["person.email"], +) + + +unset_slot = ExtractedValueSlot.model_construct( + is_slot_extracted=False, extracted_value=SlotNotExtracted("Slot manually unset."), default_value=None +) + + +init_slot_storage = ExtractedGroupSlot( + person=ExtractedGroupSlot( + name=init_value_slot, + surname=init_value_slot, + email=init_value_slot, + ), + msg_len=init_value_slot, +) + + +unset_slot_storage = ExtractedGroupSlot( + person=ExtractedGroupSlot( + name=unset_slot, + surname=unset_slot, + email=unset_slot, + ), + msg_len=unset_slot, +) + + +full_slot_storage = ExtractedGroupSlot( + person=ExtractedGroupSlot( + name=extracted_slot_values["person.name"], + surname=extracted_slot_values["person.surname"], + email=extracted_slot_values["person.email"], + ), + msg_len=extracted_slot_values["msg_len"], +) + + +class TestSlotManager: + @pytest.fixture(scope="function") + def context_with_request(self, context): + new_ctx = context.model_copy(deep=True) + new_ctx.add_request(Message(text="I am Bot. My email is bot@bot")) + return new_ctx + + async def test_init_slot_storage(self): + assert root_slot.init_value() == init_slot_storage + + @pytest.fixture(scope="function") + def empty_slot_manager(self): + manager = SlotManager() + manager.set_root_slot(root_slot) + return manager + + @pytest.fixture(scope="function") + def extracted_slot_manager(self): + slot_storage = full_slot_storage.model_copy(deep=True) + return SlotManager(root_slot=root_slot, slot_storage=slot_storage) + + @pytest.fixture(scope="function") + def fully_extracted_slot_manager(self): + slot_storage = full_slot_storage.model_copy(deep=True) + slot_storage.person.surname = ExtractedValueSlot.model_construct( + extracted_value="Bot", is_slot_extracted=True, default_value=None + ) + return SlotManager(root_slot=root_slot, slot_storage=slot_storage) + + def test_get_slot_by_name(self, empty_slot_manager): + assert empty_slot_manager.get_slot("person.name").regexp == r"(?<=am ).+?(?=\.)" + assert empty_slot_manager.get_slot("person.email").regexp == r"[a-zA-Z\.]+@[a-zA-Z\.]+" + assert isinstance(empty_slot_manager.get_slot("person"), GroupSlot) + assert isinstance(empty_slot_manager.get_slot("msg_len"), FunctionSlot) + + with pytest.raises(KeyError): + empty_slot_manager.get_slot("person.birthday") + + with pytest.raises(KeyError): + empty_slot_manager.get_slot("intent") + + @pytest.mark.parametrize( + "slot_name,expected_slot_storage", + [ + ( + "person.name", + ExtractedGroupSlot( + person=ExtractedGroupSlot( + name=extracted_slot_values["person.name"], + surname=init_value_slot, + email=init_value_slot, + ), + msg_len=init_value_slot, + ), + ), + ( + "person", + ExtractedGroupSlot( + person=ExtractedGroupSlot( + name=extracted_slot_values["person.name"], + surname=extracted_slot_values["person.surname"], + email=extracted_slot_values["person.email"], + ), + msg_len=init_value_slot, + ), + ), + ( + "msg_len", + ExtractedGroupSlot( + person=ExtractedGroupSlot( + name=init_value_slot, + surname=init_value_slot, + email=init_value_slot, + ), + msg_len=extracted_slot_values["msg_len"], + ), + ), + ], + ) + async def test_slot_extraction( + self, slot_name, expected_slot_storage, empty_slot_manager, context_with_request, pipeline + ): + await empty_slot_manager.extract_slot(slot_name, context_with_request, pipeline) + assert empty_slot_manager.slot_storage == expected_slot_storage + + async def test_extract_all(self, empty_slot_manager, context_with_request, pipeline): + await empty_slot_manager.extract_all(context_with_request, pipeline) + assert empty_slot_manager.slot_storage == full_slot_storage + + @pytest.mark.parametrize( + "slot_name, expected_slot_storage", + [ + ( + "person.name", + ExtractedGroupSlot( + person=ExtractedGroupSlot( + name=unset_slot, + surname=extracted_slot_values["person.surname"], + email=extracted_slot_values["person.email"], + ), + msg_len=extracted_slot_values["msg_len"], + ), + ), + ( + "person", + ExtractedGroupSlot( + person=ExtractedGroupSlot( + name=unset_slot, + surname=unset_slot, + email=unset_slot, + ), + msg_len=extracted_slot_values["msg_len"], + ), + ), + ( + "msg_len", + ExtractedGroupSlot( + person=ExtractedGroupSlot( + name=extracted_slot_values["person.name"], + surname=extracted_slot_values["person.surname"], + email=extracted_slot_values["person.email"], + ), + msg_len=unset_slot, + ), + ), + ], + ) + def test_unset_slot(self, extracted_slot_manager, slot_name, expected_slot_storage): + extracted_slot_manager.unset_slot(slot_name) + assert extracted_slot_manager.slot_storage == expected_slot_storage + + def test_unset_all(self, extracted_slot_manager): + extracted_slot_manager.unset_all_slots() + assert extracted_slot_manager.slot_storage == unset_slot_storage + + @pytest.mark.parametrize("slot_name", ["person.name", "person", "msg_len"]) + def test_get_extracted_slot(self, extracted_slot_manager, slot_name): + assert extracted_slot_manager.get_extracted_slot(slot_name) == extracted_slot_values[slot_name] + + def test_get_extracted_slot_raises(self, extracted_slot_manager): + with pytest.raises(KeyError): + extracted_slot_manager.get_extracted_slot("none") + + def test_slot_extracted(self, fully_extracted_slot_manager, empty_slot_manager): + assert fully_extracted_slot_manager.is_slot_extracted("person.name") is True + assert fully_extracted_slot_manager.is_slot_extracted("person") is True + with pytest.raises(KeyError): + fully_extracted_slot_manager.is_slot_extracted("none") + assert fully_extracted_slot_manager.all_slots_extracted() is True + + assert empty_slot_manager.is_slot_extracted("person.name") is False + assert empty_slot_manager.is_slot_extracted("person") is False + with pytest.raises(KeyError): + empty_slot_manager.is_slot_extracted("none") + assert empty_slot_manager.all_slots_extracted() is False + + @pytest.mark.parametrize( + "template,filled_value", + [ + ( + "Your name is {person.name} {person.surname}, your email: {person.email}.", + "Your name is Bot None, your email: bot@bot.", + ), + ], + ) + def test_template_filling(self, extracted_slot_manager, template, filled_value): + assert extracted_slot_manager.fill_template(template) == filled_value + + def test_serializable(self): + serialized = full_slot_storage.model_dump_json() + assert full_slot_storage == ExtractedGroupSlot.model_validate_json(serialized) diff --git a/tests/slots/test_slot_types.py b/tests/slots/test_slot_types.py new file mode 100644 index 000000000..396d7b168 --- /dev/null +++ b/tests/slots/test_slot_types.py @@ -0,0 +1,161 @@ +import pytest +from pydantic import ValidationError + +from dff.script import Message +from dff.slots.slots import ( + RegexpSlot, + GroupSlot, + FunctionSlot, + SlotNotExtracted, + ExtractedValueSlot, + ExtractedGroupSlot, +) + + +@pytest.mark.parametrize( + ("user_request", "regexp", "expected"), + [ + ( + Message(text="My name is Bot"), + "(?<=name is ).+", + ExtractedValueSlot.model_construct(extracted_value="Bot", is_slot_extracted=True, default_value=None), + ), + ( + Message(text="I won't tell you my name"), + "(?<=name is ).+$", + ExtractedValueSlot.model_construct( + extracted_value=SlotNotExtracted( + "Failed to match pattern {regexp!r} in {request_text!r}.".format( + regexp="(?<=name is ).+$", request_text="I won't tell you my name" + ) + ), + is_slot_extracted=False, + default_value=None, + ), + ), + ], +) +async def test_regexp(user_request, regexp, expected, context, pipeline): + context.add_request(user_request) + slot = RegexpSlot(regexp=regexp) + result = await slot.get_value(context, pipeline) + assert result == expected + + +@pytest.mark.parametrize( + ("user_request", "func", "expected"), + [ + ( + Message(text="I am bot"), + lambda msg: msg.text.split(" ")[2], + ExtractedValueSlot.model_construct(extracted_value="bot", is_slot_extracted=True, default_value=None), + ), + ( + Message(text="My email is bot@bot"), + lambda msg: [i for i in msg.text.split(" ") if "@" in i][0], + ExtractedValueSlot.model_construct(extracted_value="bot@bot", is_slot_extracted=True, default_value=None), + ), + ], +) +async def test_function(user_request, func, expected, context, pipeline): + context.add_request(user_request) + slot = FunctionSlot(func=func) + result = await slot.get_value(context, pipeline) + assert result == expected + + async def async_func(*args, **kwargs): + return func(*args, **kwargs) + + slot = FunctionSlot(func=async_func) + result = await slot.get_value(context, pipeline) + assert result == expected + + +async def test_function_exception(context, pipeline): + def func(msg: Message): + raise RuntimeError("error") + + slot = FunctionSlot(func=func) + result = await slot.get_value(context, pipeline) + assert result.is_slot_extracted is False + assert isinstance(result.extracted_value, RuntimeError) + + +@pytest.mark.parametrize( + ("user_request", "slot", "expected", "is_extracted"), + [ + ( + Message(text="I am Bot. My email is bot@bot"), + GroupSlot( + name=RegexpSlot(regexp=r"(?<=am ).+?(?=\.)"), + email=RegexpSlot(regexp=r"[a-zA-Z\.]+@[a-zA-Z\.]+"), + ), + ExtractedGroupSlot( + name=ExtractedValueSlot.model_construct( + is_slot_extracted=True, extracted_value="Bot", default_value=None + ), + email=ExtractedValueSlot.model_construct( + is_slot_extracted=True, extracted_value="bot@bot", default_value=None + ), + ), + True, + ), + ( + Message(text="I am Bot. I won't tell you my email"), + GroupSlot( + name=RegexpSlot(regexp=r"(?<=am ).+?(?=\.)"), + email=RegexpSlot(regexp=r"[a-zA-Z\.]+@[a-zA-Z\.]+"), + ), + ExtractedGroupSlot( + name=ExtractedValueSlot.model_construct( + is_slot_extracted=True, extracted_value="Bot", default_value=None + ), + email=ExtractedValueSlot.model_construct( + is_slot_extracted=False, + extracted_value=SlotNotExtracted( + "Failed to match pattern {regexp!r} in {request_text!r}.".format( + regexp=r"[a-zA-Z\.]+@[a-zA-Z\.]+", request_text="I am Bot. I won't tell you my email" + ) + ), + default_value=None, + ), + ), + False, + ), + ], +) +async def test_group_slot_extraction(user_request, slot, expected, is_extracted, context, pipeline): + context.add_request(user_request) + result = await slot.get_value(context, pipeline) + assert result == expected + assert result.__slot_extracted__ == is_extracted + + +@pytest.mark.parametrize("forbidden_name", ["__dunder__", "contains.dot"]) +def test_group_subslot_name_validation(forbidden_name): + with pytest.raises(ValidationError): + GroupSlot(**{forbidden_name: RegexpSlot(regexp="")}) + + +async def test_str_representation(): + assert ( + str(ExtractedValueSlot.model_construct(is_slot_extracted=True, extracted_value="hello", default_value=None)) + == "hello" + ) + assert ( + str(ExtractedValueSlot.model_construct(is_slot_extracted=False, extracted_value=None, default_value="hello")) + == "hello" + ) + assert ( + str( + ExtractedGroupSlot( + first_name=ExtractedValueSlot.model_construct( + is_slot_extracted=True, extracted_value="Tom", default_value="John" + ), + last_name=ExtractedValueSlot.model_construct( + is_slot_extracted=False, extracted_value=None, default_value="Smith" + ), + ) + ) + == "{'first_name': 'Tom', 'last_name': 'Smith'}" + ) diff --git a/tests/slots/test_tutorials.py b/tests/slots/test_tutorials.py new file mode 100644 index 000000000..f85649eb8 --- /dev/null +++ b/tests/slots/test_tutorials.py @@ -0,0 +1,20 @@ +import importlib +import pytest +from tests.test_utils import get_path_from_tests_to_current_dir +from dff.utils.testing.common import check_happy_path + + +dot_path_to_addon = get_path_from_tests_to_current_dir(__file__, separator=".") + + +@pytest.mark.parametrize( + "tutorial_module_name", + [ + "1_basic_example", + ], +) +def test_examples(tutorial_module_name): + module = importlib.import_module(f"tutorials.{dot_path_to_addon}.{tutorial_module_name}") + pipeline = getattr(module, "pipeline") + happy_path = getattr(module, "HAPPY_PATH") + check_happy_path(pipeline, happy_path) diff --git a/tests/utils/test_serialization.py b/tests/utils/test_serialization.py index 334b8f14a..7e1dc0518 100644 --- a/tests/utils/test_serialization.py +++ b/tests/utils/test_serialization.py @@ -2,6 +2,7 @@ import pytest from pydantic import BaseModel +from copy import deepcopy import dff.utils.devel.json_serialization as json_ser @@ -63,7 +64,11 @@ def test_pickle(self, unserializable_obj): assert json_ser.pickle_validator(serialized) == unserializable_obj def test_json_pickle(self, unserializable_dict, non_serializable_fields, deserialized_dict): - serialized = json_ser.json_pickle_serializer(unserializable_dict) + dict_copy = deepcopy(unserializable_dict) + + serialized = json_ser.json_pickle_serializer(dict_copy) + + assert dict_copy == unserializable_dict, "Dict changed by serializer" assert serialized[json_ser._JSON_EXTRA_FIELDS_KEYS] == non_serializable_fields assert all(isinstance(serialized[field], str) for field in non_serializable_fields) @@ -80,7 +85,11 @@ class Class(BaseModel): obj = Class() obj.field = unserializable_obj - dump = obj.model_dump(mode="json") + obj_copy = obj.model_copy(deep=True) + + dump = obj_copy.model_dump(mode="json") + + assert obj == obj_copy, "Object changed by serializer" assert isinstance(dump["field"], str) @@ -94,7 +103,12 @@ class Class(BaseModel): obj = Class(field=unserializable_dict) - dump = obj.model_dump(mode="json") + obj_copy = obj.model_copy(deep=True) + + dump = obj_copy.model_dump(mode="json") + + assert obj == obj_copy, "Object changed by serializer" + assert dump["field"][json_ser._JSON_EXTRA_FIELDS_KEYS] == non_serializable_fields reconstructed_obj = Class.model_validate(dump) @@ -107,7 +121,12 @@ class Class(json_ser.JSONSerializableExtras): obj = Class(**unserializable_dict) - dump = obj.model_dump(mode="json") + obj_copy = obj.model_copy(deep=True) + + dump = obj_copy.model_dump(mode="json") + + assert obj == obj_copy, "Object changed by serializer" + assert dump[json_ser._JSON_EXTRA_FIELDS_KEYS] == non_serializable_fields reconstructed_obj = Class.model_validate(dump) diff --git a/tutorials/slots/1_basic_example.py b/tutorials/slots/1_basic_example.py new file mode 100644 index 000000000..bc07b32b5 --- /dev/null +++ b/tutorials/slots/1_basic_example.py @@ -0,0 +1,236 @@ +# %% [markdown] +""" +# 1. Basic Example + +The following tutorial shows basic usage of slots extraction +module packaged with `dff`. +""" + +# %pip install dff + +# %% +from dff.script import conditions as cnd +from dff.script import ( + RESPONSE, + TRANSITIONS, + PRE_TRANSITIONS_PROCESSING, + PRE_RESPONSE_PROCESSING, + GLOBAL, + LOCAL, + Message, +) + +from dff.pipeline import Pipeline +from dff.slots import GroupSlot, RegexpSlot +from dff.slots import processing as slot_procs +from dff.slots import response as slot_rsp +from dff.slots import conditions as slot_cnd + +from dff.utils.testing import ( + check_happy_path, + is_interactive_mode, + run_interactive_mode, +) + +# %% [markdown] +""" +The slots fall into the following category groups: + +- Value slots can be used to extract slot values from user utterances. +- Group slots can be used to split value slots into groups + with an arbitrary level of nesting. + +You can build the slot tree by passing the child slot instances as extra fields +of the parent slot. In the following cell, we define two slot groups: + + Group 1: person.username, person.email + Group 2: friend.first_name, friend.last_name + +Currently there are two types of value slots: + +- %mddoclink(api,slots.slots,RegexpSlot): + Extracts slot values via regexp. +- %mddoclink(api,slots.slots,FunctionSlot): + Extracts slot values with the help of a user-defined function. +""" + +# %% +SLOTS = GroupSlot( + person=GroupSlot( + username=RegexpSlot( + regexp=r"username is ([a-zA-Z]+)", + match_group_idx=1, + ), + email=RegexpSlot( + regexp=r"email is ([a-z@\.A-Z]+)", + match_group_idx=1, + ), + ), + friend=GroupSlot( + first_name=RegexpSlot(regexp=r"^[A-Z][a-z]+?(?= )"), + last_name=RegexpSlot(regexp=r"(?<= )[A-Z][a-z]+"), + ), +) + +# %% [markdown] +""" +The slots module provides several functions for managing slots in-script: + +- %mddoclink(api,slots.conditions,slots_extracted): + Condition for checking if specified slots are extracted. +- %mddoclink(api,slots.processing,extract): + A processing function that extracts specified slots. +- %mddoclink(api,slots.processing,extract_all): + A processing function that extracts all slots. +- %mddoclink(api,slots.processing,unset): + A processing function that marks specified slots as not extracted, + effectively resetting their state. +- %mddoclink(api,slots.processing,unset_all): + A processing function that marks all slots as not extracted. +- %mddoclink(api,slots.processing,fill_template): + A processing function that fills the `response` + Message text with extracted slot values. +- %mddoclink(api,slots.response,filled_template): + A response function that takes a Message with a + format-string text and returns Message + with its text string filled with extracted slot values. + +The usage of all the above functions is shown in the following script: +""" + +# %% +script = { + GLOBAL: {TRANSITIONS: {("username_flow", "ask"): cnd.regexp(r"^[sS]tart")}}, + "username_flow": { + LOCAL: { + PRE_TRANSITIONS_PROCESSING: { + "get_slot": slot_procs.extract("person.username") + }, + TRANSITIONS: { + ("email_flow", "ask", 1.2): slot_cnd.slots_extracted( + "person.username" + ), + ("username_flow", "repeat_question", 0.8): cnd.true(), + }, + }, + "ask": { + RESPONSE: Message(text="Write your username (my username is ...):"), + }, + "repeat_question": { + RESPONSE: Message( + text="Please, type your username again (my username is ...):" + ) + }, + }, + "email_flow": { + LOCAL: { + PRE_TRANSITIONS_PROCESSING: { + "get_slot": slot_procs.extract("person.email") + }, + TRANSITIONS: { + ("friend_flow", "ask", 1.2): slot_cnd.slots_extracted( + "person.username", "person.email" + ), + ("email_flow", "repeat_question", 0.8): cnd.true(), + }, + }, + "ask": { + RESPONSE: Message(text="Write your email (my email is ...):"), + }, + "repeat_question": { + RESPONSE: Message( + text="Please, write your email again (my email is ...):" + ) + }, + }, + "friend_flow": { + LOCAL: { + PRE_TRANSITIONS_PROCESSING: { + "get_slots": slot_procs.extract("friend") + }, + TRANSITIONS: { + ("root", "utter", 1.2): slot_cnd.slots_extracted( + "friend.first_name", "friend.last_name", mode="any" + ), + ("friend_flow", "repeat_question", 0.8): cnd.true(), + }, + }, + "ask": { + RESPONSE: Message( + text="Please, name me one of your friends: (John Doe)" + ) + }, + "repeat_question": { + RESPONSE: Message( + text="Please, name me one of your friends again: (John Doe)" + ) + }, + }, + "root": { + "start": { + RESPONSE: Message(text=""), + TRANSITIONS: {("username_flow", "ask"): cnd.true()}, + }, + "fallback": { + RESPONSE: Message(text="Finishing query"), + TRANSITIONS: {("username_flow", "ask"): cnd.true()}, + }, + "utter": { + RESPONSE: slot_rsp.filled_template( + Message( + text="Your friend is {friend.first_name} {friend.last_name}" + ) + ), + TRANSITIONS: {("root", "utter_alternative"): cnd.true()}, + }, + "utter_alternative": { + RESPONSE: Message( + text="Your username is {person.username}. " + "Your email is {person.email}." + ), + PRE_RESPONSE_PROCESSING: {"fill": slot_procs.fill_template()}, + TRANSITIONS: {("root", "fallback"): cnd.true()}, + }, + }, +} + +# %% +HAPPY_PATH = [ + ( + Message(text="hi"), + Message(text="Write your username (my username is ...):"), + ), + ( + Message(text="my username is groot"), + Message(text="Write your email (my email is ...):"), + ), + ( + Message(text="my email is groot@gmail.com"), + Message(text="Please, name me one of your friends: (John Doe)"), + ), + (Message(text="Bob Page"), Message(text="Your friend is Bob Page")), + ( + Message(text="ok"), + Message(text="Your username is groot. Your email is groot@gmail.com."), + ), + (Message(text="ok"), Message(text="Finishing query")), +] + +# %% +pipeline = Pipeline.from_script( + script, + start_label=("root", "start"), + fallback_label=("root", "fallback"), + slots=SLOTS, +) + +if __name__ == "__main__": + check_happy_path( + pipeline, HAPPY_PATH + ) # This is a function for automatic tutorial running + # (testing) with HAPPY_PATH + + # This runs tutorial in interactive mode if not in IPython env + # and if `DISABLE_INTERACTIVE_MODE` is not set + if is_interactive_mode(): + run_interactive_mode(pipeline)