-
Notifications
You must be signed in to change notification settings - Fork 75
/
Copy pathturtle_agent.py
executable file
·311 lines (261 loc) · 10 KB
/
turtle_agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
#!/usr/bin/env python3.9
# Copyright (c) 2024. Jet Propulsion Laboratory. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import os
from datetime import datetime
import dotenv
import pyinputplus as pyip
import rospy
from langchain.agents import tool, Tool
# from langchain_ollama import ChatOllama
from rich.console import Console
from rich.console import Group
from rich.live import Live
from rich.markdown import Markdown
from rich.panel import Panel
from rich.text import Text
from rosa import ROSA
import tools.turtle as turtle_tools
from help import get_help
from llm import get_llm
from prompts import get_prompts
# Typical method for defining tools in ROSA
@tool
def cool_turtle_tool():
"""A cool turtle tool that doesn't really do anything."""
return "This is a cool turtle tool! It doesn't do anything, but it's cool."
class TurtleAgent(ROSA):
def __init__(self, streaming: bool = False, verbose: bool = True):
self.__blacklist = ["master", "docker"]
self.__prompts = get_prompts()
self.__llm = get_llm(streaming=streaming)
# self.__llm = ChatOllama(
# base_url="host.docker.internal:11434",
# model="llama3.1",
# temperature=0,
# num_ctx=8192,
# )
self.__streaming = streaming
# Another method for adding tools
blast_off = Tool(
name="blast_off",
func=self.blast_off,
description="Make the turtle blast off!",
)
super().__init__(
ros_version=1,
llm=self.__llm,
tools=[cool_turtle_tool, blast_off],
tool_packages=[turtle_tools],
blacklist=self.__blacklist,
prompts=self.__prompts,
verbose=verbose,
accumulate_chat_history=True,
streaming=streaming,
)
self.examples = [
"Give me a ROS tutorial using the turtlesim.",
"Show me how to move the turtle forward.",
"Draw a 5-point star using the turtle.",
"Teleport to (3, 3) and draw a small hexagon.",
"Give me a list of nodes, topics, services, params, and log files.",
"Change the background color to light blue and the pen color to red.",
]
self.command_handler = {
"help": lambda: self.submit(get_help(self.examples)),
"examples": lambda: self.submit(self.choose_example()),
"clear": lambda: self.clear(),
}
def blast_off(self, input: str):
return f"""
Ok, we're blasting off at the speed of light!
<ROSA_INSTRUCTIONS>
You should now use your tools to make the turtle move around the screen at high speeds.
</ROSA_INSTRUCTIONS>
"""
@property
def greeting(self):
greeting = Text(
"\nHi! I'm the ROSA-TurtleSim agent 🐢🤖. How can I help you today?\n"
)
greeting.stylize("frame bold blue")
greeting.append(
f"Try {', '.join(self.command_handler.keys())} or exit.",
style="italic",
)
return greeting
def choose_example(self):
"""Get user selection from the list of examples."""
return pyip.inputMenu(
self.examples,
prompt="\nEnter your choice and press enter: \n",
numbered=True,
blank=False,
timeout=60,
default="1",
)
async def clear(self):
"""Clear the chat history."""
self.clear_chat()
self.last_events = []
self.command_handler.pop("info", None)
os.system("clear")
def get_input(self, prompt: str):
"""Get user input from the console."""
return pyip.inputStr(prompt, default="help")
async def run(self):
"""
Run the TurtleAgent's main interaction loop.
This method initializes the console interface and enters a continuous loop to handle user input.
It processes various commands including 'help', 'examples', 'clear', and 'exit', as well as
custom user queries. The method uses asynchronous operations to stream responses and maintain
a responsive interface.
The loop continues until the user inputs 'exit'.
Returns:
None
Raises:
Any exceptions that might occur during the execution of user commands or streaming responses.
"""
await self.clear()
console = Console()
while True:
console.print(self.greeting)
input = self.get_input("> ")
# Handle special commands
if input == "exit":
break
elif input in self.command_handler:
await self.command_handler[input]()
else:
await self.submit(input)
async def submit(self, query: str):
if self.__streaming:
await self.stream_response(query)
else:
self.print_response(query)
def print_response(self, query: str):
"""
Submit the query to the agent and print the response to the console.
Args:
query (str): The input query to process.
Returns:
None
"""
response = self.invoke(query)
console = Console()
content_panel = None
with Live(
console=console, auto_refresh=True, vertical_overflow="visible"
) as live:
content_panel = Panel(
Markdown(response), title="Final Response", border_style="green"
)
live.update(content_panel, refresh=True)
async def stream_response(self, query: str):
"""
Stream the agent's response with rich formatting.
This method processes the agent's response in real-time, updating the console
with formatted output for tokens and keeping track of events.
Args:
query (str): The input query to process.
Returns:
None
Raises:
Any exceptions raised during the streaming process.
"""
console = Console()
content = ""
self.last_events = []
panel = Panel("", title="Streaming Response", border_style="green")
with Live(panel, console=console, auto_refresh=False) as live:
async for event in self.astream(query):
event["timestamp"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[
:-3
]
if event["type"] == "token":
content += event["content"]
panel.renderable = Markdown(content)
live.refresh()
elif event["type"] in ["tool_start", "tool_end", "error"]:
self.last_events.append(event)
elif event["type"] == "final":
content = event["content"]
if self.last_events:
panel.renderable = Markdown(
content
+ "\n\nType 'info' for details on how I got my answer."
)
else:
panel.renderable = Markdown(content)
panel.title = "Final Response"
live.refresh()
if self.last_events:
self.command_handler["info"] = self.show_event_details
else:
self.command_handler.pop("info", None)
async def show_event_details(self):
"""
Display detailed information about the events that occurred during the last query.
"""
console = Console()
if not self.last_events:
console.print("[yellow]No events to display.[/yellow]")
return
else:
console.print(Markdown("# Tool Usage and Events"))
for event in self.last_events:
timestamp = event["timestamp"]
if event["type"] == "tool_start":
console.print(
Panel(
Group(
Text(f"Input: {event.get('input', 'None')}"),
Text(f"Timestamp: {timestamp}", style="dim"),
),
title=f"Tool Started: {event['name']}",
border_style="blue",
)
)
elif event["type"] == "tool_end":
console.print(
Panel(
Group(
Text(f"Output: {event.get('output', 'N/A')}"),
Text(f"Timestamp: {timestamp}", style="dim"),
),
title=f"Tool Completed: {event['name']}",
border_style="green",
)
)
elif event["type"] == "error":
console.print(
Panel(
Group(
Text(f"Error: {event['content']}", style="bold red"),
Text(f"Timestamp: {timestamp}", style="dim"),
),
border_style="red",
)
)
console.print()
console.print("[bold]End of events[/bold]\n")
def main():
dotenv.load_dotenv(dotenv.find_dotenv())
streaming = rospy.get_param("~streaming", False)
turtle_agent = TurtleAgent(verbose=False, streaming=streaming)
asyncio.run(turtle_agent.run())
if __name__ == "__main__":
rospy.init_node("rosa", log_level=rospy.INFO)
main()