From dd9025b5cb0e82adfc36d0dbfc766748efe0737f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9rome=20Eertmans?= Date: Thu, 8 Feb 2024 16:08:35 +0100 Subject: [PATCH] chore(test): add some basic tests --- src/rt/utils.rs | 18 +++++++++++- tests/rt/test_graph.py | 64 ++++++++++++++++++++++++++++++++++++++++++ tests/rt/test_utils.py | 1 - 3 files changed, 81 insertions(+), 2 deletions(-) create mode 100644 tests/rt/test_graph.py diff --git a/src/rt/utils.rs b/src/rt/utils.rs index 9ee6ce60..d12dc45c 100644 --- a/src/rt/utils.rs +++ b/src/rt/utils.rs @@ -1,7 +1,11 @@ use numpy::{IntoPyArray, PyArray2}; use pyo3::prelude::*; -use super::graph::{complete::CompleteGraph, directed::DiGraph, PathsIterator}; +use super::graph::{ + complete::CompleteGraph, + directed::{AllPathsFromDiGraphIter, DiGraph}, + PathsIterator, +}; /// Generate an array of all path candidates (assuming fully connected /// primitives). @@ -18,9 +22,21 @@ pub fn generate_all_path_candidates( array.reversed_axes().into_pyarray(py) } +/// Iterator variant of eponym function. +#[pyfunction] +pub fn generate_all_path_candidates_iter( + num_primitives: usize, + order: usize, +) -> AllPathsFromDiGraphIter { + let mut graph: DiGraph = CompleteGraph::new(num_primitives).into(); + let (from, to) = graph.insert_from_and_to_nodes(true); + graph.all_paths(from, to, order + 2, false) +} + pub(crate) fn create_module(py: Python<'_>) -> PyResult<&PyModule> { let m = pyo3::prelude::PyModule::new(py, "utils")?; m.add_function(wrap_pyfunction!(generate_all_path_candidates, m)?)?; + m.add_function(wrap_pyfunction!(generate_all_path_candidates_iter, m)?)?; Ok(m) } diff --git a/tests/rt/test_graph.py b/tests/rt/test_graph.py new file mode 100644 index 00000000..ee44f9e7 --- /dev/null +++ b/tests/rt/test_graph.py @@ -0,0 +1,64 @@ +from typing import Any + +import chex +import jax.numpy as jnp +import pytest +from jaxtyping import Array + +from differt.rt.graph import ( + AllPathsFromDiGraphIter, DiGraph, CompleteGraph, +) + +class TestDiGraph: + + def test_insert_from_and_to_nodes(self) -> None: + graph = DiGraph.from_complete_graph(CompleteGraph(5)) + from_, to = graph.insert_from_and_to_nodes() + assert from_ == 5 + assert to == 6 + from_, to = graph.insert_from_and_to_nodes(direct_path=True) + assert from_ == 7 + assert to == 8 + from_, to = graph.insert_from_and_to_nodes(direct_path=False) + assert from_ == 9 + assert to == 10 + from_, to = graph.insert_from_and_to_nodes(True) + assert from_ == 11 + assert to == 12 + from_, to = graph.insert_from_and_to_nodes(False) + assert from_ == 13 + assert to == 14 + + def test_from_graph(self) -> None: + graph = CompleteGraph(10) + assert isinstance(graph, CompleteGraph) + graph = DiGraph.from_complete_graph(graph) + assert isinstance(graph, DiGraph) + + def test_all_paths_positional_only_parameters(self) -> None: + graph = DiGraph.from_complete_graph(CompleteGraph(5)) + + with pytest.raises(TypeError) as exc: + _ = graph.all_paths(from_=0, to=1, depth=0) + assert "unexpected keyword argument" in str(exc) + + + @pytest.mark.parametrize( + "num_nodes,depth", + [ + (10, 1), + (50, 2), + (10, 3), + ], + ) + def test_all_paths_count_from_complete_graph( + self, num_nodes: int, depth: int + ) -> None: + graph = DiGraph.from_complete_graph(CompleteGraph(num_nodes)) + from_, to = graph.insert_from_and_to_nodes() + paths = graph.all_paths(from_, to, depth + 2, include_from_and_to=False) + num_paths = sum(1 for _ in paths) + assert num_paths == num_nodes * (num_nodes - 1) ** (depth - 1) + array = graph.all_paths_array(from_, to, depth + 2, include_from_and_to=False) + assert array.shape == (num_paths, depth) + diff --git a/tests/rt/test_utils.py b/tests/rt/test_utils.py index dd05e901..66e501dc 100644 --- a/tests/rt/test_utils.py +++ b/tests/rt/test_utils.py @@ -56,7 +56,6 @@ def test_generate_all_path_candidates( chex.assert_trees_all_equal(got, expected) -@pytest.mark.xfail(reason="TODO") @pytest.mark.parametrize( "num_primitives,order", [