Skip to content

Commit

Permalink
fix(lint): typo and fmt
Browse files Browse the repository at this point in the history
  • Loading branch information
jeertmans committed Feb 8, 2024
1 parent dd9025b commit 7c1943a
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 12 deletions.
2 changes: 1 addition & 1 deletion python/differt/_core/rt/graph.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class DiGraph:
) -> DiGraph: ...
@classmethod
def from_complete_graph(cls, graph: CompleteGraph) -> DiGraph: ...
def insert_from_and_to_nodes(self, direct_path: bool = True) -> DiGraph: ...
def insert_from_and_to_nodes(self, direct_path: bool = True) -> tuple[int, int]: ...
def all_paths(
self, from_: int, to: int, depth: int, include_from_and_to: bool = True
) -> AllPathsFromDiGraphIter: ...
Expand Down
16 changes: 5 additions & 11 deletions tests/rt/test_graph.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
from typing import Any

import chex
import jax.numpy as jnp
import pytest
from jaxtyping import Array

from differt.rt.graph import (
AllPathsFromDiGraphIter, DiGraph, CompleteGraph,
CompleteGraph,
DiGraph,
)

class TestDiGraph:

class TestDiGraph:
def test_insert_from_and_to_nodes(self) -> None:
graph = DiGraph.from_complete_graph(CompleteGraph(5))
from_, to = graph.insert_from_and_to_nodes()
Expand All @@ -28,7 +24,7 @@ def test_insert_from_and_to_nodes(self) -> None:
from_, to = graph.insert_from_and_to_nodes(False)
assert from_ == 13
assert to == 14

def test_from_graph(self) -> None:
graph = CompleteGraph(10)
assert isinstance(graph, CompleteGraph)
Expand All @@ -42,7 +38,6 @@ def test_all_paths_positional_only_parameters(self) -> None:
_ = graph.all_paths(from_=0, to=1, depth=0)
assert "unexpected keyword argument" in str(exc)


@pytest.mark.parametrize(
"num_nodes,depth",
[
Expand All @@ -58,7 +53,6 @@ def test_all_paths_count_from_complete_graph(
from_, to = graph.insert_from_and_to_nodes()
paths = graph.all_paths(from_, to, depth + 2, include_from_and_to=False)
num_paths = sum(1 for _ in paths)
assert num_paths == num_nodes * (num_nodes - 1) ** (depth - 1)
assert num_paths == num_nodes * (num_nodes - 1) ** (depth - 1)
array = graph.all_paths_array(from_, to, depth + 2, include_from_and_to=False)
assert array.shape == (num_paths, depth)

0 comments on commit 7c1943a

Please sign in to comment.