Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ADAG] Add visualization of compiled graphs #47958

Merged
merged 16 commits into from
Oct 24, 2024
129 changes: 129 additions & 0 deletions python/ray/dag/compiled_dag_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -2000,6 +2000,135 @@ async def execute_async(
self._execution_index += 1
return fut

def visualize(
self, filename="compiled_graph", format="png", view=False, return_dot=False
):
"""
Visualize the compiled graph using Graphviz.

This method generates a graphical representation of the compiled graph,
showing tasks and their dependencies.This method should be called
**after** the graph has been compiled using `experimental_compile()`.

Args:
filename: The name of the output file (without extension).
format: The format of the output file (e.g., 'png', 'pdf').
view: Whether to open the file with the default viewer.
return_dot: If True, returns the DOT source as a string instead of figure.

Raises:
ValueError: If the graph is empty or not properly compiled.
ImportError: If the `graphviz` package is not installed.

"""
import graphviz
from ray.dag import (
InputAttributeNode,
InputNode,
MultiOutputNode,
ClassMethodNode,
DAGNode,
)

# Check that the DAG has been compiled
if not hasattr(self, "idx_to_task") or not self.idx_to_task:
raise ValueError(
"The DAG must be compiled before calling 'visualize()'. "
"Please call 'experimental_compile()' first."
)

# Check that each CompiledTask has a valid dag_node
for idx, task in self.idx_to_task.items():
if not hasattr(task, "dag_node") or not isinstance(task.dag_node, DAGNode):
raise ValueError(
f"Task at index {idx} does not have a valid 'dag_node'. "
"Ensure that 'experimental_compile()' completed successfully."
)

# Dot file for debuging
dot = graphviz.Digraph(name="compiled_graph", format=format)

# Add nodes with task information
for idx, task in self.idx_to_task.items():
dag_node = task.dag_node

# Initialize the label and attributes
label = f"Task {idx}\n"
shape = "oval" # Default shape
style = "filled"
fillcolor = ""

# Handle different types of dag_node
if isinstance(dag_node, InputNode):
label += "InputNode"
shape = "rectangle"
fillcolor = "lightblue"
elif isinstance(dag_node, InputAttributeNode):
label += f"InputAttributeNode[{dag_node.key}]"
shape = "rectangle"
fillcolor = "lightblue"
elif isinstance(dag_node, MultiOutputNode):
label += "MultiOutputNode"
shape = "rectangle"
fillcolor = "yellow"
elif isinstance(dag_node, ClassMethodNode):
if dag_node.is_class_method_call:
# Class Method Call Node
method_name = dag_node.get_method_name()
actor_handle = dag_node._get_actor_handle()
if actor_handle:
actor_id = actor_handle._actor_id.hex()
label += f"Actor: {actor_id[:6]}...\nMethod: {method_name}"
else:
label += f"Method: {method_name}"
shape = "oval"
fillcolor = "lightgreen"
elif dag_node.is_class_method_output:
# Class Method Output Node
label += f"ClassMethodOutputNode[{dag_node.output_idx}]"
shape = "rectangle"
fillcolor = "orange"
else:
# Unexpected ClassMethodNode
label += "ClassMethodNode"
shape = "diamond"
fillcolor = "red"
else:
# Unexpected node type
label += type(dag_node).__name__
shape = "diamond"
fillcolor = "red"

# Add the node to the graph with attributes
dot.node(str(idx), label, shape=shape, style=style, fillcolor=fillcolor)

# Add edges with type hints based on argument mappings
for idx, task in self.idx_to_task.items():
current_task_idx = idx

for arg_index, arg in enumerate(task.dag_node.get_args()):
if isinstance(arg, DAGNode):
# Get the upstream task index
upstream_task_idx = self.dag_node_to_idx[arg]

# Get the type hint for this argument
if arg_index < len(task.arg_type_hints):
type_hint = type(task.arg_type_hints[arg_index]).__name__
else:
type_hint = "UnknownType"

# Draw an edge from the upstream task to the
# current task with the type hint
dot.edge(
str(upstream_task_idx), str(current_task_idx), label=type_hint
)

if return_dot:
return dot.source
else:
# Render the graph to a file
dot.render(filename, view=view)

def teardown(self):
"""Teardown and cancel all actor tasks for this DAG. After this
function returns, the actors should be available to execute new tasks
Expand Down
171 changes: 171 additions & 0 deletions python/ray/dag/tests/experimental/test_accelerated_dag.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are test failures for this file. Please fix.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect that graphviz is not available on the test environment so we need to pip install graphviz and sudo apt-get install graphviz for visualization test. Do you know how we can achieve that in the test environments? I need to add dependency to which file for pip and apt-get? @ruisearch42

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can introduce something like a test requirement.txt , something like python/requirements/ml/data-test-requirements.txt
If that is too complex, we can skip the test for now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just skip these tests and I think it's OK now.

Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@

logger = logging.getLogger(__name__)

try:
import pydot
except Exception:
logging.info("pydot is not installed, visualization tests will be skiped")

pytestmark = [
pytest.mark.skipif(
Expand Down Expand Up @@ -2493,6 +2497,173 @@ async def main():
compiled_dag.teardown()


class TestVisualization:
Bye-legumes marked this conversation as resolved.
Show resolved Hide resolved

"""Tests for the visualize method of compiled DAGs."""

# TODO(zhilong): "pip intsall pydot"
# and "sudo apt-get install graphviz " to run test.
@pytest.fixture(autouse=True)
def skip_if_pydot_graphviz_not_available(self):
# Skip the test if pydot or graphviz is not available
pytest.importorskip("pydot")
pytest.importorskip("graphviz")

def test_visualize_basic(self, ray_start_regular):
"""
Expect output or dot_source:
MultiOutputNode" fillcolor=yellow shape=rectangle style=filled]
0 -> 1 [label=SharedMemoryType]
1 -> 2 [label=SharedMemoryType]
"""

@ray.remote
class Actor:
def echo(self, x):
return x

actor = Actor.remote()

with InputNode() as i:
dag = actor.echo.bind(i)

compiled_dag = dag.experimental_compile()

# Call the visualize method
dot_source = compiled_dag.visualize(return_dot=True)

graphs = pydot.graph_from_dot_data(dot_source)
graph = graphs[0]

node_names = {node.get_name() for node in graph.get_nodes()}
edge_pairs = {
(edge.get_source(), edge.get_destination()) for edge in graph.get_edges()
}

expected_nodes = {"0", "1", "2"}
assert expected_nodes.issubset(
node_names
), f"Expected nodes {expected_nodes} not found."

expected_edges = {("0", "1"), ("1", "2")}
assert expected_edges.issubset(
edge_pairs
), f"Expected edges {expected_edges} not found."

compiled_dag.teardown()

def test_visualize_multi_return(self, ray_start_regular):
"""
Expect output or dot_source:
MultiOutputNode" fillcolor=yellow shape=rectangle style=filled]
0 -> 1 [label=SharedMemoryType]
1 -> 2 [label=SharedMemoryType]
1 -> 3 [label=SharedMemoryType]
2 -> 4 [label=SharedMemoryType]
3 -> 4 [label=SharedMemoryType]
"""

@ray.remote
class Actor:
@ray.method(num_returns=2)
def return_two(self, x):
return x, x + 1

actor = Actor.remote()

with InputNode() as i:
o1, o2 = actor.return_two.bind(i)
dag = MultiOutputNode([o1, o2])

compiled_dag = dag.experimental_compile()

# Get the DOT source
dot_source = compiled_dag.visualize(return_dot=True)

graphs = pydot.graph_from_dot_data(dot_source)
graph = graphs[0]

node_names = {node.get_name() for node in graph.get_nodes()}
edge_pairs = {
(edge.get_source(), edge.get_destination()) for edge in graph.get_edges()
}

expected_nodes = {"0", "1", "2", "3", "4"}
assert expected_nodes.issubset(
node_names
), f"Expected nodes {expected_nodes} not found."

expected_edges = {("0", "1"), ("1", "2"), ("1", "3"), ("2", "4"), ("3", "4")}
assert expected_edges.issubset(
edge_pairs
), f"Expected edges {expected_edges} not found."

compiled_dag.teardown()

def test_visualize_multi_return2(self, ray_start_regular):
"""
Expect output or dot_source:
MultiOutputNode" fillcolor=yellow shape=rectangle style=filled]
0 -> 1 [label=SharedMemoryType]
1 -> 2 [label=SharedMemoryType]
1 -> 3 [label=SharedMemoryType]
2 -> 4 [label=SharedMemoryType]
3 -> 5 [label=SharedMemoryType]
4 -> 6 [label=SharedMemoryType]
5 -> 6 [label=SharedMemoryType]
"""

@ray.remote
class Actor:
@ray.method(num_returns=2)
def return_two(self, x):
return x, x + 1

def echo(self, x):
return x

a = Actor.remote()
b = Actor.remote()
with InputNode() as i:
o1, o2 = a.return_two.bind(i)
o3 = b.echo.bind(o1)
o4 = b.echo.bind(o2)
dag = MultiOutputNode([o3, o4])

compiled_dag = dag.experimental_compile()

# Get the DOT source
dot_source = compiled_dag.visualize(return_dot=True)

graphs = pydot.graph_from_dot_data(dot_source)
graph = graphs[0]

node_names = {node.get_name() for node in graph.get_nodes()}
edge_pairs = {
(edge.get_source(), edge.get_destination()) for edge in graph.get_edges()
}

expected_nodes = {"0", "1", "2", "3", "4", "5", "6"}
assert expected_nodes.issubset(
node_names
), f"Expected nodes {expected_nodes} not found."

expected_edges = {
("0", "1"),
("1", "2"),
("1", "3"),
("2", "4"),
("3", "5"),
("4", "6"),
("5", "6"),
}
assert expected_edges.issubset(
edge_pairs
), f"Expected edges {expected_edges} not found."

compiled_dag.teardown()


if __name__ == "__main__":
if os.environ.get("PARALLEL_CI"):
sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__]))
Expand Down