Skip to content

Commit

Permalink
Make weight_fn optional in adjacency matrix and fw numpy (#158)
Browse files Browse the repository at this point in the history
* Make weight_fn optional adjacency matrix fw numpy

This commit makes the weight_fn argument for graph_adjacency_matrix,
digraph_adjacency_matrix, graph_floyd_warshall_numpy, and
digraph_floyd_warshall_numpy optional. A new kwarg is added
default_weight (which defaults to 1.0) which can be used instead of
passing a callable. If weight_fn is not set the value of default_weight
will be used for all edges. Previously, a function returning a fixed
value would have to be used to accomplish this. In practice there was
not much overhead to just using something like 'lambda _: 1' as the
weight fn, but it was a bit of a clumsy interface.

* Add tests

* Fix rebase issue

* Add test coverage for floyd warshall too

* Fix text_signature for weight_fn
  • Loading branch information
mtreinish authored Nov 6, 2020
1 parent f35becf commit d8bd470
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 40 deletions.
90 changes: 50 additions & 40 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -811,25 +811,22 @@ where
/// path between two nodes then the corresponding matrix entry will be
/// ``np.inf``.
/// :rtype: numpy.ndarray
#[pyfunction]
#[text_signature = "(graph, weight_fn, /)"]
#[pyfunction(default_weight = "1.0")]
#[text_signature = "(graph, /, weight_fn=None, default_weight=1.0)"]
fn graph_floyd_warshall_numpy(
py: Python,
graph: &graph::PyGraph,
weight_fn: PyObject,
weight_fn: Option<PyObject>,
default_weight: f64,
) -> PyResult<PyObject> {
let n = graph.node_count();
// Allocate empty matrix
let mut mat = Array2::<f64>::from_elem((n, n), std::f64::INFINITY);

let weight_callable = |a: &PyObject| -> PyResult<f64> {
let res = weight_fn.call1(py, (a,))?;
res.extract(py)
};

// Build adjacency matrix
for (i, j, weight) in get_edge_iter_with_weights(graph) {
let edge_weight = weight_callable(&weight)?;
let edge_weight =
weight_callable(py, &weight_fn, weight, default_weight)?;
mat[[i, j]] = mat[[i, j]].min(edge_weight);
mat[[j, i]] = mat[[j, i]].min(edge_weight);
}
Expand Down Expand Up @@ -879,27 +876,24 @@ fn graph_floyd_warshall_numpy(
/// path between two nodes then the corresponding matrix entry will be
/// ``np.inf``.
/// :rtype: numpy.ndarray
#[pyfunction(as_undirected = "false")]
#[text_signature = "(graph, weight_fn, /, as_undirected=False)"]
#[pyfunction(as_undirected = "false", default_weight = "1.0")]
#[text_signature = "(graph, /, weight_fn=None as_undirected=False, default_weight=1.0)"]
fn digraph_floyd_warshall_numpy(
py: Python,
graph: &digraph::PyDiGraph,
weight_fn: PyObject,
weight_fn: Option<PyObject>,
as_undirected: bool,
default_weight: f64,
) -> PyResult<PyObject> {
let n = graph.node_count();

// Allocate empty matrix
let mut mat = Array2::<f64>::from_elem((n, n), std::f64::INFINITY);

let weight_callable = |a: &PyObject| -> PyResult<f64> {
let res = weight_fn.call1(py, (a,))?;
res.extract(py)
};

// Build adjacency matrix
for (i, j, weight) in get_edge_iter_with_weights(graph) {
let edge_weight = weight_callable(&weight)?;
let edge_weight =
weight_callable(py, &weight_fn, weight, default_weight)?;
mat[[i, j]] = mat[[i, j]].min(edge_weight);
if as_undirected {
mat[[j, i]] = mat[[j, i]].min(edge_weight);
Expand Down Expand Up @@ -1016,14 +1010,29 @@ fn layers(
Ok(PyList::new(py, output).into())
}

fn weight_callable(
py: Python,
weight_fn: &Option<PyObject>,
weight: PyObject,
default: f64,
) -> PyResult<f64> {
match weight_fn {
Some(weight_fn) => {
let res = weight_fn.call1(py, (weight,))?;
res.extract(py)
}
None => Ok(default),
}
}

/// Return the adjacency matrix for a PyDiGraph object
///
/// In the case where there are multiple edges between nodes the value in the
/// output matrix will be the sum of the edges' weights.
///
/// :param PyDiGraph graph: The DiGraph used to generate the adjacency matrix
/// from
/// :param weight_fn callable: A callable object (function, lambda, etc) which
/// :param callable weight_fn: A callable object (function, lambda, etc) which
/// will be passed the edge object and expected to return a ``float``. This
/// tells retworkx/rust how to extract a numerical weight as a ``float``
/// for edge object. Some simple examples are::
Expand All @@ -1034,26 +1043,27 @@ fn layers(
///
/// dag_adjacency_matrix(dag, weight_fn: lambda x: float(x))
///
/// to cast the edge object as a float as the weight.
/// to cast the edge object as a float as the weight. If this is not
/// specified a default value (either ``default_weight`` or 1) will be used
/// for all edges.
/// :param float default_weight: If ``weight_fn`` is not used this can be
/// optionally used to specify a default weight to use for all edges.
///
/// :return: The adjacency matrix for the input dag as a numpy array
/// :rtype: numpy.ndarray
#[pyfunction]
#[text_signature = "(graph, weight_fn, /)"]
#[pyfunction(default_weight = "1.0")]
#[text_signature = "(graph, /, weight_fn=None, default_weight=1.0)"]
fn digraph_adjacency_matrix(
py: Python,
graph: &digraph::PyDiGraph,
weight_fn: PyObject,
weight_fn: Option<PyObject>,
default_weight: f64,
) -> PyResult<PyObject> {
let n = graph.node_count();
let mut matrix = Array2::<f64>::zeros((n, n));

let weight_callable = |a: &PyObject| -> PyResult<f64> {
let res = weight_fn.call1(py, (a,))?;
res.extract(py)
};
for (i, j, weight) in get_edge_iter_with_weights(graph) {
let edge_weight = weight_callable(&weight)?;
let edge_weight =
weight_callable(py, &weight_fn, weight, default_weight)?;
matrix[[i, j]] += edge_weight;
}
Ok(matrix.into_pyarray(py).into())
Expand All @@ -1076,30 +1086,30 @@ fn digraph_adjacency_matrix(
///
/// graph_adjacency_matrix(graph, weight_fn: lambda x: float(x))
///
/// to cast the edge object as a float as the weight.
/// to cast the edge object as a float as the weight. If this is not
/// specified a default value (either ``default_weight`` or 1) will be used
/// for all edges.
/// :param float default_weight: If ``weight_fn`` is not used this can be
/// optionally used to specify a default weight to use for all edges.
///
/// :return: The adjacency matrix for the input dag as a numpy array
/// :rtype: numpy.ndarray
#[pyfunction]
#[text_signature = "(graph, weight_fn, /)"]
#[pyfunction(default_weight = "1.0")]
#[text_signature = "(graph, /, weight_fn=None, default_weight=1.0)"]
fn graph_adjacency_matrix(
py: Python,
graph: &graph::PyGraph,
weight_fn: PyObject,
weight_fn: Option<PyObject>,
default_weight: f64,
) -> PyResult<PyObject> {
let n = graph.node_count();
let mut matrix = Array2::<f64>::zeros((n, n));

let weight_callable = |a: &PyObject| -> PyResult<f64> {
let res = weight_fn.call1(py, (a,))?;
res.extract(py)
};
for (i, j, weight) in get_edge_iter_with_weights(graph) {
let edge_weight = weight_callable(&weight)?;
let edge_weight =
weight_callable(py, &weight_fn, weight, default_weight)?;
matrix[[i, j]] += edge_weight;
matrix[[j, i]] += edge_weight;
}

Ok(matrix.into_pyarray(py).into())
}

Expand Down
56 changes: 56 additions & 0 deletions tests/test_adjacency_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,32 @@ def test_single_neighbor(self):
dtype=np.float64),
res))

def test_no_weight_fn(self):
dag = retworkx.PyDAG()
node_a = dag.add_node('a')
dag.add_child(node_a, 'b', {'a': 1})
dag.add_child(node_a, 'c', {'a': 2})
res = retworkx.digraph_adjacency_matrix(dag)
self.assertIsInstance(res, np.ndarray)
self.assertTrue(np.array_equal(
np.array(
[[0.0, 1.0, 1.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
dtype=np.float64),
res))

def test_default_weight(self):
dag = retworkx.PyDAG()
node_a = dag.add_node('a')
dag.add_child(node_a, 'b', {'a': 1})
dag.add_child(node_a, 'c', {'a': 2})
res = retworkx.digraph_adjacency_matrix(dag, default_weight=4)
self.assertIsInstance(res, np.ndarray)
self.assertTrue(np.array_equal(
np.array(
[[0.0, 4.0, 4.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
dtype=np.float64),
res))

def test_float_cast_weight_func(self):
dag = retworkx.PyDAG()
node_a = dag.add_node('a')
Expand Down Expand Up @@ -88,6 +114,36 @@ def test_single_neighbor(self):
dtype=np.float64),
res))

def test_no_weight_fn(self):
graph = retworkx.PyGraph()
node_a = graph.add_node('a')
node_b = graph.add_node('b')
graph.add_edge(node_a, node_b, 'edge_a')
node_c = graph.add_node('c')
graph.add_edge(node_b, node_c, 'edge_b')
res = retworkx.graph_adjacency_matrix(graph)
self.assertIsInstance(res, np.ndarray)
self.assertTrue(np.array_equal(
np.array(
[[0.0, 1.0, 0.0], [1.0, 0.0, 1.0], [0.0, 1.0, 0.0]],
dtype=np.float64),
res))

def test_default_weight(self):
graph = retworkx.PyGraph()
node_a = graph.add_node('a')
node_b = graph.add_node('b')
graph.add_edge(node_a, node_b, 'edge_a')
node_c = graph.add_node('c')
graph.add_edge(node_b, node_c, 'edge_b')
res = retworkx.graph_adjacency_matrix(graph, default_weight=4)
self.assertIsInstance(res, np.ndarray)
self.assertTrue(np.array_equal(
np.array(
[[0.0, 4.0, 0.0], [4.0, 0.0, 4.0], [0.0, 4.0, 0.0]],
dtype=np.float64),
res))

def test_float_cast_weight_func(self):
graph = retworkx.PyGraph()
node_a = graph.add_node('a')
Expand Down
40 changes: 40 additions & 0 deletions tests/test_floyd_warshall.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,3 +231,43 @@ def test_floyd_warshall_numpy_graph_cycle_with_removals(self):
dist = retworkx.graph_floyd_warshall_numpy(graph, lambda x: 1)
self.assertEqual(dist[0, 3], 3)
self.assertEqual(dist[0, 4], 3)

def test_floyd_warshall_numpy_digraph_cycle_no_weight_fn(self):
graph = retworkx.PyDiGraph()
graph.add_nodes_from(list(range(8)))
graph.remove_node(0)
graph.add_edges_from_no_data(
[(1, 2), (1, 7), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7)])
dist = retworkx.digraph_floyd_warshall_numpy(graph)
self.assertEqual(dist[0, 3], 3)
self.assertEqual(dist[0, 4], 4)

def test_floyd_warshall_numpy_graph_cycle_no_weight_fn(self):
graph = retworkx.PyGraph()
graph.add_nodes_from(list(range(8)))
graph.remove_node(0)
graph.add_edges_from_no_data(
[(1, 2), (1, 7), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7)])
dist = retworkx.graph_floyd_warshall_numpy(graph)
self.assertEqual(dist[0, 3], 3)
self.assertEqual(dist[0, 4], 3)

def test_floyd_warshall_numpy_digraph_cycle_default_weight(self):
graph = retworkx.PyDiGraph()
graph.add_nodes_from(list(range(8)))
graph.remove_node(0)
graph.add_edges_from_no_data(
[(1, 2), (1, 7), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7)])
dist = retworkx.digraph_floyd_warshall_numpy(graph, default_weight=2)
self.assertEqual(dist[0, 3], 6)
self.assertEqual(dist[0, 4], 8)

def test_floyd_warshall_numpy_graph_cycle_default_weight(self):
graph = retworkx.PyGraph()
graph.add_nodes_from(list(range(8)))
graph.remove_node(0)
graph.add_edges_from_no_data(
[(1, 2), (1, 7), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7)])
dist = retworkx.graph_floyd_warshall_numpy(graph, default_weight=2)
self.assertEqual(dist[0, 3], 6)
self.assertEqual(dist[0, 4], 6)

0 comments on commit d8bd470

Please sign in to comment.