Skip to content

Commit

Permalink
chore(test): add some basic tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jeertmans committed Feb 8, 2024
1 parent 29cd0d5 commit dd9025b
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 2 deletions.
18 changes: 17 additions & 1 deletion src/rt/utils.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
use numpy::{IntoPyArray, PyArray2};
use pyo3::prelude::*;

use super::graph::{complete::CompleteGraph, directed::DiGraph, PathsIterator};
use super::graph::{
complete::CompleteGraph,
directed::{AllPathsFromDiGraphIter, DiGraph},
PathsIterator,
};

/// Generate an array of all path candidates (assuming fully connected
/// primitives).
Expand All @@ -18,9 +22,21 @@ pub fn generate_all_path_candidates(
array.reversed_axes().into_pyarray(py)
}

/// Iterator variant of eponym function.
#[pyfunction]
pub fn generate_all_path_candidates_iter(
num_primitives: usize,
order: usize,
) -> AllPathsFromDiGraphIter {
let mut graph: DiGraph = CompleteGraph::new(num_primitives).into();
let (from, to) = graph.insert_from_and_to_nodes(true);
graph.all_paths(from, to, order + 2, false)
}

pub(crate) fn create_module(py: Python<'_>) -> PyResult<&PyModule> {
let m = pyo3::prelude::PyModule::new(py, "utils")?;
m.add_function(wrap_pyfunction!(generate_all_path_candidates, m)?)?;
m.add_function(wrap_pyfunction!(generate_all_path_candidates_iter, m)?)?;

Ok(m)
}
64 changes: 64 additions & 0 deletions tests/rt/test_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from typing import Any

import chex
import jax.numpy as jnp
import pytest
from jaxtyping import Array

from differt.rt.graph import (
AllPathsFromDiGraphIter, DiGraph, CompleteGraph,
)

class TestDiGraph:

def test_insert_from_and_to_nodes(self) -> None:
graph = DiGraph.from_complete_graph(CompleteGraph(5))
from_, to = graph.insert_from_and_to_nodes()
assert from_ == 5
assert to == 6
from_, to = graph.insert_from_and_to_nodes(direct_path=True)
assert from_ == 7
assert to == 8
from_, to = graph.insert_from_and_to_nodes(direct_path=False)
assert from_ == 9
assert to == 10
from_, to = graph.insert_from_and_to_nodes(True)
assert from_ == 11
assert to == 12
from_, to = graph.insert_from_and_to_nodes(False)
assert from_ == 13
assert to == 14

def test_from_graph(self) -> None:
graph = CompleteGraph(10)
assert isinstance(graph, CompleteGraph)
graph = DiGraph.from_complete_graph(graph)
assert isinstance(graph, DiGraph)

def test_all_paths_positional_only_parameters(self) -> None:
graph = DiGraph.from_complete_graph(CompleteGraph(5))

with pytest.raises(TypeError) as exc:
_ = graph.all_paths(from_=0, to=1, depth=0)
assert "unexpected keyword argument" in str(exc)


@pytest.mark.parametrize(
"num_nodes,depth",
[
(10, 1),
(50, 2),
(10, 3),
],
)
def test_all_paths_count_from_complete_graph(
self, num_nodes: int, depth: int
) -> None:
graph = DiGraph.from_complete_graph(CompleteGraph(num_nodes))
from_, to = graph.insert_from_and_to_nodes()
paths = graph.all_paths(from_, to, depth + 2, include_from_and_to=False)
num_paths = sum(1 for _ in paths)
assert num_paths == num_nodes * (num_nodes - 1) ** (depth - 1)
array = graph.all_paths_array(from_, to, depth + 2, include_from_and_to=False)
assert array.shape == (num_paths, depth)

1 change: 0 additions & 1 deletion tests/rt/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ def test_generate_all_path_candidates(
chex.assert_trees_all_equal(got, expected)


@pytest.mark.xfail(reason="TODO")
@pytest.mark.parametrize(
"num_primitives,order",
[
Expand Down

0 comments on commit dd9025b

Please sign in to comment.