From 7c1943a4f783af5143f1da466fda84f6c0e42062 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9rome=20Eertmans?= Date: Thu, 8 Feb 2024 16:13:29 +0100 Subject: [PATCH] fix(lint): typo and fmt --- python/differt/_core/rt/graph.pyi | 2 +- tests/rt/test_graph.py | 16 +++++----------- 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/python/differt/_core/rt/graph.pyi b/python/differt/_core/rt/graph.pyi index 221c1b1b..9aae9058 100644 --- a/python/differt/_core/rt/graph.pyi +++ b/python/differt/_core/rt/graph.pyi @@ -13,7 +13,7 @@ class DiGraph: ) -> DiGraph: ... @classmethod def from_complete_graph(cls, graph: CompleteGraph) -> DiGraph: ... - def insert_from_and_to_nodes(self, direct_path: bool = True) -> DiGraph: ... + def insert_from_and_to_nodes(self, direct_path: bool = True) -> tuple[int, int]: ... def all_paths( self, from_: int, to: int, depth: int, include_from_and_to: bool = True ) -> AllPathsFromDiGraphIter: ... diff --git a/tests/rt/test_graph.py b/tests/rt/test_graph.py index ee44f9e7..a8ab0511 100644 --- a/tests/rt/test_graph.py +++ b/tests/rt/test_graph.py @@ -1,16 +1,12 @@ -from typing import Any - -import chex -import jax.numpy as jnp import pytest -from jaxtyping import Array from differt.rt.graph import ( - AllPathsFromDiGraphIter, DiGraph, CompleteGraph, + CompleteGraph, + DiGraph, ) -class TestDiGraph: +class TestDiGraph: def test_insert_from_and_to_nodes(self) -> None: graph = DiGraph.from_complete_graph(CompleteGraph(5)) from_, to = graph.insert_from_and_to_nodes() @@ -28,7 +24,7 @@ def test_insert_from_and_to_nodes(self) -> None: from_, to = graph.insert_from_and_to_nodes(False) assert from_ == 13 assert to == 14 - + def test_from_graph(self) -> None: graph = CompleteGraph(10) assert isinstance(graph, CompleteGraph) @@ -42,7 +38,6 @@ def test_all_paths_positional_only_parameters(self) -> None: _ = graph.all_paths(from_=0, to=1, depth=0) assert "unexpected keyword argument" in str(exc) - @pytest.mark.parametrize( "num_nodes,depth", [ @@ -58,7 +53,6 @@ def test_all_paths_count_from_complete_graph( from_, to = graph.insert_from_and_to_nodes() paths = graph.all_paths(from_, to, depth + 2, include_from_and_to=False) num_paths = sum(1 for _ in paths) - assert num_paths == num_nodes * (num_nodes - 1) ** (depth - 1) + assert num_paths == num_nodes * (num_nodes - 1) ** (depth - 1) array = graph.all_paths_array(from_, to, depth + 2, include_from_and_to=False) assert array.shape == (num_paths, depth) -