Skip to content

Commit

Permalink
Merge slots (#36)
Browse files Browse the repository at this point in the history
Add slots feature

---------

Co-authored-by: Denis Kuznetsov <kuznetsov.den.p@gmail.com>
Co-authored-by: ruthenian8 <ruthenian8@gmail.com>
Co-authored-by: pseusys <shveitsar215@gmail.com>
  • Loading branch information
4 people authored Jul 2, 2024
1 parent 1593559 commit a2fe683
Show file tree
Hide file tree
Showing 28 changed files with 1,586 additions and 30 deletions.
12 changes: 12 additions & 0 deletions compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,18 @@ services:
retries: 5
start_period: 30s

ner:
image: deeppavlov/deeppavlov:latest
profiles:
- extras
environment:
- CONFIG=ner_conll2003_bert
restart: unless-stopped
ports:
- 5000:5000
volumes:
- ~/.deeppavlov:/root/.deeppavlov/
- ~/.cache:/root/.cache/
dashboard:
env_file: [.env_file]
build:
Expand Down
2 changes: 1 addition & 1 deletion dff/pipeline/pipeline/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from dff.script.core.script import Script, Node
from dff.script.core.normalization import normalize_label, normalize_response
from dff.script.core.keywords import GLOBAL, LOCAL
from dff.pipeline.service.utils import wrap_sync_function_in_async
from dff.utils.devel.async_helpers import wrap_sync_function_in_async

logger = logging.getLogger(__name__)

Expand Down
10 changes: 10 additions & 0 deletions dff/pipeline/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from dff.messengers.console import CLIMessengerInterface
from dff.messengers.common import MessengerInterface
from dff.slots.slots import GroupSlot
from ..service.group import ServiceGroup
from ..types import (
ServiceBuilder,
Expand Down Expand Up @@ -56,6 +57,7 @@ class Pipeline:
:param label_priority: Default priority value for all actor :py:const:`labels <dff.script.ConstLabel>`
where there is no priority. Defaults to `1.0`.
:param condition_handler: Handler that processes a call of actor condition functions. Defaults to `None`.
:param slots: Slots configuration.
:param handlers: This variable is responsible for the usage of external handlers on
the certain stages of work of :py:class:`~dff.script.Actor`.
Expand Down Expand Up @@ -89,6 +91,7 @@ def __init__(
fallback_label: Optional[NodeLabel2Type] = None,
label_priority: float = 1.0,
condition_handler: Optional[Callable] = None,
slots: Optional[Union[GroupSlot, Dict]] = None,
handlers: Optional[Dict[ActorStage, List[Callable]]] = None,
messenger_interface: Optional[MessengerInterface] = None,
context_storage: Optional[Union[DBContextStorage, Dict]] = None,
Expand All @@ -101,6 +104,7 @@ def __init__(
self.actor: Actor = None
self.messenger_interface = CLIMessengerInterface() if messenger_interface is None else messenger_interface
self.context_storage = {} if context_storage is None else context_storage
self.slots = GroupSlot.model_validate(slots) if slots is not None else None
self._services_pipeline = ServiceGroup(
components,
before_handler=before_handler,
Expand Down Expand Up @@ -208,6 +212,7 @@ def from_script(
fallback_label: Optional[NodeLabel2Type] = None,
label_priority: float = 1.0,
condition_handler: Optional[Callable] = None,
slots: Optional[Union[GroupSlot, Dict]] = None,
parallelize_processing: bool = False,
handlers: Optional[Dict[ActorStage, List[Callable]]] = None,
context_storage: Optional[Union[DBContextStorage, Dict]] = None,
Expand All @@ -229,6 +234,7 @@ def from_script(
:param label_priority: Default priority value for all actor :py:const:`labels <dff.script.ConstLabel>`
where there is no priority. Defaults to `1.0`.
:param condition_handler: Handler that processes a call of actor condition functions. Defaults to `None`.
:param slots: Slots configuration.
:param parallelize_processing: This flag determines whether or not the functions
defined in the ``PRE_RESPONSE_PROCESSING`` and ``PRE_TRANSITIONS_PROCESSING`` sections
of the script should be parallelized over respective groups.
Expand Down Expand Up @@ -257,6 +263,7 @@ def from_script(
fallback_label=fallback_label,
label_priority=label_priority,
condition_handler=condition_handler,
slots=slots,
parallelize_processing=parallelize_processing,
handlers=handlers,
messenger_interface=messenger_interface,
Expand Down Expand Up @@ -320,6 +327,9 @@ async def _run_pipeline(
if update_ctx_misc is not None:
ctx.misc.update(update_ctx_misc)

if self.slots is not None:
ctx.framework_data.slot_manager.set_root_slot(self.slots)

ctx.add_request(request)
result = await self._services_pipeline(ctx, self)

Expand Down
3 changes: 2 additions & 1 deletion dff/pipeline/service/extra.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@

from dff.script import Context

from .utils import collect_defined_constructor_parameters_to_dict, _get_attrs_with_updates, wrap_sync_function_in_async
from .utils import collect_defined_constructor_parameters_to_dict, _get_attrs_with_updates
from dff.utils.devel.async_helpers import wrap_sync_function_in_async
from ..types import (
ServiceRuntimeInfo,
ExtraHandlerType,
Expand Down
3 changes: 2 additions & 1 deletion dff/pipeline/service/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

from dff.script import Context

from .utils import wrap_sync_function_in_async, collect_defined_constructor_parameters_to_dict, _get_attrs_with_updates
from .utils import collect_defined_constructor_parameters_to_dict, _get_attrs_with_updates
from dff.utils.devel.async_helpers import wrap_sync_function_in_async
from ..types import (
ServiceBuilder,
StartConditionCheckerFunction,
Expand Down
19 changes: 1 addition & 18 deletions dff/pipeline/service/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,7 @@
These functions provide a variety of utility functionality.
"""

import asyncio
from typing import Callable, Any, Optional, Tuple, Mapping


async def wrap_sync_function_in_async(func: Callable, *args, **kwargs) -> Any:
"""
Utility function, that wraps both functions and coroutines in coroutines.
Invokes `func` if it is just a callable and awaits, if this is a coroutine.
:param func: Callable to wrap.
:param \\*args: Function args.
:param \\**kwargs: Function kwargs.
:return: What function returns.
"""
if asyncio.iscoroutinefunction(func):
return await func(*args, **kwargs)
else:
return func(*args, **kwargs)
from typing import Any, Optional, Tuple, Mapping


def _get_attrs_with_updates(
Expand Down
3 changes: 3 additions & 0 deletions dff/script/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from dff.script.core.message import Message
from dff.script.core.types import NodeLabel2Type
from dff.pipeline.types import ComponentExecutionState
from dff.slots.slots import SlotManager

if TYPE_CHECKING:
from dff.script.core.script import Node
Expand Down Expand Up @@ -56,6 +57,8 @@ class FrameworkData(BaseModel):
"Actor service data. Cleared at the end of every turn."
stats: Dict[str, Any] = Field(default_factory=dict)
"Enables complex stats collection across multiple turns."
slot_manager: SlotManager = Field(default_factory=SlotManager)
"Stores extracted slots."


class Context(BaseModel):
Expand Down
4 changes: 3 additions & 1 deletion dff/script/core/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,11 @@ def normalize_label(label: Label, default_flow_label: LabelType = "") -> Label:
"""
if callable(label):

def get_label_handler(ctx: Context, pipeline: Pipeline) -> ConstLabel:
def get_label_handler(ctx: Context, pipeline: Pipeline) -> Optional[ConstLabel]:
try:
new_label = label(ctx, pipeline)
if new_label is None:
return None
new_label = normalize_label(new_label, default_flow_label)
flow_label, node_label, _ = new_label
node = pipeline.script.get(flow_label, {}).get(node_label)
Expand Down
7 changes: 7 additions & 0 deletions dff/slots/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# -*- coding: utf-8 -*-
# flake8: noqa: F401

from dff.slots.slots import GroupSlot, ValueSlot, RegexpSlot, FunctionSlot
from dff.slots.conditions import slots_extracted
from dff.slots.processing import extract, extract_all, unset, unset_all, fill_template
from dff.slots.response import filled_template
32 changes: 32 additions & 0 deletions dff/slots/conditions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""
Conditions
---------------------------
Provides slot-related conditions.
"""

from __future__ import annotations
from typing import TYPE_CHECKING, Literal

if TYPE_CHECKING:
from dff.script import Context
from dff.slots.slots import SlotName
from dff.pipeline import Pipeline


def slots_extracted(*slots: SlotName, mode: Literal["any", "all"] = "all"):
"""
Conditions that checks if slots are extracted.
:param slots: Names for slots that need to be checked.
:param mode: Whether to check if all slots are extracted or any slot is extracted.
"""

def check_slot_state(ctx: Context, pipeline: Pipeline) -> bool:
manager = ctx.framework_data.slot_manager
if mode == "all":
return all(manager.is_slot_extracted(slot) for slot in slots)
elif mode == "any":
return any(manager.is_slot_extracted(slot) for slot in slots)
raise ValueError(f"{mode!r} not in ['any', 'all'].")

return check_slot_state
98 changes: 98 additions & 0 deletions dff/slots/processing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
"""
Processing
---------------------------
This module provides wrappers for :py:class:`~dff.slots.slots.SlotManager`'s API.
"""

from __future__ import annotations

import logging
from typing import Awaitable, Callable, TYPE_CHECKING

if TYPE_CHECKING:
from dff.slots.slots import SlotName
from dff.script import Context
from dff.pipeline import Pipeline

logger = logging.getLogger(__name__)


def extract(*slots: SlotName) -> Callable[[Context, Pipeline], Awaitable[None]]:
"""
Extract slots listed slots.
This will override all slots even if they are already extracted.
:param slots: List of slot names to extract.
"""

async def inner(ctx: Context, pipeline: Pipeline) -> None:
manager = ctx.framework_data.slot_manager
for slot in slots: # todo: maybe gather
await manager.extract_slot(slot, ctx, pipeline)

return inner


def extract_all():
"""
Extract all slots defined in the pipeline.
"""

async def inner(ctx: Context, pipeline: Pipeline):
manager = ctx.framework_data.slot_manager
await manager.extract_all(ctx, pipeline)

return inner


def unset(*slots: SlotName) -> Callable[[Context, Pipeline], None]:
"""
Mark specified slots as not extracted and clear extracted values.
:param slots: List of slot names to extract.
"""

def unset_inner(ctx: Context, pipeline: Pipeline) -> None:
manager = ctx.framework_data.slot_manager
for slot in slots:
manager.unset_slot(slot)

return unset_inner


def unset_all():
"""
Mark all slots as not extracted and clear all extracted values.
"""

def inner(ctx: Context, pipeline: Pipeline):
manager = ctx.framework_data.slot_manager
manager.unset_all_slots()

return inner


def fill_template() -> Callable[[Context, Pipeline], None]:
"""
Fill the response template in the current node.
Response message of the current node should be a format-string: e.g. "Your username is {profile.username}".
"""

def inner(ctx: Context, pipeline: Pipeline) -> None:
manager = ctx.framework_data.slot_manager
# get current node response
response = ctx.current_node.response

if response is None:
return

if callable(response):
response = response(ctx, pipeline)

new_text = manager.fill_template(response.text)

response.text = new_text
ctx.current_node.response = response

return inner
34 changes: 34 additions & 0 deletions dff/slots/response.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""
Response
---------------------------
Slot-related DFF responses.
"""

from __future__ import annotations
from typing import Callable, TYPE_CHECKING

if TYPE_CHECKING:
from dff.script import Context, Message
from dff.pipeline import Pipeline


def filled_template(template: Message) -> Callable[[Context, Pipeline], Message]:
"""
Fill template with slot values.
The `text` attribute of the template message should be a format-string:
e.g. "Your username is {profile.username}".
For the example above, if ``profile.username`` slot has value "admin",
it would return a copy of the message with the following text:
"Your username is admin".
:param template: Template message with a format-string text.
"""

def fill_inner(ctx: Context, pipeline: Pipeline) -> Message:
message = template.model_copy()
new_text = ctx.framework_data.slot_manager.fill_template(template.text)
message.text = new_text
return message

return fill_inner
Loading

0 comments on commit a2fe683

Please sign in to comment.