forked from starpig1129/AI-Data-Analysis-MultiAgent
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrouter.py
124 lines (101 loc) · 4.51 KB
/
router.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
from state import State
from typing import Literal
from langchain_core.messages import AIMessage
import logging
import json
# Set up logger
logger = logging.getLogger(__name__)
# Define types for node routing
NodeType = Literal['Visualization', 'Search', 'Coder', 'Report', 'Process', 'NoteTaker', 'Hypothesis', 'QualityReview']
ProcessNodeType = Literal['Coder', 'Search', 'Visualization', 'Report', 'Process', 'Refiner']
def hypothesis_router(state: State) -> NodeType:
"""
Route based on the presence of a hypothesis in the state.
Args:
state (State): The current state of the system.
Returns:
NodeType: 'Hypothesis' if no hypothesis exists, otherwise 'Process'.
"""
logger.info("Entering hypothesis_router")
hypothesis = state.get("hypothesis")
if isinstance(hypothesis, AIMessage):
hypothesis_content = hypothesis.content
logger.debug("Hypothesis is an AIMessage")
elif isinstance(hypothesis, str):
hypothesis_content = hypothesis
logger.debug("Hypothesis is a string")
else:
hypothesis_content = ""
logger.warning(f"Unexpected hypothesis type: {type(hypothesis)}")
result = "Hypothesis" if not hypothesis_content.strip() else "Process"
logger.info(f"hypothesis_router decision: {result}")
return result
def QualityReview_router(state: State) -> NodeType:
"""
Route based on the quality review outcome and process decision.
Args:
state (State): The current state of the system.
Returns:
NodeType: The next node to route to based on the quality review and process decision.
"""
logger.info("Entering QualityReview_router")
messages = state.get("messages", [])
last_message = messages[-1] if messages else None
# Check if revision is needed
if (last_message and 'REVISION' in str(last_message.content)) or state.get("needs_revision", False):
previous_node = state.get("last_sender", "")
revision_routes = {
"Visualization": "Visualization",
"Search": "Search",
"Coder": "Coder",
"Report": "Report"
}
result = revision_routes.get(previous_node, "NoteTaker")
logger.info(f"Revision needed. Routing to: {result}")
return result
else:
return "NoteTaker"
def process_router(state: State) -> ProcessNodeType:
"""
Route based on the process decision in the state.
Args:
state (State): The current state of the system.
Returns:
ProcessNodeType: The next process node to route to based on the process decision.
"""
logger.info("Entering process_router")
process_decision = state.get("process_decision", "")
# Handle AIMessage object
if isinstance(process_decision, AIMessage):
logger.debug("Process decision is an AIMessage")
try:
# Attempt to parse JSON in content
decision_dict = json.loads(process_decision.content.replace("'", '"'))
process_decision = decision_dict.get('next', '')
logger.debug(f"Parsed process decision from JSON: {process_decision}")
except json.JSONDecodeError:
# If JSON parsing fails, use content directly
process_decision = process_decision.content
logger.warning("Failed to parse process decision as JSON. Using content directly.")
elif isinstance(process_decision, dict):
process_decision = process_decision.get('next', '')
logger.debug(f"Process decision is a dictionary. Using 'next' value: {process_decision}")
elif not isinstance(process_decision, str):
process_decision = str(process_decision)
logger.warning(f"Unexpected process decision type. Converting to string: {process_decision}")
# Define valid decisions
valid_decisions = {"Coder", "Search", "Visualization", "Report"}
if process_decision in valid_decisions:
logger.info(f"Valid process decision: {process_decision}")
return process_decision
if process_decision == "FINISH":
logger.info("Process decision is FINISH. Ending process.")
return "Refiner"
# If process_decision is empty or not a valid decision, return "Process"
if not process_decision or process_decision not in valid_decisions:
logger.warning(f"Invalid or empty process decision: {process_decision}. Defaulting to 'Process'.")
return "Process"
# Default to "Process"
logger.info("Defaulting to 'Process'")
return "Process"
logger.info("Router module initialized")