From ffb1973ebee99e2dc7d9e8eb79ea0d6acfd5d626 Mon Sep 17 00:00:00 2001 From: Matthew Treinish Date: Tue, 22 Aug 2023 10:35:56 -0400 Subject: [PATCH] Handle token_swapper impossible swap mapping with Error (#971) * Handle token_swapper impossible swap mapping with Error This commit fixes the handling of invalid mapping requests in the token_swapper rustworkx-core function and the graph_token_swapper function in rustworkx that uses it. Previously if an invalid mapping was requested the function would internally panic because it always assumed there was a path in the graph to fulfill the user requested mapping. However, because the rustworkx-core function didn't support error returns a breaking api change is needed to add a result return type. * Fix formatting --------- Co-authored-by: Edwin Navarro --- ...apping-token_swapper-55d5b045b0b55345.yaml | 34 +++++ rustworkx-core/src/token_swapper.rs | 120 ++++++++++++++---- src/lib.rs | 3 + src/token_swapper.rs | 17 ++- tests/rustworkx_tests/test_token_swapper.py | 8 ++ 5 files changed, 154 insertions(+), 28 deletions(-) create mode 100644 releasenotes/notes/handle-invalid-mapping-token_swapper-55d5b045b0b55345.yaml diff --git a/releasenotes/notes/handle-invalid-mapping-token_swapper-55d5b045b0b55345.yaml b/releasenotes/notes/handle-invalid-mapping-token_swapper-55d5b045b0b55345.yaml new file mode 100644 index 0000000000..5cc97023e0 --- /dev/null +++ b/releasenotes/notes/handle-invalid-mapping-token_swapper-55d5b045b0b55345.yaml @@ -0,0 +1,34 @@ +--- +features: + - | + Added a new exception class :class:`~.InvalidMapping` which is raised when a function receives an invalid + mapping. The sole user of this exception is the :func:`~.graph_token_swapper` which will raise it when + the user provided mapping is not feasible on the provided graph. +upgrade: + - | + The rustworkx function :func:`~.graph_token_swapper` now will raise an :class:`~.InvalidMapping` exception + instead of a ``PanicException`` when an invalid mapping is requested. This was done because a + ``PanicException`` is difficult to catch by design as it is used to indicate an unhandled error. Using + - | + The return type of the ``rustworkx-core`` function ``token_swapper()`` has been changed + from ``Vec<(NodeIndex, NodeIndex)>`` to be ``Result, MapNotPossible>``. + This change was necessary to return an expected error condition if a mapping is requested for a graph + that is not possible. For example is if you have a disjoint graph and you're trying to map + nodes without any connectivity: + + .. code-block:: rust + + use rustworkx_core::token_swapper; + use rustworkx_core::petgraph; + + let g = petgraph::graph::UnGraph::<(), ()>::from_edges(&[(0, 1), (2, 3) ]); + let mapping = HashMap::from([ + (NodeIndex::new(2), NodeIndex::new(0)), + (NodeIndex::new(1), NodeIndex::new(1)), + (NodeIndex::new(0), NodeIndex::new(2)), + (NodeIndex::new(3), NodeIndex::new(3)), + ]); + token_swapper(&g, mapping, Some(10), Some(4), Some(50)); + + will now return ``Err(MapNotPossible)`` instead of panicking. If you were using this + funciton before you'll need to handle the result type. diff --git a/rustworkx-core/src/token_swapper.rs b/rustworkx-core/src/token_swapper.rs index 065c969f85..c60e31425d 100644 --- a/rustworkx-core/src/token_swapper.rs +++ b/rustworkx-core/src/token_swapper.rs @@ -13,6 +13,8 @@ use rand::distributions::{Standard, Uniform}; use rand::prelude::*; use rand_pcg::Pcg64; +use std::error::Error; +use std::fmt; use std::hash::Hash; use hashbrown::HashMap; @@ -34,6 +36,18 @@ use crate::traversal::dfs_edges; type Swap = (NodeIndex, NodeIndex); type Edge = (NodeIndex, NodeIndex); +/// Error returned by token swapper if the request mapping +/// is impossible +#[derive(Debug, PartialEq, Eq, Ord, PartialOrd, Copy, Clone)] +pub struct MapNotPossible; + +impl Error for MapNotPossible {} +impl fmt::Display for MapNotPossible { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "No mapping possible.") + } +} + struct TokenSwapper where G::NodeId: Eq + Hash, @@ -85,7 +99,7 @@ where } } - fn map(&mut self) -> Vec { + fn map(&mut self) -> Result, MapNotPossible> { let num_nodes = self.graph.node_bound(); let num_edges = self.graph.edge_count(); @@ -141,7 +155,7 @@ where &mut digraph, &mut sub_digraph, &mut tokens, - ); + )?; } // First collect the self.trial number of random numbers // into a Vec based on the given seed @@ -165,7 +179,10 @@ where trial_seed, ) }) - .min_by_key(|result| result.len()) + .min_by_key(|result| match result { + Ok(res) => Ok(res.len()), + Err(e) => Err(*e), + }) .unwrap() } @@ -175,16 +192,16 @@ where digraph: &mut StableGraph<(), (), Directed>, sub_digraph: &mut StableGraph<(), (), Directed>, tokens: &mut HashMap, - ) { + ) -> Result<(), MapNotPossible> { // Adds an edge to digraph if distance from the token to a neighbor is // less than distance from token to node. sub_digraph is same except // for self-edges. if !(tokens.contains_key(&node)) { - return; + return Ok(()); } if tokens[&node] == node { digraph.update_edge(node, node, ()); - return; + return Ok(()); } let id_node = self.rev_node_map[&node]; let id_token = self.rev_node_map[&tokens[&node]]; @@ -207,12 +224,21 @@ where None, ) .unwrap(); + let neigh_dist = dist_neighbor.get(&id_token); + let node_dist = dist_node.get(&id_token); + if neigh_dist.is_none() { + return Err(MapNotPossible {}); + } + if node_dist.is_none() { + return Err(MapNotPossible {}); + } - if dist_neighbor[&id_token] < dist_node[&id_token] { + if neigh_dist < node_dist { digraph.update_edge(node, neighbor, ()); sub_digraph.update_edge(node, neighbor, ()); } } + Ok(()) } fn trial_map( @@ -222,7 +248,7 @@ where mut tokens: HashMap, mut todo_nodes: Vec, trial_seed: u64, - ) -> Vec { + ) -> Result, MapNotPossible> { // Create a random trial list of swaps to move tokens to optimal positions let mut steps = 0; let mut swap_edges: Vec = vec![]; @@ -245,7 +271,7 @@ where &mut sub_digraph, &mut tokens, &mut todo_nodes, - ); + )?; } steps += cycle.len() - 1; // If there's no cycle, see if there's an edge target that matches a token key. @@ -264,7 +290,7 @@ where &mut sub_digraph, &mut tokens, &mut todo_nodes, - ); + )?; steps += 1; found = true; break; @@ -288,7 +314,7 @@ where &mut sub_digraph, &mut tokens, &mut todo_nodes, - ); + )?; steps += 1; found = true; break; @@ -305,7 +331,7 @@ where todo_nodes.is_empty(), "The output final swap map is incomplete, this points to a bug in rustworkx, please open an issue." ); - swap_edges + Ok(swap_edges) } fn swap( @@ -316,7 +342,7 @@ where sub_digraph: &mut StableGraph<(), (), Directed>, tokens: &mut HashMap, todo_nodes: &mut Vec, - ) { + ) -> Result<(), MapNotPossible> { // Get token values for the 2 nodes and remove them let token1 = tokens.remove(&node1); let token2 = tokens.remove(&node2); @@ -347,7 +373,7 @@ where let edge = sub_digraph.find_edge(edge_node1, edge_node2).unwrap(); sub_digraph.remove_edge(edge); } - self.add_token_edges(node, digraph, sub_digraph, tokens); + self.add_token_edges(node, digraph, sub_digraph, tokens)?; // If a node is a token key and not equal to the value, add it to todo_nodes if tokens.contains_key(&node) && tokens[&node] != node { @@ -359,6 +385,7 @@ where todo_nodes.swap_remove(todo_nodes.iter().position(|x| *x == node).unwrap()); } } + Ok(()) } } @@ -378,7 +405,8 @@ where /// trigger the use of parallel threads. If the number of nodes in the graph is less than this value /// it will run in a single thread. The default value is 50. /// -/// It returns a list of tuples representing the swaps to perform. +/// It returns a list of tuples representing the swaps to perform. The result will be an +/// `Err(MapNotPossible)` if the `token_swapper()` function can't find a mapping. /// /// This function is multithreaded and will launch a thread pool with threads equal to /// the number of CPUs by default. You can tune the number of threads with @@ -400,7 +428,7 @@ where /// (NodeIndex::new(2), NodeIndex::new(2)), /// ]); /// // Do the token swap -/// let output = token_swapper(&g, mapping, Some(4), Some(4), Some(50)); +/// let output = token_swapper(&g, mapping, Some(4), Some(4), Some(50)).expect("Swap mapping failed."); /// assert_eq!(3, output.len()); /// /// ``` @@ -411,7 +439,7 @@ pub fn token_swapper( trials: Option, seed: Option, parallel_threshold: Option, -) -> Vec +) -> Result, MapNotPossible> where G: NodeCount + EdgeCount @@ -470,7 +498,7 @@ mod test_token_swapper { (NodeIndex::new(2), NodeIndex::new(2)), ]); let swaps = token_swapper(&g, mapping, Some(4), Some(4), Some(50)); - assert_eq!(3, swaps.len()); + assert_eq!(3, swaps.expect("swap mapping errored").len()); } #[test] @@ -491,7 +519,8 @@ mod test_token_swapper { } // Do the token swap let mut new_map = mapping.clone(); - let swaps = token_swapper(&g, mapping, Some(4), Some(4), Some(50)); + let swaps = + token_swapper(&g, mapping, Some(4), Some(4), Some(50)).expect("swap mapping errored"); do_swap(&mut new_map, &swaps); let mut expected = HashMap::with_capacity(8); for i in 0..8 { @@ -526,7 +555,8 @@ mod test_token_swapper { ]); // Do the token swap let mut new_map = mapping.clone(); - let swaps = token_swapper(&g, mapping, Some(4), Some(4), Some(50)); + let swaps = + token_swapper(&g, mapping, Some(4), Some(4), Some(50)).expect("swap mapping errored"); do_swap(&mut new_map, &swaps); let mut expected = HashMap::with_capacity(6); for i in (0..5).chain(6..7) { @@ -541,7 +571,8 @@ mod test_token_swapper { let g = petgraph::graph::UnGraph::<(), ()>::from_edges(&[(0, 1), (1, 2), (2, 3)]); let mapping = HashMap::from([(NodeIndex::new(0), NodeIndex::new(3))]); let mut new_map = mapping.clone(); - let swaps = token_swapper(&g, mapping, Some(4), Some(4), Some(1)); + let swaps = + token_swapper(&g, mapping, Some(4), Some(4), Some(1)).expect("swap mapping errored"); do_swap(&mut new_map, &swaps); let mut expected = HashMap::with_capacity(4); expected.insert(NodeIndex::new(3), NodeIndex::new(3)); @@ -557,7 +588,8 @@ mod test_token_swapper { g.remove_node(NodeIndex::new(2)); g.add_edge(NodeIndex::new(1), NodeIndex::new(3), ()); let mut new_map = mapping.clone(); - let swaps = token_swapper(&g, mapping, Some(4), Some(4), Some(1)); + let swaps = + token_swapper(&g, mapping, Some(4), Some(4), Some(1)).expect("swap mapping errored"); do_swap(&mut new_map, &swaps); let mut expected = HashMap::with_capacity(4); expected.insert(NodeIndex::new(3), NodeIndex::new(3)); @@ -573,7 +605,8 @@ mod test_token_swapper { (NodeIndex::new(1), NodeIndex::new(2)), ]); let mut new_map = mapping.clone(); - let swaps = token_swapper(&g, mapping, Some(4), Some(4), Some(50)); + let swaps = + token_swapper(&g, mapping, Some(4), Some(4), Some(50)).expect("swap mapping errored"); do_swap(&mut new_map, &swaps); let expected = HashMap::from([ (NodeIndex::new(2), NodeIndex::new(2)), @@ -629,8 +662,47 @@ mod test_token_swapper { let expected: HashMap = mapping.values().map(|val| (*val, *val)).collect(); - let swaps = token_swapper(&g, mapping, Some(4), Some(4), Some(50)); + let swaps = + token_swapper(&g, mapping, Some(4), Some(4), Some(50)).expect("swap mapping errored"); do_swap(&mut new_map, &swaps); assert_eq!(expected, new_map) } + + #[test] + fn test_disjoint_graph_works() { + let g = petgraph::graph::UnGraph::<(), ()>::from_edges(&[(0, 1), (2, 3)]); + let mapping = HashMap::from([ + (NodeIndex::new(1), NodeIndex::new(0)), + (NodeIndex::new(0), NodeIndex::new(1)), + (NodeIndex::new(2), NodeIndex::new(3)), + (NodeIndex::new(3), NodeIndex::new(2)), + ]); + let mut new_map = mapping.clone(); + let swaps = + token_swapper(&g, mapping, Some(10), Some(4), Some(50)).expect("swap mapping errored"); + do_swap(&mut new_map, &swaps); + let expected = HashMap::from([ + (NodeIndex::new(2), NodeIndex::new(2)), + (NodeIndex::new(3), NodeIndex::new(3)), + (NodeIndex::new(1), NodeIndex::new(1)), + (NodeIndex::new(0), NodeIndex::new(0)), + ]); + assert_eq!(2, swaps.len()); + assert_eq!(expected, new_map); + } + + #[test] + fn test_disjoint_graph_fails() { + let g = petgraph::graph::UnGraph::<(), ()>::from_edges(&[(0, 1), (2, 3)]); + let mapping = HashMap::from([ + (NodeIndex::new(2), NodeIndex::new(0)), + (NodeIndex::new(1), NodeIndex::new(1)), + (NodeIndex::new(0), NodeIndex::new(2)), + (NodeIndex::new(3), NodeIndex::new(3)), + ]); + match token_swapper(&g, mapping, Some(10), Some(4), Some(50)) { + Ok(_) => panic!("This should error"), + Err(_) => (), + }; + } } diff --git a/src/lib.rs b/src/lib.rs index 1b15d44592..d34e31f107 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -322,6 +322,8 @@ create_exception!(rustworkx, NoSuitableNeighbors, PyException); create_exception!(rustworkx, NullGraph, PyException); // No path was found between the specified nodes. create_exception!(rustworkx, NoPathFound, PyException); +// No mapping was found for the request swapping +create_exception!(rustworkx, InvalidMapping, PyException); // Prune part of the search tree while traversing a graph. import_exception!(rustworkx.visit, PruneSearch); // Stop graph traversal. @@ -342,6 +344,7 @@ fn rustworkx(py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add("DAGHasCycle", py.get_type::())?; m.add("NoSuitableNeighbors", py.get_type::())?; m.add("NoPathFound", py.get_type::())?; + m.add("InvalidMapping", py.get_type::())?; m.add("NullGraph", py.get_type::())?; m.add("NegativeCycle", py.get_type::())?; m.add( diff --git a/src/token_swapper.rs b/src/token_swapper.rs index 0adf5b85b7..a4b460a868 100644 --- a/src/token_swapper.rs +++ b/src/token_swapper.rs @@ -12,6 +12,7 @@ use crate::graph; use crate::iterators::EdgeList; +use crate::InvalidMapping; use hashbrown::HashMap; use petgraph::graph::NodeIndex; @@ -54,16 +55,24 @@ pub fn graph_token_swapper( trials: Option, seed: Option, parallel_threshold: Option, -) -> EdgeList { +) -> PyResult { let map: HashMap = mapping .iter() .map(|(s, t)| (NodeIndex::new(*s), NodeIndex::new(*t))) .collect(); - let swaps = token_swapper::token_swapper(&graph.graph, map, trials, seed, parallel_threshold); - EdgeList { + let swaps = + match token_swapper::token_swapper(&graph.graph, map, trials, seed, parallel_threshold) { + Ok(swaps) => swaps, + Err(_) => { + return Err(InvalidMapping::new_err( + "Specified mapping could not be made on the given graph", + )) + } + }; + Ok(EdgeList { edges: swaps .into_iter() .map(|(s, t)| (s.index(), t.index())) .collect(), - } + }) } diff --git a/tests/rustworkx_tests/test_token_swapper.py b/tests/rustworkx_tests/test_token_swapper.py index b5a207e32b..aafc6e6ff0 100644 --- a/tests/rustworkx_tests/test_token_swapper.py +++ b/tests/rustworkx_tests/test_token_swapper.py @@ -116,3 +116,11 @@ def test_large_partial_random(self) -> None: swaps = rx.graph_token_swapper(graph, permutation, 4, 4) swap_permutation(mapping, swaps) self.assertEqual({i: i for i in mapping.values()}, mapping) + + def test_disjoint_graph(self): + graph = rx.PyGraph() + graph.extend_from_edge_list([(0, 1), (2, 3)]) + swaps = rx.graph_token_swapper(graph, {1: 0, 0: 1, 2: 3, 3: 2}, 10, seed=42) + self.assertEqual(len(swaps), 2) + with self.assertRaises(rx.InvalidMapping): + rx.graph_token_swapper(graph, {2: 0, 1: 1, 0: 2, 3: 3}, 10)