-
Notifications
You must be signed in to change notification settings - Fork 4.7k
/
Copy pathinteractive.py
1741 lines (1389 loc) · 58.7 KB
/
interactive.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import asyncio
import logging
import os
import textwrap
import uuid
from functools import partial
from multiprocessing import Process
from typing import (
Any,
Callable,
Deque,
Dict,
List,
Optional,
Text,
Tuple,
Union,
Set,
cast,
)
from sanic import Sanic, response
from sanic.exceptions import NotFound
from sanic.request import Request
from sanic.response import HTTPResponse
from terminaltables import AsciiTable, SingleTable
import terminaltables.width_and_alignment
import numpy as np
from aiohttp import ClientError
from colorclass import Color
import questionary
from questionary import Choice, Form, Question
from rasa import telemetry
import rasa.shared.utils.cli
import rasa.shared.utils.io
import rasa.cli.utils
import rasa.shared.data
from rasa.shared.nlu.constants import TEXT, INTENT_NAME_KEY
from rasa.shared.nlu.training_data.loading import RASA, RASA_YAML
from rasa.shared.core.constants import (
USER_INTENT_RESTART,
ACTION_LISTEN_NAME,
LOOP_NAME,
ACTIVE_LOOP,
LOOP_REJECTED,
REQUESTED_SLOT,
LOOP_INTERRUPTED,
ACTION_UNLIKELY_INTENT_NAME,
)
from rasa.core import run, utils
import rasa.core.train
from rasa.core.constants import DEFAULT_SERVER_FORMAT, DEFAULT_SERVER_PORT
from rasa.shared.core.domain import (
Domain,
KEY_INTENTS,
KEY_ENTITIES,
KEY_RESPONSES,
KEY_ACTIONS,
KEY_RESPONSES_TEXT,
)
import rasa.shared.core.events
from rasa.shared.core.events import (
ActionExecuted,
ActionReverted,
BotUttered,
Event,
Restarted,
UserUttered,
UserUtteranceReverted,
)
from rasa.shared.constants import (
INTENT_MESSAGE_PREFIX,
DEFAULT_SENDER_ID,
UTTER_PREFIX,
DOCS_URL_POLICIES,
)
from rasa.shared.core.trackers import EventVerbosity, DialogueStateTracker
from rasa.shared.core.training_data import visualization
from rasa.shared.core.training_data.visualization import (
VISUALIZATION_TEMPLATE_PATH,
visualize_neighborhood,
)
from rasa.core.utils import AvailableEndpoints
from rasa.shared.importers.rasa import TrainingDataImporter
from rasa.utils.common import update_sanic_log_level
from rasa.utils.endpoints import EndpointConfig
from rasa.shared.exceptions import InvalidConfigException
# noinspection PyProtectedMember
from rasa.shared.nlu.training_data import loading
from rasa.shared.nlu.training_data.message import Message
# WARNING: This command line UI is using an external library
# communicating with the shell - these functions are hard to test
# automatically. If you change anything in here, please make sure to
# run the interactive learning and check if your part of the "ui"
# still works.
import rasa.utils.io as io_utils
from rasa.shared.core.generator import TrackerWithCachedStates
logger = logging.getLogger(__name__)
PATHS = {
"stories": "data/stories.yml",
"nlu": "data/nlu.yml",
"backup": "data/nlu_interactive.yml",
"domain": "domain.yml",
}
SAVE_IN_E2E = False
# choose other intent, making sure this doesn't clash with an existing intent
OTHER_INTENT = uuid.uuid4().hex
OTHER_ACTION = uuid.uuid4().hex
NEW_ACTION = uuid.uuid4().hex
NEW_RESPONSES: Dict[Text, List[Dict[Text, Any]]] = {}
MAX_NUMBER_OF_TRAINING_STORIES_FOR_VISUALIZATION = 200
DEFAULT_STORY_GRAPH_FILE = "story_graph.dot"
class RestartConversation(Exception):
"""Exception used to break out the flow and restart the conversation."""
pass
class ForkTracker(Exception):
"""Exception used to break out the flow and fork at a previous step.
The tracker will be reset to the selected point in the past and the
conversation will continue from there.
"""
pass
class UndoLastStep(Exception):
"""Exception used to break out the flow and undo the last step.
The last step is either the most recent user message or the most
recent action run by the bot.
"""
pass
class Abort(Exception):
"""Exception used to abort the interactive learning and exit."""
pass
async def send_message(
endpoint: EndpointConfig,
conversation_id: Text,
message: Text,
parse_data: Optional[Dict[Text, Any]] = None,
) -> Optional[Any]:
"""Send a user message to a conversation."""
payload = {
"sender": UserUttered.type_name,
"text": message,
"parse_data": parse_data,
}
return await endpoint.request(
json=payload,
method="post",
subpath=f"/conversations/{conversation_id}/messages",
)
async def request_prediction(
endpoint: EndpointConfig, conversation_id: Text
) -> Optional[Any]:
"""Request the next action prediction from core."""
return await endpoint.request(
method="post", subpath=f"/conversations/{conversation_id}/predict"
)
async def retrieve_domain(endpoint: EndpointConfig) -> Optional[Any]:
"""Retrieve the domain from core."""
return await endpoint.request(
method="get", subpath="/domain", headers={"Accept": "application/json"}
)
async def retrieve_status(endpoint: EndpointConfig) -> Optional[Any]:
"""Retrieve the status from core."""
return await endpoint.request(method="get", subpath="/status")
async def retrieve_tracker(
endpoint: EndpointConfig,
conversation_id: Text,
verbosity: EventVerbosity = EventVerbosity.ALL,
) -> Dict[Text, Any]:
"""Retrieve a tracker from core."""
path = f"/conversations/{conversation_id}/tracker?include_events={verbosity.name}"
result = await endpoint.request(
method="get", subpath=path, headers={"Accept": "application/json"}
)
# If the request wasn't successful the previous call had already raised. Hence,
# we can be sure we have the tracker in the right format.
return cast(Dict[Text, Any], result)
async def send_action(
endpoint: EndpointConfig,
conversation_id: Text,
action_name: Text,
policy: Optional[Text] = None,
confidence: Optional[float] = None,
is_new_action: bool = False,
) -> Optional[Any]:
"""Log an action to a conversation."""
payload = ActionExecuted(action_name, policy, confidence).as_dict()
subpath = f"/conversations/{conversation_id}/execute"
try:
return await endpoint.request(json=payload, method="post", subpath=subpath)
except ClientError:
if is_new_action:
if action_name in NEW_RESPONSES:
warning_questions = questionary.confirm(
f"WARNING: You have created a new action: '{action_name}', "
f"with matching response: "
f"'{NEW_RESPONSES[action_name][0][KEY_RESPONSES_TEXT]}'. "
f"This action will not return its message in this session, "
f"but the new response will be saved to your domain file "
f"when you exit and save this session. "
f"You do not need to do anything further."
)
await _ask_questions(warning_questions, conversation_id, endpoint)
else:
warning_questions = questionary.confirm(
f"WARNING: You have created a new action: '{action_name}', "
f"which was not successfully executed. "
f"If this action does not return any events, "
f"you do not need to do anything. "
f"If this is a custom action which returns events, "
f"you are recommended to implement this action "
f"in your action server and try again."
)
await _ask_questions(warning_questions, conversation_id, endpoint)
payload = ActionExecuted(action_name).as_dict()
return await send_event(endpoint, conversation_id, payload)
else:
logger.error("failed to execute action!")
raise
async def send_event(
endpoint: EndpointConfig,
conversation_id: Text,
evt: Union[List[Dict[Text, Any]], Dict[Text, Any]],
) -> Optional[Any]:
"""Log an event to a conversation."""
subpath = f"/conversations/{conversation_id}/tracker/events"
return await endpoint.request(json=evt, method="post", subpath=subpath)
def format_bot_output(message: BotUttered) -> Text:
"""Format a bot response to be displayed in the history table."""
# First, add text to output
output = message.text or ""
# Then, append all additional items
data = message.data or {}
if not data:
return output
if "image" in data and data["image"] is not None:
output += "\nImage: " + data["image"]
if "attachment" in data and data["attachment"] is not None:
output += "\nAttachment: " + data["attachment"]
if "buttons" in data and data["buttons"] is not None:
output += "\nButtons:"
choices = rasa.cli.utils.button_choices_from_message_data(
data, allow_free_text_input=True
)
for choice in choices:
output += "\n" + choice
if "elements" in data and data["elements"] is not None:
output += "\nElements:"
for idx, element in enumerate(data["elements"]):
element_str = rasa.cli.utils.element_to_string(element, idx)
output += "\n" + element_str
if "quick_replies" in data and data["quick_replies"] is not None:
output += "\nQuick replies:"
for idx, element in enumerate(data["quick_replies"]):
element_str = rasa.cli.utils.element_to_string(element, idx)
output += "\n" + element_str
return output
def latest_user_message(events: List[Dict[Text, Any]]) -> Optional[Dict[Text, Any]]:
"""Return most recent user message."""
for i, e in enumerate(reversed(events)):
if e.get("event") == UserUttered.type_name:
return e
return None
async def _ask_questions(
questions: Union[Form, Question],
conversation_id: Text,
endpoint: EndpointConfig,
is_abort: Callable[[Dict[Text, Any]], bool] = lambda x: False,
) -> Any:
"""Ask the user a question, if Ctrl-C is pressed provide user with menu."""
should_retry = True
answers: Any = {}
while should_retry:
answers = await questions.ask_async()
if answers is None or is_abort(answers):
should_retry = await _ask_if_quit(conversation_id, endpoint)
else:
should_retry = False
return answers
def _selection_choices_from_intent_prediction(
predictions: List[Dict[Text, Any]]
) -> List[Dict[Text, Any]]:
"""Given a list of ML predictions create a UI choice list."""
sorted_intents = sorted(
predictions, key=lambda k: (-k["confidence"], k[INTENT_NAME_KEY])
)
choices = []
for p in sorted_intents:
name_with_confidence = (
f'{p.get("confidence"):03.2f} {p.get(INTENT_NAME_KEY):40}'
)
choice = {
INTENT_NAME_KEY: name_with_confidence,
"value": p.get(INTENT_NAME_KEY),
}
choices.append(choice)
return choices
async def _request_free_text_intent(
conversation_id: Text, endpoint: EndpointConfig
) -> Text:
question = questionary.text(
message="Please type the intent name:",
validate=io_utils.not_empty_validator("Please enter an intent name"),
)
return await _ask_questions(question, conversation_id, endpoint)
async def _request_free_text_action(
conversation_id: Text, endpoint: EndpointConfig
) -> Text:
question = questionary.text(
message="Please type the action name:",
validate=io_utils.not_empty_validator("Please enter an action name"),
)
return await _ask_questions(question, conversation_id, endpoint)
async def _request_free_text_utterance(
conversation_id: Text, endpoint: EndpointConfig, action: Text
) -> Text:
question = questionary.text(
message=(f"Please type the message for your new bot response '{action}':"),
validate=io_utils.not_empty_validator("Please enter a response"),
)
return await _ask_questions(question, conversation_id, endpoint)
async def _request_selection_from_intents(
intents: List[Dict[Text, Text]], conversation_id: Text, endpoint: EndpointConfig
) -> Text:
question = questionary.select("What intent is it?", choices=intents)
return await _ask_questions(question, conversation_id, endpoint)
async def _request_fork_point_from_list(
forks: List[Dict[Text, Text]], conversation_id: Text, endpoint: EndpointConfig
) -> Text:
question = questionary.select(
"Before which user message do you want to fork?", choices=forks
)
return await _ask_questions(question, conversation_id, endpoint)
async def _request_fork_from_user(
conversation_id: Text, endpoint: EndpointConfig
) -> Optional[List[Dict[Text, Any]]]:
"""Take in a conversation and ask at which point to fork the conversation.
Returns the list of events that should be kept. Forking means, the
conversation will be reset and continued from this previous point.
"""
tracker = await retrieve_tracker(
endpoint, conversation_id, EventVerbosity.AFTER_RESTART
)
choices = []
for i, e in enumerate(tracker.get("events", [])):
if e.get("event") == UserUttered.type_name:
choices.append({"name": e.get("text"), "value": i})
fork_idx = await _request_fork_point_from_list(
list(reversed(choices)), conversation_id, endpoint
)
if fork_idx is not None:
return tracker.get("events", [])[: int(fork_idx)]
else:
return None
async def _request_intent_from_user(
latest_message: Dict[Text, Any],
intents: List[Text],
conversation_id: Text,
endpoint: EndpointConfig,
) -> Dict[Text, Any]:
"""Take in latest message and ask which intent it should have been.
Returns the intent dict that has been selected by the user.
"""
predictions = latest_message.get("parse_data", {}).get("intent_ranking", [])
predicted_intents = {p[INTENT_NAME_KEY] for p in predictions}
for i in intents:
if i not in predicted_intents:
predictions.append({INTENT_NAME_KEY: i, "confidence": 0.0})
# convert intents to ui list and add <other> as a free text alternative
choices = [
{INTENT_NAME_KEY: "<create_new_intent>", "value": OTHER_INTENT}
] + _selection_choices_from_intent_prediction(predictions)
intent_name = await _request_selection_from_intents(
choices, conversation_id, endpoint
)
if intent_name == OTHER_INTENT:
intent_name = await _request_free_text_intent(conversation_id, endpoint)
selected_intent = {INTENT_NAME_KEY: intent_name, "confidence": 1.0}
else:
# returns the selected intent with the original probability value
selected_intent = next(
(x for x in predictions if x[INTENT_NAME_KEY] == intent_name),
{INTENT_NAME_KEY: None},
)
return selected_intent
async def _print_history(conversation_id: Text, endpoint: EndpointConfig) -> None:
"""Print information about the conversation for the user."""
tracker_dump = await retrieve_tracker(
endpoint, conversation_id, EventVerbosity.AFTER_RESTART
)
events = tracker_dump.get("events", [])
table = _chat_history_table(events)
slot_strings = _slot_history(tracker_dump)
print("------")
print("Chat History\n")
loop = asyncio.get_running_loop()
loop.run_in_executor(None, print, table)
if slot_strings:
print("\n")
slots_info = f"Current slots: \n\t{', '.join(slot_strings)}\n"
loop.run_in_executor(None, print, slots_info)
loop.run_in_executor(None, print, "------")
def _chat_history_table(events: List[Dict[Text, Any]]) -> Text:
"""Create a table containing bot and user messages.
Also includes additional information, like any events and
prediction probabilities.
"""
def wrap(txt: Text, max_width: int) -> Text:
true_wrapping_width = calc_true_wrapping_width(txt, max_width)
return "\n".join(
textwrap.wrap(txt, true_wrapping_width, replace_whitespace=False)
)
def colored(txt: Text, color: Text) -> Text:
return "{" + color + "}" + txt + "{/" + color + "}"
def format_user_msg(user_event: UserUttered, max_width: int) -> Text:
intent = user_event.intent or {}
intent_name = intent.get(INTENT_NAME_KEY, "")
_confidence = intent.get("confidence", 1.0)
_md = _as_md_message(user_event.parse_data)
_lines = [
colored(wrap(_md, max_width), "hired"),
f"intent: {intent_name} {_confidence:03.2f}",
]
return "\n".join(_lines)
def bot_width(_table: AsciiTable) -> int:
return _table.column_max_width(1)
def user_width(_table: AsciiTable) -> int:
return _table.column_max_width(3)
def add_bot_cell(data: List[List[Union[Text, Color]]], cell: Text) -> None:
data.append([len(data), Color(cell), "", ""])
def add_user_cell(data: List[List[Union[Text, Color]]], cell: Text) -> None:
data.append([len(data), "", "", Color(cell)])
# prints the historical interactions between the bot and the user,
# to help with correctly identifying the action
table_data = [
[
"# ",
Color(colored("Bot ", "autoblue")),
" ",
Color(colored("You ", "hired")),
]
]
table = SingleTable(table_data, "Chat History")
bot_column = []
tracker = DialogueStateTracker.from_dict("any", events)
applied_events = tracker.applied_events()
for idx, event in enumerate(applied_events):
if isinstance(event, ActionExecuted):
if (
event.action_name == ACTION_UNLIKELY_INTENT_NAME
and event.confidence == 0
):
continue
bot_column.append(colored(str(event), "autocyan"))
if event.confidence is not None:
bot_column[-1] += colored(f" {event.confidence:03.2f}", "autowhite")
elif isinstance(event, UserUttered):
if bot_column:
text = "\n".join(bot_column)
add_bot_cell(table_data, text)
bot_column = []
msg = format_user_msg(event, user_width(table))
add_user_cell(table_data, msg)
elif isinstance(event, BotUttered):
wrapped = wrap(format_bot_output(event), bot_width(table))
bot_column.append(colored(wrapped, "autoblue"))
else:
if event.as_story_string():
bot_column.append(wrap(event.as_story_string(), bot_width(table)))
if bot_column:
text = "\n".join(bot_column)
add_bot_cell(table_data, text)
table.inner_heading_row_border = False
table.inner_row_border = True
table.inner_column_border = False
table.outer_border = False
table.justify_columns = {0: "left", 1: "left", 2: "center", 3: "right"}
return table.table
def _slot_history(tracker_dump: Dict[Text, Any]) -> List[Text]:
"""Create an array of slot representations to be displayed."""
slot_strings = []
for k, s in tracker_dump.get("slots", {}).items():
colored_value = rasa.shared.utils.io.wrap_with_color(
str(s), color=rasa.shared.utils.io.bcolors.WARNING
)
slot_strings.append(f"{k}: {colored_value}")
return slot_strings
async def _retry_on_error(
func: Callable, export_path: Text, *args: Any, **kwargs: Any
) -> None:
while True:
try:
return func(export_path, *args, **kwargs)
except OSError as e:
answer = await questionary.confirm(
f"Failed to export '{export_path}': {e}. Please make sure 'rasa' "
f"has read and write access to this file. Would you like to retry?"
).ask_async()
if not answer:
raise e
async def _write_data_to_file(conversation_id: Text, endpoint: EndpointConfig) -> None:
"""Write stories and nlu data to file."""
story_path, nlu_path, domain_path = await _request_export_info()
tracker = await retrieve_tracker(endpoint, conversation_id)
events = tracker.get("events", [])
serialised_domain = await retrieve_domain(endpoint)
domain = Domain.from_dict(serialised_domain)
await _retry_on_error(_write_stories_to_file, story_path, events, domain)
await _retry_on_error(_write_nlu_to_file, nlu_path, events)
await _retry_on_error(_write_domain_to_file, domain_path, events, domain)
logger.info("Successfully wrote stories and NLU data")
async def _ask_if_quit(conversation_id: Text, endpoint: EndpointConfig) -> bool:
"""Display the exit menu.
Return `True` if the previous question should be retried.
"""
answer = await questionary.select(
message="Do you want to stop?",
choices=[
Choice("Continue", "continue"),
Choice("Undo Last", "undo"),
Choice("Fork", "fork"),
Choice("Start Fresh", "restart"),
Choice("Export & Quit", "quit"),
],
).ask_async()
if not answer or answer == "quit":
# this is also the default answer if the user presses Ctrl-C
await _write_data_to_file(conversation_id, endpoint)
raise Abort()
elif answer == "undo":
raise UndoLastStep()
elif answer == "fork":
raise ForkTracker()
elif answer == "restart":
raise RestartConversation()
else: # `continue` or no answer
# in this case we will just return, and the original
# question will get asked again
return True
async def _request_action_from_user(
predictions: List[Dict[Text, Any]], conversation_id: Text, endpoint: EndpointConfig
) -> Tuple[Text, bool]:
"""Ask the user to correct an action prediction."""
await _print_history(conversation_id, endpoint)
choices = [
{"name": f'{a["score"]:03.2f} {a["action"]:40}', "value": a["action"]}
for a in predictions
]
tracker = await retrieve_tracker(endpoint, conversation_id)
events = tracker.get("events", [])
session_actions_all = [a["name"] for a in _collect_actions(events)]
session_actions_unique = list(set(session_actions_all))
old_actions = [action["value"] for action in choices]
new_actions = [
{"name": action, "value": OTHER_ACTION + action}
for action in session_actions_unique
if action not in old_actions
]
choices = (
[{"name": "<create new action>", "value": NEW_ACTION}] + new_actions + choices
)
question = questionary.select("What is the next action of the bot?", choices)
action_name = await _ask_questions(question, conversation_id, endpoint)
is_new_action = action_name == NEW_ACTION
if is_new_action:
# create new action
action_name = await _request_free_text_action(conversation_id, endpoint)
if action_name.startswith(UTTER_PREFIX):
utter_message = await _request_free_text_utterance(
conversation_id, endpoint, action_name
)
NEW_RESPONSES[action_name] = [{KEY_RESPONSES_TEXT: utter_message}]
elif action_name[:32] == OTHER_ACTION:
# action was newly created in the session, but not this turn
is_new_action = True
action_name = action_name[32:]
print(f"Thanks! The bot will now run {action_name}.\n")
return action_name, is_new_action
async def _request_export_info() -> Tuple[Text, Text, Text]:
import rasa.shared.data
"""Request file path and export stories & nlu data to that path"""
# export training data and quit
questions = questionary.form(
export_stories=questionary.text(
message="Export stories to (if file exists, this "
"will append the stories)",
default=PATHS["stories"],
validate=io_utils.file_type_validator(
rasa.shared.data.YAML_FILE_EXTENSIONS,
"Please provide a valid export path for the stories, "
"e.g. 'stories.yml'.",
),
),
export_nlu=questionary.text(
message="Export NLU data to (if file exists, this will "
"merge learned data with previous training examples)",
default=PATHS["nlu"],
validate=io_utils.file_type_validator(
list(rasa.shared.data.TRAINING_DATA_EXTENSIONS),
"Please provide a valid export path for the NLU data, "
"e.g. 'nlu.yml'.",
),
),
export_domain=questionary.text(
message="Export domain file to (if file exists, this "
"will be overwritten)",
default=PATHS["domain"],
validate=io_utils.file_type_validator(
rasa.shared.data.YAML_FILE_EXTENSIONS,
"Please provide a valid export path for the domain file, "
"e.g. 'domain.yml'.",
),
),
)
answers = await questions.ask_async()
if not answers:
raise Abort()
return answers["export_stories"], answers["export_nlu"], answers["export_domain"]
def _split_conversation_at_restarts(
events: List[Dict[Text, Any]]
) -> List[List[Dict[Text, Any]]]:
"""Split a conversation at restart events.
Returns an array of event lists, without the restart events.
"""
deserialized_events = [Event.from_parameters(event) for event in events]
split_events = rasa.shared.core.events.split_events(
deserialized_events, Restarted, include_splitting_event=False
)
return [[event.as_dict() for event in events] for events in split_events]
def _collect_messages(events: List[Dict[Text, Any]]) -> List[Message]:
"""Collect the message text and parsed data from the UserMessage events
into a list.
"""
import rasa.shared.nlu.training_data.util as rasa_nlu_training_data_utils
messages = []
for event in events:
if event.get("event") == UserUttered.type_name:
data = event.get("parse_data", {})
rasa_nlu_training_data_utils.remove_untrainable_entities_from(data)
msg = Message.build(
data["text"], data["intent"][INTENT_NAME_KEY], data["entities"]
)
messages.append(msg)
elif event.get("event") == UserUtteranceReverted.type_name and messages:
messages.pop() # user corrected the nlu, remove incorrect example
return messages
def _collect_actions(events: List[Dict[Text, Any]]) -> List[Dict[Text, Any]]:
"""Collect all the `ActionExecuted` events into a list."""
return [evt for evt in events if evt.get("event") == ActionExecuted.type_name]
def _write_stories_to_file(
export_story_path: Text, events: List[Dict[Text, Any]], domain: Domain
) -> None:
"""Write the conversation of the conversation_id to the file paths."""
from rasa.shared.core.training_data.story_writer.yaml_story_writer import (
YAMLStoryWriter,
)
sub_conversations = _split_conversation_at_restarts(events)
io_utils.create_path(export_story_path)
if rasa.shared.data.is_likely_yaml_file(export_story_path):
writer = YAMLStoryWriter()
should_append_stories = False
if os.path.exists(export_story_path):
append_write = "a" # append if already exists
should_append_stories = True
else:
append_write = "w" # make a new file if not
with open(
export_story_path, append_write, encoding=rasa.shared.utils.io.DEFAULT_ENCODING
) as f:
interactive_story_counter = 1
for conversation in sub_conversations:
parsed_events = rasa.shared.core.events.deserialise_events(conversation)
tracker = DialogueStateTracker.from_events(
f"interactive_story_{interactive_story_counter}",
evts=parsed_events,
slots=domain.slots,
)
if any(
isinstance(event, UserUttered) for event in tracker.applied_events()
):
interactive_story_counter += 1
f.write(
"\n"
+ tracker.export_stories(
writer=writer,
should_append_stories=should_append_stories,
e2e=SAVE_IN_E2E,
)
)
def _filter_messages(msgs: List[Message]) -> List[Message]:
"""Filter messages removing those that start with INTENT_MESSAGE_PREFIX."""
filtered_messages = []
for msg in msgs:
if not msg.get(TEXT).startswith(INTENT_MESSAGE_PREFIX):
filtered_messages.append(msg)
return filtered_messages
def _write_nlu_to_file(export_nlu_path: Text, events: List[Dict[Text, Any]]) -> None:
"""Write the nlu data of the conversation_id to the file paths."""
from rasa.shared.nlu.training_data.training_data import TrainingData
msgs = _collect_messages(events)
msgs = _filter_messages(msgs)
# noinspection PyBroadException
try:
previous_examples = loading.load_data(export_nlu_path)
except Exception as e:
logger.debug(f"An exception occurred while trying to load the NLU data. {e!s}")
# No previous file exists, use empty training data as replacement.
previous_examples = TrainingData()
nlu_data = previous_examples.merge(TrainingData(msgs))
# need to guess the format of the file before opening it to avoid a read
# in a write
nlu_format = _get_nlu_target_format(export_nlu_path)
if nlu_format == RASA_YAML:
stringified_training_data = nlu_data.nlu_as_yaml()
else:
stringified_training_data = nlu_data.nlu_as_json()
rasa.shared.utils.io.write_text_file(stringified_training_data, export_nlu_path)
def _get_nlu_target_format(export_path: Text) -> Text:
guessed_format = loading.guess_format(export_path)
if guessed_format not in {RASA, RASA_YAML}:
if rasa.shared.data.is_likely_json_file(export_path):
guessed_format = RASA
elif rasa.shared.data.is_likely_yaml_file(export_path):
guessed_format = RASA_YAML
return guessed_format
def _entities_from_messages(messages: List[Message]) -> List[Text]:
"""Return all entities that occur in at least one of the messages."""
return list({e["entity"] for m in messages for e in m.data.get("entities", [])})
def _intents_from_messages(messages: List[Message]) -> Set[Text]:
"""Return all intents that occur in at least one of the messages."""
# set of distinct intents
distinct_intents = {m.data["intent"] for m in messages if "intent" in m.data}
return distinct_intents
def _write_domain_to_file(
domain_path: Text, events: List[Dict[Text, Any]], old_domain: Domain
) -> None:
"""Write an updated domain file to the file path."""
io_utils.create_path(domain_path)
messages = _collect_messages(events)
actions = _collect_actions(events)
responses = NEW_RESPONSES
# TODO for now there is no way to distinguish between action and form
collected_actions = list(
{
e["name"]
for e in actions
if e["name"] not in rasa.shared.core.constants.DEFAULT_ACTION_NAMES
and e["name"] not in old_domain.form_names
}
)
new_domain = Domain.from_dict(
{
KEY_INTENTS: list(_intents_from_messages(messages)),
KEY_ENTITIES: _entities_from_messages(messages),
KEY_RESPONSES: responses,
KEY_ACTIONS: collected_actions,
}
)
old_domain.merge(new_domain).persist(domain_path)
async def _predict_till_next_listen(
endpoint: EndpointConfig,
conversation_id: Text,
conversation_ids: List[Text],
plot_file: Optional[Text],
) -> None:
"""Predict and validate actions until we need to wait for a user message."""
listen = False
while not listen:
result = await request_prediction(endpoint, conversation_id)
if result is None:
result = {}
predictions = result.get("scores", [])
if not predictions:
raise InvalidConfigException(
"Cannot continue as no action was predicted by the dialogue manager. "
"This can happen if you trained the assistant with no policy included "
"in the configuration. If so, please re-train the assistant with at "
f"least one policy ({DOCS_URL_POLICIES}) included in the configuration."
)
probabilities = [prediction["score"] for prediction in predictions]
pred_out = int(np.argmax(probabilities))
action_name = predictions[pred_out].get("action")
policy = result.get("policy")
confidence = result.get("confidence")
await _print_history(conversation_id, endpoint)
await _plot_trackers(
conversation_ids,
plot_file,
endpoint,
unconfirmed=[ActionExecuted(action_name)],
)
listen = await _validate_action(
action_name, policy, confidence, predictions, endpoint, conversation_id
)
await _plot_trackers(conversation_ids, plot_file, endpoint)
tracker_dump = await retrieve_tracker(
endpoint, conversation_id, EventVerbosity.AFTER_RESTART
)
events = tracker_dump.get("events", [])
if len(events) >= 2:
last_event = events[-2] # last event before action_listen
# if bot message includes buttons the user will get a list choice to reply
# the list choice is displayed in place of action listen
if last_event.get("event") == BotUttered.type_name and last_event["data"].get(
"buttons", None