diff --git a/crates/computegraph/src/lib.rs b/crates/computegraph/src/lib.rs index 09506bf4..9b02deb4 100644 --- a/crates/computegraph/src/lib.rs +++ b/crates/computegraph/src/lib.rs @@ -859,446 +859,3 @@ pub trait NodeFactory: ExecutableNode { /// A handle of type `Self::Handle` that can be used to interact with the node. fn create_handle(gnode: &GraphNode) -> Self::Handle; } - -#[cfg(test)] -mod tests { - use anyhow::{anyhow, Result}; - use std::any::TypeId; - - use super::*; - extern crate self as computegraph; - - #[derive(Debug)] - struct TestNodeConstant { - value: usize, - } - - impl TestNodeConstant { - pub const fn new(value: usize) -> Self { - Self { value } - } - } - - #[node(TestNodeConstant)] - fn run(&self) -> usize { - self.value - } - - #[derive(Debug)] - struct TestNodeAddition {} - - impl TestNodeAddition { - pub const fn new() -> Self { - Self {} - } - } - - #[node(TestNodeAddition)] - fn run(&self, a: &usize, b: &usize) -> usize { - *a + *b - } - - #[derive(Debug)] - struct TestNodeNumToString {} - - impl TestNodeNumToString { - pub const fn new() -> Self { - Self {} - } - } - - #[node(TestNodeNumToString)] - fn run(&self, input: &usize) -> String { - input.to_string() - } - - #[test] - fn test_basic_graph() -> Result<()> { - let mut graph = ComputeGraph::new(); - let value1 = graph.add_node(TestNodeConstant::new(9), "value1".to_string())?; - let value2 = graph.add_node(TestNodeConstant::new(10), "value2".to_string())?; - - let addition = graph.add_node(TestNodeAddition::new(), "addition".to_string())?; - - graph.connect(value1.output(), addition.input_a())?; - graph.connect(value2.output(), addition.input_b())?; - - assert_eq!(graph.compute(value1.output())?, 9); - assert_eq!(graph.compute(value2.output())?, 10); - assert_eq!(graph.compute(addition.output())?, 19); - - Ok(()) - } - - #[test] - fn test_diamond_dependencies() -> Result<()> { - // Here we will test a more complex graph with two diamond dependencies between nodes. - // The graph will look like this: - // - // value1──┐ - // └─►┌─────────┐ - // │addition1├────────────┐ - // ┌─►└─────────┘ └────►┌─────────┐ - // value2──┤ │addition4│ - // └─►┌─────────┐ ┌─►┌─────────┐ ┌►└────┬────┘ - // │addition2├─┤ │addition3├─┘ │ - // ┌─►└─────────┘ └─►└─────────┘ │ - // value3──┘ ▼ - // result - // - // So The result should be: - let function = |v1: usize, v2: usize, v3: usize| v1 + v2 + 2 * (v2 + v3); - - let mut graph = ComputeGraph::new(); - let value1 = graph.add_node(TestNodeConstant::new(5), "value1".to_string())?; - let value2 = graph.add_node(TestNodeConstant::new(7), "value2".to_string())?; - let value3 = graph.add_node(TestNodeConstant::new(3), "value3".to_string())?; - - let addition1 = graph.add_node(TestNodeAddition::new(), "addition1".to_string())?; - let addition2 = graph.add_node(TestNodeAddition::new(), "addition2".to_string())?; - let addition3 = graph.add_node(TestNodeAddition::new(), "addition3".to_string())?; - let addition4 = graph.add_node(TestNodeAddition::new(), "addition4".to_string())?; - - graph.connect(value1.output(), addition1.input_a())?; - graph.connect(value2.output(), addition1.input_b())?; - graph.connect(value2.output(), addition2.input_a())?; - graph.connect(value3.output(), addition2.input_b())?; - graph.connect(addition2.output(), addition3.input_a())?; - graph.connect(addition2.output(), addition3.input_b())?; - graph.connect(addition1.output(), addition4.input_a())?; - graph.connect(addition3.output(), addition4.input_b())?; - - assert_eq!(graph.compute(addition4.output())?, function(5, 7, 3)); - - Ok(()) - } - - #[test] - fn test_invalid_graph_missing_input() -> Result<()> { - let mut graph = ComputeGraph::new(); - let value = graph.add_node(TestNodeConstant::new(5), "value".to_string())?; - let addition = graph.add_node(TestNodeAddition::new(), "addition".to_string())?; - - graph.connect(value.output(), addition.input_a())?; - - match graph.compute(addition.output()) { - Err(ComputeError::InputPortNotConnected(err)) => { - assert_eq!(err.node, addition.handle); - assert_eq!(err.input_name, "b"); - } - _ => panic!("Expected ComputeError::InputPortNotConnected"), - } - - Ok(()) - } - - #[test] - fn test_invalid_graph_type_mismatch() -> Result<()> { - let mut graph = ComputeGraph::new(); - let value = graph.add_node(TestNodeConstant::new(5), "value".to_string())?; - let to_string = graph.add_node(TestNodeNumToString::new(), "to_string".to_string())?; - let addition = graph.add_node(TestNodeAddition::new(), "addition".to_string())?; - - graph.connect(value.output(), to_string.input())?; - graph.connect(value.output(), addition.input_a())?; - let res = graph.connect_untyped(to_string.output().into(), addition.input_b().into()); - match res { - Err(ConnectError::TypeMismatch { expected, found }) => { - assert_eq!(expected, TypeId::of::()); - assert_eq!(found, TypeId::of::()); - } - _ => panic!("Expected ConnectError::TypeMismatch"), - } - - Ok(()) - } - - #[test] - fn test_cycle_detection() -> Result<()> { - let mut graph = ComputeGraph::new(); - let value = graph.add_node(TestNodeConstant::new(5), "value".to_string())?; - let node1 = graph.add_node(TestNodeAddition::new(), "node1".to_string())?; - let node2 = graph.add_node(TestNodeAddition::new(), "node2".to_string())?; - let node3 = graph.add_node(TestNodeAddition::new(), "node3".to_string())?; - - graph.connect(node1.output(), node2.input_a())?; - graph.connect(node2.output(), node3.input_a())?; - graph.connect(node3.output(), node1.input_a())?; - - graph.connect(value.output(), node2.input_b())?; - graph.connect(value.output(), node3.input_b())?; - graph.connect(value.output(), node1.input_b())?; - - // The graph contains a cycle: node1 -> node2 -> node3 -> node1 -> ... - let result = graph.compute(node1.output()); - assert!(result.is_err()); - - Ok(()) - } - - #[test] - fn test_edge_disconnection() -> Result<()> { - let mut graph = ComputeGraph::new(); - let value = graph.add_node(TestNodeConstant::new(5), "value".to_string())?; - let one = graph.add_node(TestNodeConstant::new(1), "one".to_string())?; - let addition = graph.add_node(TestNodeAddition::new(), "addition".to_string())?; - let to_string = graph.add_node(TestNodeNumToString::new(), "to_string".to_string())?; - - let value_to_addition = graph.connect(value.output(), addition.input_a())?; - graph.connect(one.output(), addition.input_b())?; - graph.connect(addition.output(), to_string.input())?; - - // Test that the graph works before disconnecting the edge - assert_eq!(graph.compute(to_string.output())?, "6".to_string()); - - // Disconnect the edge between value and addition nodes - graph.disconnect(&value_to_addition)?; - - // Test that the graph fails after disconnecting the edge with the expected error - match graph.compute(to_string.output()) { - Err(ComputeError::InputPortNotConnected(port)) => { - assert_eq!(port.node, addition.handle); - assert_eq!(port.input_name, "a"); - } - _ => panic!("Expected ComputeError::InputPortNotConnected"), - } - - // Now reconnect the edge and test that the graph works again - graph.connect(value.output(), addition.input_a())?; - assert_eq!(graph.compute(to_string.output())?, "6".to_string()); - - Ok(()) - } - - #[test] - fn test_disconnected_subgraphs() -> Result<()> { - let mut graph = ComputeGraph::new(); - - // Subgraph 1: Addition - let value1 = graph.add_node(TestNodeConstant::new(5), "value1".to_string())?; - let value2 = graph.add_node(TestNodeConstant::new(7), "value2".to_string())?; - let addition1 = graph.add_node(TestNodeAddition::new(), "addition1".to_string())?; - graph.connect(value1.output(), addition1.input_a())?; - graph.connect(value2.output(), addition1.input_b())?; - - // Subgraph 2: Addition - let value3 = graph.add_node(TestNodeConstant::new(3), "value3".to_string())?; - let value4 = graph.add_node(TestNodeConstant::new(4), "value4".to_string())?; - let addition2 = graph.add_node(TestNodeAddition::new(), "addition2".to_string())?; - graph.connect(value3.output(), addition2.input_a())?; - graph.connect(value4.output(), addition2.input_b())?; - - // Compute the results of the disconnected subgraphs independently - assert_eq!(graph.compute(addition1.output())?, 12); - - assert_eq!(graph.compute(addition2.output())?, 7); - - Ok(()) - } - - #[test] - fn test_node_removal() -> Result<()> { - let mut graph = ComputeGraph::new(); - let value1 = graph.add_node(TestNodeConstant::new(5), "value1".to_string())?; - let value2 = graph.add_node(TestNodeConstant::new(7), "value2".to_string())?; - let addition = graph.add_node(TestNodeAddition::new(), "addition".to_string())?; - - graph.connect(value1.output(), addition.input_a())?; - graph.connect(value2.output(), addition.input_b())?; - - // Compute the result before removing a node - assert_eq!(graph.compute(addition.output())?, 12); - - // Remove the 'value2' node from the graph - graph.remove_node(value2.handle)?; - - // After removing 'value2', the 'addition' node should have a missing input - match graph.compute(addition.output()) { - Err(ComputeError::InputPortNotConnected(port)) => { - assert_eq!(port.node, addition.handle); - assert_eq!(port.input_name, "b"); - } - _ => panic!("Expected ComputeError::InputPortNotConnected"), - } - - // Ensure that the 'value1' node can still be computed - assert_eq!(graph.compute(value1.output())?, 5); - - // Now connect value1 to both inputs of the addition node - graph.connect(value1.output(), addition.input_b())?; - - // Compute the result after reconnecting the edge - assert_eq!(graph.compute(addition.output())?, 10); - - Ok(()) - } - - #[test] - fn test_connect_already_connected() -> Result<()> { - let mut graph = ComputeGraph::new(); - - let value1 = graph.add_node(TestNodeConstant::new(5), "value1".to_string())?; - let value2 = graph.add_node(TestNodeConstant::new(7), "value2".to_string())?; - let to_string = graph.add_node(TestNodeNumToString::new(), "to_string".to_string())?; - - graph.connect(value1.output(), to_string.input())?; - let res = graph.connect(value2.output(), to_string.input()); - match res { - Err(ConnectError::InputPortAlreadyConnected { from, to }) => { - assert_eq!(from.node, value2.handle); - assert_eq!(to.node, to_string.handle); - } - _ => panic!("Expected ConnectError::InputPortAlreadyConnected"), - } - - Ok(()) - } - - #[test] - fn test_duplicate_node_names() -> Result<()> { - let mut graph = ComputeGraph::new(); - - graph.add_node(TestNodeConstant::new(5), "value".to_string())?; - match graph.add_node(TestNodeConstant::new(7), "value".to_string()) { - Err(AddError::DuplicateName(name)) => { - assert_eq!(name, "value"); - } - _ => panic!("Expected AddError::DuplicateName"), - } - - Ok(()) - } - - #[test] - fn test_metadata() -> Result<()> { - #[derive(Debug, PartialEq)] - struct SomeMetadata; - #[derive(Debug, PartialEq)] - struct OtherMetadata(usize); - - let mut graph = ComputeGraph::new(); - let value = graph.add_node(TestNodeConstant::new(5), "value".to_string())?; - let value_node = graph - .get_node_mut(&value.handle) - .ok_or_else(|| anyhow!("value node not found"))?; - - assert_eq!(value_node.metadata.get::(), None); - value_node.metadata.insert(SomeMetadata); - assert_eq!( - value_node.metadata.get::(), - Some(&SomeMetadata) - ); - value_node.metadata.remove::(); - value_node.metadata.insert(OtherMetadata(42)); - - let value_node = graph - .get_node(&value.handle) - .ok_or_else(|| anyhow!("value node not found"))?; - assert_eq!(value_node.metadata.get::(), None); - assert_eq!(value_node.metadata.get(), Some(&OtherMetadata(42))); - Ok(()) - } - - #[test] - fn test_macro_node() { - #[derive(Debug)] - struct Node1 {} - #[node(Node1)] - fn run(&self) {} - - #[derive(Debug)] // TODO: why do we need this? - struct Node2 {} - #[node(Node2)] - fn run(&self) -> usize { - 21 - } - - #[derive(Debug)] - struct Node3 {} - #[node(Node3 -> hello)] - fn run(&self) -> String { - "hello".to_string() - } - - #[derive(Debug)] - struct Node4 {} - - #[node(Node4 -> (hello, world))] - fn run(&self) -> (String, String) { - ("hello".to_string(), "world".to_string()) - } - - #[derive(Debug)] - struct Node5 {} - #[node(Node5)] - fn run(&self, input: &usize) -> usize { - *input - } - - #[derive(Debug)] - struct Node6 {} - #[node(Node6 -> output)] - fn run(&self, text: &String, repeat_count: &usize) -> String { - text.repeat(*repeat_count) - } - - // TODO: generics support - - assert_eq!(::inputs(), vec![]); - assert_eq!(::outputs(), vec![]); - let res = ExecutableNode::run(&Node1 {}, &[]); - assert_eq!(res.len(), 0); - - assert_eq!(::inputs(), vec![]); - assert_eq!( - ::outputs(), - vec![("output", TypeId::of::())] - ); - - assert_eq!(::inputs(), vec![]); - assert_eq!( - ::outputs(), - vec![("hello", TypeId::of::())] - ); - - assert_eq!(::inputs(), vec![]); - assert_eq!( - ::outputs(), - vec![ - ("hello", TypeId::of::()), - ("world", TypeId::of::()) - ] - ); - let res = ExecutableNode::run(&Node4 {}, &[]); - assert_eq!(res.len(), 2); - assert_eq!(res[0].downcast_ref::().unwrap(), "hello"); - assert_eq!(res[1].downcast_ref::().unwrap(), "world"); - - assert_eq!( - ::inputs(), - vec![("input", TypeId::of::())] - ); - assert_eq!( - ::outputs(), - vec![("output", TypeId::of::())] - ); - let res = ExecutableNode::run(&Node6 {}, &[Box::new("hi".to_string()), Box::new(3_usize)]); - assert_eq!(res.len(), 1); - assert_eq!(res[0].downcast_ref::().unwrap(), "hihihi"); - - assert_eq!( - ::inputs(), - vec![ - ("text", TypeId::of::()), - ("repeat_count", TypeId::of::()) - ] - ); - assert_eq!( - ::outputs(), - vec![("output", TypeId::of::())] - ); - let res = ExecutableNode::run(&Node6 {}, &[Box::new("hi".to_string()), Box::new(3_usize)]); - assert_eq!(res.len(), 1); - assert_eq!(res[0].downcast_ref::().unwrap(), "hihihi"); - } -} diff --git a/crates/computegraph/tests/common/mod.rs b/crates/computegraph/tests/common/mod.rs new file mode 100644 index 00000000..97506400 --- /dev/null +++ b/crates/computegraph/tests/common/mod.rs @@ -0,0 +1,46 @@ +#![allow(dead_code)] +use computegraph::*; + +#[derive(Debug)] +pub struct TestNodeConstant { + value: usize, +} + +impl TestNodeConstant { + pub const fn new(value: usize) -> Self { + Self { value } + } +} + +#[node(TestNodeConstant)] +fn run(&self) -> usize { + self.value +} + +#[derive(Debug)] +pub struct TestNodeAddition {} + +impl TestNodeAddition { + pub const fn new() -> Self { + Self {} + } +} + +#[node(TestNodeAddition)] +fn run(&self, a: &usize, b: &usize) -> usize { + *a + *b +} + +#[derive(Debug)] +pub struct TestNodeNumToString {} + +impl TestNodeNumToString { + pub const fn new() -> Self { + Self {} + } +} + +#[node(TestNodeNumToString)] +fn run(&self, input: &usize) -> String { + input.to_string() +} diff --git a/crates/computegraph/tests/dependencies.rs b/crates/computegraph/tests/dependencies.rs new file mode 100644 index 00000000..a1a1be96 --- /dev/null +++ b/crates/computegraph/tests/dependencies.rs @@ -0,0 +1,156 @@ +mod common; +use std::any::TypeId; + +use anyhow::Result; +use common::*; +use computegraph::*; + +#[test] +fn test_basic_graph() -> Result<()> { + let mut graph = ComputeGraph::new(); + let value1 = graph.add_node(TestNodeConstant::new(9), "value1".to_string())?; + let value2 = graph.add_node(TestNodeConstant::new(10), "value2".to_string())?; + + let addition = graph.add_node(TestNodeAddition::new(), "addition".to_string())?; + + graph.connect(value1.output(), addition.input_a())?; + graph.connect(value2.output(), addition.input_b())?; + + assert_eq!(graph.compute(value1.output())?, 9); + assert_eq!(graph.compute(value2.output())?, 10); + assert_eq!(graph.compute(addition.output())?, 19); + + Ok(()) +} + +#[test] +fn test_diamond_dependencies() -> Result<()> { + // Here we will test a more complex graph with two diamond dependencies between nodes. + // The graph will look like this: + // + // value1──┐ + // └─►┌─────────┐ + // │addition1├────────────┐ + // ┌─►└─────────┘ └────►┌─────────┐ + // value2──┤ │addition4│ + // └─►┌─────────┐ ┌─►┌─────────┐ ┌►└────┬────┘ + // │addition2├─┤ │addition3├─┘ │ + // ┌─►└─────────┘ └─►└─────────┘ │ + // value3──┘ ▼ + // result + // + // So The result should be: + let function = |v1: usize, v2: usize, v3: usize| v1 + v2 + 2 * (v2 + v3); + + let mut graph = ComputeGraph::new(); + let value1 = graph.add_node(TestNodeConstant::new(5), "value1".to_string())?; + let value2 = graph.add_node(TestNodeConstant::new(7), "value2".to_string())?; + let value3 = graph.add_node(TestNodeConstant::new(3), "value3".to_string())?; + + let addition1 = graph.add_node(TestNodeAddition::new(), "addition1".to_string())?; + let addition2 = graph.add_node(TestNodeAddition::new(), "addition2".to_string())?; + let addition3 = graph.add_node(TestNodeAddition::new(), "addition3".to_string())?; + let addition4 = graph.add_node(TestNodeAddition::new(), "addition4".to_string())?; + + graph.connect(value1.output(), addition1.input_a())?; + graph.connect(value2.output(), addition1.input_b())?; + graph.connect(value2.output(), addition2.input_a())?; + graph.connect(value3.output(), addition2.input_b())?; + graph.connect(addition2.output(), addition3.input_a())?; + graph.connect(addition2.output(), addition3.input_b())?; + graph.connect(addition1.output(), addition4.input_a())?; + graph.connect(addition3.output(), addition4.input_b())?; + + assert_eq!(graph.compute(addition4.output())?, function(5, 7, 3)); + + Ok(()) +} + +#[test] +fn test_invalid_graph_missing_input() -> Result<()> { + let mut graph = ComputeGraph::new(); + let value = graph.add_node(TestNodeConstant::new(5), "value".to_string())?; + let addition = graph.add_node(TestNodeAddition::new(), "addition".to_string())?; + + graph.connect(value.output(), addition.input_a())?; + + match graph.compute(addition.output()) { + Err(ComputeError::InputPortNotConnected(err)) => { + assert_eq!(err.node, addition.handle); + assert_eq!(err.input_name, "b"); + } + _ => panic!("Expected ComputeError::InputPortNotConnected"), + } + + Ok(()) +} + +#[test] +fn test_invalid_graph_type_mismatch() -> Result<()> { + let mut graph = ComputeGraph::new(); + let value = graph.add_node(TestNodeConstant::new(5), "value".to_string())?; + let to_string = graph.add_node(TestNodeNumToString::new(), "to_string".to_string())?; + let addition = graph.add_node(TestNodeAddition::new(), "addition".to_string())?; + + graph.connect(value.output(), to_string.input())?; + graph.connect(value.output(), addition.input_a())?; + let res = graph.connect_untyped(to_string.output().into(), addition.input_b().into()); + match res { + Err(ConnectError::TypeMismatch { expected, found }) => { + assert_eq!(expected, TypeId::of::()); + assert_eq!(found, TypeId::of::()); + } + _ => panic!("Expected ConnectError::TypeMismatch"), + } + + Ok(()) +} + +#[test] +fn test_cycle_detection() -> Result<()> { + let mut graph = ComputeGraph::new(); + let value = graph.add_node(TestNodeConstant::new(5), "value".to_string())?; + let node1 = graph.add_node(TestNodeAddition::new(), "node1".to_string())?; + let node2 = graph.add_node(TestNodeAddition::new(), "node2".to_string())?; + let node3 = graph.add_node(TestNodeAddition::new(), "node3".to_string())?; + + graph.connect(node1.output(), node2.input_a())?; + graph.connect(node2.output(), node3.input_a())?; + graph.connect(node3.output(), node1.input_a())?; + + graph.connect(value.output(), node2.input_b())?; + graph.connect(value.output(), node3.input_b())?; + graph.connect(value.output(), node1.input_b())?; + + // The graph contains a cycle: node1 -> node2 -> node3 -> node1 -> ... + let result = graph.compute(node1.output()); + assert!(result.is_err()); + + Ok(()) +} + +#[test] +fn test_disconnected_subgraphs() -> Result<()> { + let mut graph = ComputeGraph::new(); + + // Subgraph 1: Addition + let value1 = graph.add_node(TestNodeConstant::new(5), "value1".to_string())?; + let value2 = graph.add_node(TestNodeConstant::new(7), "value2".to_string())?; + let addition1 = graph.add_node(TestNodeAddition::new(), "addition1".to_string())?; + graph.connect(value1.output(), addition1.input_a())?; + graph.connect(value2.output(), addition1.input_b())?; + + // Subgraph 2: Addition + let value3 = graph.add_node(TestNodeConstant::new(3), "value3".to_string())?; + let value4 = graph.add_node(TestNodeConstant::new(4), "value4".to_string())?; + let addition2 = graph.add_node(TestNodeAddition::new(), "addition2".to_string())?; + graph.connect(value3.output(), addition2.input_a())?; + graph.connect(value4.output(), addition2.input_b())?; + + // Compute the results of the disconnected subgraphs independently + assert_eq!(graph.compute(addition1.output())?, 12); + + assert_eq!(graph.compute(addition2.output())?, 7); + + Ok(()) +} diff --git a/crates/computegraph/tests/graph_construction.rs b/crates/computegraph/tests/graph_construction.rs new file mode 100644 index 00000000..b28fbe57 --- /dev/null +++ b/crates/computegraph/tests/graph_construction.rs @@ -0,0 +1,41 @@ +mod common; + +use anyhow::Result; +use common::*; +use computegraph::*; + +#[test] +fn test_connect_already_connected() -> Result<()> { + let mut graph = ComputeGraph::new(); + + let value1 = graph.add_node(TestNodeConstant::new(5), "value1".to_string())?; + let value2 = graph.add_node(TestNodeConstant::new(7), "value2".to_string())?; + let to_string = graph.add_node(TestNodeNumToString::new(), "to_string".to_string())?; + + graph.connect(value1.output(), to_string.input())?; + let res = graph.connect(value2.output(), to_string.input()); + match res { + Err(ConnectError::InputPortAlreadyConnected { from, to }) => { + assert_eq!(from.node, value2.handle); + assert_eq!(to.node, to_string.handle); + } + _ => panic!("Expected ConnectError::InputPortAlreadyConnected"), + } + + Ok(()) +} + +#[test] +fn test_duplicate_node_names() -> Result<()> { + let mut graph = ComputeGraph::new(); + + graph.add_node(TestNodeConstant::new(5), "value".to_string())?; + match graph.add_node(TestNodeConstant::new(7), "value".to_string()) { + Err(AddError::DuplicateName(name)) => { + assert_eq!(name, "value"); + } + _ => panic!("Expected AddError::DuplicateName"), + } + + Ok(()) +} diff --git a/crates/computegraph/tests/graph_manipulation.rs b/crates/computegraph/tests/graph_manipulation.rs new file mode 100644 index 00000000..2e3ce284 --- /dev/null +++ b/crates/computegraph/tests/graph_manipulation.rs @@ -0,0 +1,76 @@ +mod common; + +use anyhow::Result; +use common::*; +use computegraph::*; + +#[test] +fn test_edge_disconnection() -> Result<()> { + let mut graph = ComputeGraph::new(); + let value = graph.add_node(TestNodeConstant::new(5), "value".to_string())?; + let one = graph.add_node(TestNodeConstant::new(1), "one".to_string())?; + let addition = graph.add_node(TestNodeAddition::new(), "addition".to_string())?; + let to_string = graph.add_node(TestNodeNumToString::new(), "to_string".to_string())?; + + let value_to_addition = graph.connect(value.output(), addition.input_a())?; + graph.connect(one.output(), addition.input_b())?; + graph.connect(addition.output(), to_string.input())?; + + // Test that the graph works before disconnecting the edge + assert_eq!(graph.compute(to_string.output())?, "6".to_string()); + + // Disconnect the edge between value and addition nodes + graph.disconnect(&value_to_addition)?; + + // Test that the graph fails after disconnecting the edge with the expected error + match graph.compute(to_string.output()) { + Err(ComputeError::InputPortNotConnected(port)) => { + assert_eq!(port.node, addition.handle); + assert_eq!(port.input_name, "a"); + } + _ => panic!("Expected ComputeError::InputPortNotConnected"), + } + + // Now reconnect the edge and test that the graph works again + graph.connect(value.output(), addition.input_a())?; + assert_eq!(graph.compute(to_string.output())?, "6".to_string()); + + Ok(()) +} + +#[test] +fn test_node_removal() -> Result<()> { + let mut graph = ComputeGraph::new(); + let value1 = graph.add_node(TestNodeConstant::new(5), "value1".to_string())?; + let value2 = graph.add_node(TestNodeConstant::new(7), "value2".to_string())?; + let addition = graph.add_node(TestNodeAddition::new(), "addition".to_string())?; + + graph.connect(value1.output(), addition.input_a())?; + graph.connect(value2.output(), addition.input_b())?; + + // Compute the result before removing a node + assert_eq!(graph.compute(addition.output())?, 12); + + // Remove the 'value2' node from the graph + graph.remove_node(value2.handle)?; + + // After removing 'value2', the 'addition' node should have a missing input + match graph.compute(addition.output()) { + Err(ComputeError::InputPortNotConnected(port)) => { + assert_eq!(port.node, addition.handle); + assert_eq!(port.input_name, "b"); + } + _ => panic!("Expected ComputeError::InputPortNotConnected"), + } + + // Ensure that the 'value1' node can still be computed + assert_eq!(graph.compute(value1.output())?, 5); + + // Now connect value1 to both inputs of the addition node + graph.connect(value1.output(), addition.input_b())?; + + // Compute the result after reconnecting the edge + assert_eq!(graph.compute(addition.output())?, 10); + + Ok(()) +} diff --git a/crates/computegraph/tests/macros.rs b/crates/computegraph/tests/macros.rs new file mode 100644 index 00000000..0954d356 --- /dev/null +++ b/crates/computegraph/tests/macros.rs @@ -0,0 +1,105 @@ +use computegraph::{node, ExecutableNode, NodeFactory}; +use std::any::TypeId; + +#[test] +fn test_macro_node() { + #[derive(Debug)] + struct Node1 {} + #[node(Node1)] + fn run(&self) {} + + #[derive(Debug)] // TODO: why do we need this? + struct Node2 {} + #[node(Node2)] + fn run(&self) -> usize { + 21 + } + + #[derive(Debug)] + struct Node3 {} + #[node(Node3 -> hello)] + fn run(&self) -> String { + "hello".to_string() + } + + #[derive(Debug)] + struct Node4 {} + + #[node(Node4 -> (hello, world))] + fn run(&self) -> (String, String) { + ("hello".to_string(), "world".to_string()) + } + + #[derive(Debug)] + struct Node5 {} + #[node(Node5)] + fn run(&self, input: &usize) -> usize { + *input + } + + #[derive(Debug)] + struct Node6 {} + #[node(Node6 -> output)] + fn run(&self, text: &String, repeat_count: &usize) -> String { + text.repeat(*repeat_count) + } + + // TODO: generics support + + assert_eq!(::inputs(), vec![]); + assert_eq!(::outputs(), vec![]); + let res = ExecutableNode::run(&Node1 {}, &[]); + assert_eq!(res.len(), 0); + + assert_eq!(::inputs(), vec![]); + assert_eq!( + ::outputs(), + vec![("output", TypeId::of::())] + ); + + assert_eq!(::inputs(), vec![]); + assert_eq!( + ::outputs(), + vec![("hello", TypeId::of::())] + ); + + assert_eq!(::inputs(), vec![]); + assert_eq!( + ::outputs(), + vec![ + ("hello", TypeId::of::()), + ("world", TypeId::of::()) + ] + ); + let res = ExecutableNode::run(&Node4 {}, &[]); + assert_eq!(res.len(), 2); + assert_eq!(res[0].downcast_ref::().unwrap(), "hello"); + assert_eq!(res[1].downcast_ref::().unwrap(), "world"); + + assert_eq!( + ::inputs(), + vec![("input", TypeId::of::())] + ); + assert_eq!( + ::outputs(), + vec![("output", TypeId::of::())] + ); + let res = ExecutableNode::run(&Node6 {}, &[Box::new("hi".to_string()), Box::new(3_usize)]); + assert_eq!(res.len(), 1); + assert_eq!(res[0].downcast_ref::().unwrap(), "hihihi"); + + assert_eq!( + ::inputs(), + vec![ + ("text", TypeId::of::()), + ("repeat_count", TypeId::of::()) + ] + ); + assert_eq!( + ::outputs(), + vec![("output", TypeId::of::())] + ); + let res = ExecutableNode::run(&Node6 {}, &[Box::new("hi".to_string()), Box::new(3_usize)]); + assert_eq!(res.len(), 1); + assert_eq!(res[0].downcast_ref::().unwrap(), "hihihi"); +} diff --git a/crates/computegraph/tests/metadata.rs b/crates/computegraph/tests/metadata.rs new file mode 100644 index 00000000..dd3b8c3e --- /dev/null +++ b/crates/computegraph/tests/metadata.rs @@ -0,0 +1,35 @@ +mod common; + +use anyhow::{anyhow, Result}; +use common::*; +use computegraph::*; + +#[test] +fn test_metadata() -> Result<()> { + #[derive(Debug, PartialEq)] + struct SomeMetadata; + #[derive(Debug, PartialEq)] + struct OtherMetadata(usize); + + let mut graph = ComputeGraph::new(); + let value = graph.add_node(TestNodeConstant::new(5), "value".to_string())?; + let value_node = graph + .get_node_mut(&value.handle) + .ok_or_else(|| anyhow!("value node not found"))?; + + assert_eq!(value_node.metadata.get::(), None); + value_node.metadata.insert(SomeMetadata); + assert_eq!( + value_node.metadata.get::(), + Some(&SomeMetadata) + ); + value_node.metadata.remove::(); + value_node.metadata.insert(OtherMetadata(42)); + + let value_node = graph + .get_node(&value.handle) + .ok_or_else(|| anyhow!("value node not found"))?; + assert_eq!(value_node.metadata.get::(), None); + assert_eq!(value_node.metadata.get(), Some(&OtherMetadata(42))); + Ok(()) +}