Skip to content

Commit

Permalink
Add to_undirected method for PyDiGraph (#161)
Browse files Browse the repository at this point in the history
This commit adds a new method to the PyDiGraph class, to_undirected(),
which will generate an undirected PyGraph object from the PyDiGraph
object.

Fixes #153
  • Loading branch information
mtreinish authored Nov 2, 2020
1 parent 12eb6c3 commit bbf0f0c
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 0 deletions.
33 changes: 33 additions & 0 deletions src/digraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ use petgraph::algo;
use petgraph::graph::{EdgeIndex, NodeIndex};
use petgraph::prelude::*;
use petgraph::stable_graph::StableDiGraph;
use petgraph::stable_graph::StableUnGraph;

use petgraph::visit::{
GetAdjacencyMatrix, GraphBase, GraphProp, IntoEdgeReferences, IntoEdges,
IntoEdgesDirected, IntoNeighbors, IntoNeighborsDirected,
Expand Down Expand Up @@ -1605,6 +1607,37 @@ impl PyDiGraph {
}
edges.is_empty()
}

/// Generate a new PyGraph object from this graph
///
/// This will create a new :class:`~retworkx.PyGraph` object from this
/// graph. All edges in this graph will be created as undirected edges in
/// the new graph object.
/// Do note that the node and edge weights/data payloads will be passed
/// by reference to the new :class:`~retworkx.PyGraph` object.
///
/// :returns: A new PyGraph object with an undirected edge for every
/// directed edge in this graph
/// :rtype: PyGraph
pub fn to_undirected(&self, py: Python) -> crate::graph::PyGraph {
let mut new_graph = StableUnGraph::<PyObject, PyObject>::default();
let mut node_map: HashMap<NodeIndex, NodeIndex> = HashMap::new();
for node_index in self.graph.node_indices() {
let node = self.graph[node_index].clone_ref(py);
let new_index = new_graph.add_node(node);
node_map.insert(node_index, new_index);
}
for edge in self.edge_references() {
let source = node_map.get(&edge.source()).unwrap();
let target = node_map.get(&edge.target()).unwrap();
let weight = edge.weight().clone_ref(py);
new_graph.add_edge(*source, *target, weight);
}
crate::graph::PyGraph {
graph: new_graph,
node_removed: false,
}
}
}

#[pyproto]
Expand Down
55 changes: 55 additions & 0 deletions tests/test_to_undirected.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.

import unittest

import retworkx


class TestToUndirected(unittest.TestCase):

def test_to_undirected_empty_graph(self):
digraph = retworkx.PyDiGraph()
graph = digraph.to_undirected()
self.assertEqual(0, len(graph))

def test_single_direction_graph(self):
digraph = retworkx.generators.directed_path_graph(5)
graph = digraph.to_undirected()
self.assertEqual(digraph.weighted_edge_list(),
graph.weighted_edge_list())

def test_bidirectional_graph(self):
digraph = retworkx.generators.directed_path_graph(5)
for i in range(0, 4):
digraph.add_edge(i + 1, i, None)
graph = digraph.to_undirected()
self.assertEqual(digraph.weighted_edge_list(),
graph.weighted_edge_list())

def test_shared_ref(self):
digraph = retworkx.PyDiGraph()
node_weight = {'a': 1}
node_a = digraph.add_node(node_weight)
edge_weight = {'a': 1}
digraph.add_child(node_a, 'b', edge_weight)
graph = digraph.to_undirected()
self.assertEqual(digraph[node_a], {'a': 1})
self.assertEqual(graph[node_a], {'a': 1})
node_weight['b'] = 2
self.assertEqual(digraph[node_a], {'a': 1, 'b': 2})
self.assertEqual(graph[node_a], {'a': 1, 'b': 2})
self.assertEqual(digraph.get_edge_data(0, 1), {'a': 1})
self.assertEqual(graph.get_edge_data(0, 1), {'a': 1})
edge_weight['b'] = 2
self.assertEqual(digraph.get_edge_data(0, 1), {'a': 1, 'b': 2})
self.assertEqual(graph.get_edge_data(0, 1), {'a': 1, 'b': 2})

0 comments on commit bbf0f0c

Please sign in to comment.