Skip to content

Commit

Permalink
Add Dijkstra shortest path functions (#162)
Browse files Browse the repository at this point in the history
* Add to_undirected method for PyDiGraph

This commit adds a new method to the PyDiGraph class, to_undirected(),
which will generate an undirected PyGraph object from the PyDiGraph
object.

Fixes #153

* Fix lint

* Add Dijkstra shortest path functions

This commit adds 2 new functions, digraph_dijkstra_shortest_paths() and
graph_dijkstra_shortest_path(), which is a function to get the shortest
path from a node in a graph. It leverages the same dijkstra's algorithm
module which has been modified to get a path in addition to the path
length.

Depends on #161
Fixes #151

* Fix lint

* Fix duplicate weight_callable functions from rebase

This commit fixes an issue with duplicate weight_callable functions that
happened because one was added in this PR's branch and another was
added in a different PR. The functions were mostly identical so this
just consolidates the 2.

* Apply suggestions from code review

Co-authored-by: Lauren Capelluto <laurencapelluto@gmail.com>

* Add docs for paths parameter in dijkstra::dijkstra

* Use setUp() to build common test graphs

* Move path HashMap initialization into dijkstra::dijkstra()

Co-authored-by: Lauren Capelluto <laurencapelluto@gmail.com>
  • Loading branch information
mtreinish and lcapelluto authored Nov 9, 2020
1 parent b790a86 commit 0ef76fe
Show file tree
Hide file tree
Showing 4 changed files with 368 additions and 95 deletions.
2 changes: 2 additions & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ Algorithm Functions
retworkx.digraph_all_simple_paths
retworkx.graph_astar_shortest_path
retworkx.digraph_astar_shortest_path
retworkx.graph_dijkstra_shortest_paths
retworkx.digraph_dijkstra_shortest_paths
retworkx.graph_dijkstra_shortest_path_lengths
retworkx.digraph_dijkstra_shortest_path_lengths
retworkx.graph_k_shortest_path_lengths
Expand Down
23 changes: 21 additions & 2 deletions src/dijkstra.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
// License for the specific language governing permissions and limitations
// under the License.

// This module is copied and forked from the upstream petgraph repository,
// specifically:
// This module was originally copied and forked from the upstream petgraph
// repository, specifically:
// https://github.com/petgraph/petgraph/blob/0.5.1/src/dijkstra.rs
// this was necessary to modify the error handling to allow python callables
// to be use for the input functions for edge_cost and return any exceptions
Expand Down Expand Up @@ -42,6 +42,11 @@ use crate::astar::MinScored;
/// If `goal` is not `None`, then the algorithm terminates once the `goal` node's
/// cost is calculated.
///
/// If `path` is not `None`, then the algorithm will mutate the input
/// hashbrown::HashMap to insert an entry where the index is the dest node index
/// the value is a Vec of node indices of the path starting with `start` and
/// ending at the index.
///
/// Returns a `HashMap` that maps `NodeId` to path cost.
/// # Example
/// ```rust
Expand Down Expand Up @@ -97,6 +102,7 @@ pub fn dijkstra<G, F, K>(
start: G::NodeId,
goal: Option<G::NodeId>,
mut edge_cost: F,
mut path: Option<&mut HashMap<G::NodeId, Vec<G::NodeId>>>,
) -> PyResult<HashMap<G::NodeId, K>>
where
G: IntoEdges + Visitable,
Expand All @@ -110,6 +116,9 @@ where
let zero_score = K::default();
scores.insert(start, zero_score);
visit_next.push(MinScored(zero_score, start));
if path.is_some() {
path.as_mut().unwrap().insert(start, vec![start]);
}
while let Some(MinScored(node_score, node)) = visit_next.pop() {
if visited.is_visited(&node) {
continue;
Expand All @@ -134,6 +143,16 @@ where
Vacant(ent) => {
ent.insert(next_score);
visit_next.push(MinScored(next_score, next));
if path.is_some() {
let mut node_path =
path.as_mut().unwrap().get(&node).unwrap().clone();
path.as_mut().unwrap().entry(next).or_insert({
let mut new_vec: Vec<G::NodeId> = Vec::new();
new_vec.append(&mut node_path);
new_vec.push(next);
new_vec
});
}
}
}
}
Expand Down
202 changes: 177 additions & 25 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -826,7 +826,7 @@ fn graph_floyd_warshall_numpy(
// Build adjacency matrix
for (i, j, weight) in get_edge_iter_with_weights(graph) {
let edge_weight =
weight_callable(py, &weight_fn, weight, default_weight)?;
weight_callable(py, &weight_fn, &weight, default_weight)?;
mat[[i, j]] = mat[[i, j]].min(edge_weight);
mat[[j, i]] = mat[[j, i]].min(edge_weight);
}
Expand Down Expand Up @@ -893,7 +893,7 @@ fn digraph_floyd_warshall_numpy(
// Build adjacency matrix
for (i, j, weight) in get_edge_iter_with_weights(graph) {
let edge_weight =
weight_callable(py, &weight_fn, weight, default_weight)?;
weight_callable(py, &weight_fn, &weight, default_weight)?;
mat[[i, j]] = mat[[i, j]].min(edge_weight);
if as_undirected {
mat[[j, i]] = mat[[j, i]].min(edge_weight);
Expand Down Expand Up @@ -1167,21 +1167,6 @@ pub fn graph_distance_matrix(
Ok(matrix.into_pyarray(py).into())
}

fn weight_callable(
py: Python,
weight_fn: &Option<PyObject>,
weight: PyObject,
default: f64,
) -> PyResult<f64> {
match weight_fn {
Some(weight_fn) => {
let res = weight_fn.call1(py, (weight,))?;
res.extract(py)
}
None => Ok(default),
}
}

/// Return the adjacency matrix for a PyDiGraph object
///
/// In the case where there are multiple edges between nodes the value in the
Expand Down Expand Up @@ -1220,7 +1205,7 @@ fn digraph_adjacency_matrix(
let mut matrix = Array2::<f64>::zeros((n, n));
for (i, j, weight) in get_edge_iter_with_weights(graph) {
let edge_weight =
weight_callable(py, &weight_fn, weight, default_weight)?;
weight_callable(py, &weight_fn, &weight, default_weight)?;
matrix[[i, j]] += edge_weight;
}
Ok(matrix.into_pyarray(py).into())
Expand Down Expand Up @@ -1263,7 +1248,7 @@ fn graph_adjacency_matrix(
let mut matrix = Array2::<f64>::zeros((n, n));
for (i, j, weight) in get_edge_iter_with_weights(graph) {
let edge_weight =
weight_callable(py, &weight_fn, weight, default_weight)?;
weight_callable(py, &weight_fn, &weight, default_weight)?;
matrix[[i, j]] += edge_weight;
matrix[[j, i]] += edge_weight;
}
Expand Down Expand Up @@ -1384,6 +1369,163 @@ fn digraph_all_simple_paths(
Ok(result)
}

fn weight_callable(
py: Python,
weight_fn: &Option<PyObject>,
weight: &PyObject,
default: f64,
) -> PyResult<f64> {
match weight_fn {
Some(weight_fn) => {
let res = weight_fn.call1(py, (weight,))?;
res.extract(py)
}
None => Ok(default),
}
}

/// Find the shortest path from a node
///
/// This function will generate the shortest path from a source node using
/// Dijkstra's algorithm.
///
/// :param PyGraph graph:
/// :param int source: The node index to find paths from
/// :param int target: An optional target to find a path to
/// :param weight_fn: An optional weight function for an edge. It will accept
/// a single argument, the edge's weight object and will return a float which
/// will be used to represent the weight/cost of the edge
/// :param float default_weight: If ``weight_fn`` isn't specified this optional
/// float value will be used for the weight/cost of each edge.
/// :param bool as_undirected: If set to true the graph will be treated as
/// undirected for finding the shortest path.
///
/// :return: Dictionary of paths. The keys are destination node indices and
/// the dict values are lists of node indices making the path.
/// :rtype: dict
#[pyfunction(default_weight = "1.0", as_undirected = "false")]
#[text_signature = "(graph, source, /, target=None weight_fn=None, default_weight=1.0)"]
pub fn graph_dijkstra_shortest_paths(
py: Python,
graph: &graph::PyGraph,
source: usize,
target: Option<usize>,
weight_fn: Option<PyObject>,
default_weight: f64,
) -> PyResult<PyObject> {
let start = NodeIndex::new(source);
let goal_index: Option<NodeIndex> = match target {
Some(node) => Some(NodeIndex::new(node)),
None => None,
};
let mut paths: HashMap<NodeIndex, Vec<NodeIndex>> = HashMap::new();
dijkstra::dijkstra(
graph,
start,
goal_index,
|e| weight_callable(py, &weight_fn, e.weight(), default_weight),
Some(&mut paths),
)?;

let out_dict = PyDict::new(py);
for (index, value) in paths {
let int_index = index.index();
if int_index == source {
continue;
}
if (target.is_some() && target.unwrap() == int_index)
|| target.is_none()
{
out_dict.set_item(
int_index,
value
.iter()
.map(|index| index.index())
.collect::<Vec<usize>>(),
)?;
}
}
Ok(out_dict.into())
}

/// Find the shortest path from a node
///
/// This function will generate the shortest path from a source node using
/// Dijkstra's algorithm.
///
/// :param PyDiGraph graph:
/// :param int source: The node index to find paths from
/// :param int target: An optional target path to find the path
/// :param weight_fn: An optional weight function for an edge. It will accept
/// a single argument, the edge's weight object and will return a float which
/// will be used to represent the weight/cost of the edge
/// :param float default_weight: If ``weight_fn`` isn't specified this optional
/// float value will be used for the weight/cost of each edge.
/// :param bool as_undirected: If set to true the graph will be treated as
/// undirected for finding the shortest path.
///
/// :return: Dictionary of paths. The keys are destination node indices and
/// the dict values are lists of node indices making the path.
/// :rtype: dict
#[pyfunction(default_weight = "1.0", as_undirected = "false")]
#[text_signature = "(graph, source, /, target=None weight_fn=None, default_weight=1.0, as_undirected=False)"]
pub fn digraph_dijkstra_shortest_paths(
py: Python,
graph: &digraph::PyDiGraph,
source: usize,
target: Option<usize>,
weight_fn: Option<PyObject>,
default_weight: f64,
as_undirected: bool,
) -> PyResult<PyObject> {
let start = NodeIndex::new(source);
let goal_index: Option<NodeIndex> = match target {
Some(node) => Some(NodeIndex::new(node)),
None => None,
};
let mut paths: HashMap<NodeIndex, Vec<NodeIndex>> = HashMap::new();
if as_undirected {
dijkstra::dijkstra(
// TODO: Use petgraph undirected adapter after
// https://github.com/petgraph/petgraph/pull/318 is available in
// a petgraph release.
&graph.to_undirected(py),
start,
goal_index,
|e| weight_callable(py, &weight_fn, e.weight(), default_weight),
Some(&mut paths),
)?;
} else {
dijkstra::dijkstra(
graph,
start,
goal_index,
|e| weight_callable(py, &weight_fn, e.weight(), default_weight),
Some(&mut paths),
)?;
}

let out_dict = PyDict::new(py);
for (index, value) in paths {
let int_index = index.index();
if int_index == source {
continue;
}
if (target.is_some() && target.unwrap() == int_index)
|| target.is_none()
{
out_dict.set_item(
int_index,
value
.iter()
.map(|index| index.index())
.collect::<Vec<usize>>(),
)?;
}
}
Ok(out_dict.into())
}

/// Compute the lengths of the shortest paths for a PyGraph object using
/// Dijkstra's algorithm
///
Expand Down Expand Up @@ -1423,9 +1565,13 @@ fn graph_dijkstra_shortest_path_lengths(
None => None,
};

let res = dijkstra::dijkstra(graph, start, goal_index, |e| {
edge_cost_callable(e.weight())
})?;
let res = dijkstra::dijkstra(
graph,
start,
goal_index,
|e| edge_cost_callable(e.weight()),
None,
)?;
let out_dict = PyDict::new(py);
for (index, value) in res {
let int_index = index.index();
Expand Down Expand Up @@ -1478,9 +1624,13 @@ fn digraph_dijkstra_shortest_path_lengths(
None => None,
};

let res = dijkstra::dijkstra(graph, start, goal_index, |e| {
edge_cost_callable(e.weight())
})?;
let res = dijkstra::dijkstra(
graph,
start,
goal_index,
|e| edge_cost_callable(e.weight()),
None,
)?;
let out_dict = PyDict::new(py);
for (index, value) in res {
let int_index = index.index();
Expand Down Expand Up @@ -2234,6 +2384,8 @@ fn retworkx(py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(graph_adjacency_matrix))?;
m.add_wrapped(wrap_pyfunction!(graph_all_simple_paths))?;
m.add_wrapped(wrap_pyfunction!(digraph_all_simple_paths))?;
m.add_wrapped(wrap_pyfunction!(graph_dijkstra_shortest_paths))?;
m.add_wrapped(wrap_pyfunction!(digraph_dijkstra_shortest_paths))?;
m.add_wrapped(wrap_pyfunction!(graph_dijkstra_shortest_path_lengths))?;
m.add_wrapped(wrap_pyfunction!(digraph_dijkstra_shortest_path_lengths))?;
m.add_wrapped(wrap_pyfunction!(graph_astar_shortest_path))?;
Expand Down
Loading

0 comments on commit 0ef76fe

Please sign in to comment.