From 417e1e0d213920d30d356959f533ad3384167652 Mon Sep 17 00:00:00 2001 From: Ivan Carvalho Date: Sun, 5 Jan 2025 18:01:25 -0500 Subject: [PATCH] Fix issues with rustworkx.visit annotations --- .../fix-visit-stub-eafc015adf5cada0.yaml | 7 ++++ rustworkx/__init__.pyi | 4 +-- rustworkx/visit.py | 10 ++++-- rustworkx/visit.pyi | 34 +++++++++---------- 4 files changed, 33 insertions(+), 22 deletions(-) create mode 100644 releasenotes/notes/fix-visit-stub-eafc015adf5cada0.yaml diff --git a/releasenotes/notes/fix-visit-stub-eafc015adf5cada0.yaml b/releasenotes/notes/fix-visit-stub-eafc015adf5cada0.yaml new file mode 100644 index 0000000000..3b794bb8bb --- /dev/null +++ b/releasenotes/notes/fix-visit-stub-eafc015adf5cada0.yaml @@ -0,0 +1,7 @@ +--- +fixes: + - | + Fixed a bug in the discoverability of the type hints for the `rustworkx.visit` module. + Classes declared in the module are also now properly annotated as accepting generic types. + Refer to `issue 1352 `__ for + more information. \ No newline at end of file diff --git a/rustworkx/__init__.pyi b/rustworkx/__init__.pyi index 5952e177e1..495adf1447 100644 --- a/rustworkx/__init__.pyi +++ b/rustworkx/__init__.pyi @@ -19,8 +19,8 @@ from collections.abc import Iterator, Sequence # rustworkx module we need to explicitly re-export every inner function from # rustworkx.rustworkx (the root rust module) in the form: # `from .rustworkx import foo as foo` so that mypy will treat `rustworkx.foo` -# as a valid path -import rustworkx.visit as visit +# as a valid path. +from . import visit as visit from .rustworkx import DAGHasCycle as DAGHasCycle from .rustworkx import DAGWouldCycle as DAGWouldCycle diff --git a/rustworkx/visit.py b/rustworkx/visit.py index 26e99a7c3b..47eed2b06b 100644 --- a/rustworkx/visit.py +++ b/rustworkx/visit.py @@ -6,6 +6,10 @@ # copyright notice, and modified files need to carry a notice indicating # that they have been altered from the originals. +from typing import TypeVar, Generic + +_T = TypeVar("_T") + class StopSearch(Exception): """Stop graph traversal""" @@ -19,7 +23,7 @@ class PruneSearch(Exception): pass -class BFSVisitor: +class BFSVisitor(Generic[_T]): """A visitor object that is invoked at the event-points inside the :func:`~rustworkx.bfs_search` algorithm. By default, it performs no action, and should be used as a base class in order to be useful. @@ -68,7 +72,7 @@ def black_target_edge(self, e): return -class DFSVisitor: +class DFSVisitor(Generic[_T]): """A visitor object that is invoked at the event-points inside the :func:`~rustworkx.dfs_search` algorithm. By default, it performs no action, and should be used as a base class in order to be useful. @@ -119,7 +123,7 @@ def forward_or_cross_edge(self, e): return -class DijkstraVisitor: +class DijkstraVisitor(Generic[_T]): """A visitor object that is invoked at the event-points inside the :func:`~rustworkx.dijkstra_search` algorithm. By default, it performs no action, and should be used as a base class in order to be useful. diff --git a/rustworkx/visit.pyi b/rustworkx/visit.pyi index 0171307556..ca44945258 100644 --- a/rustworkx/visit.pyi +++ b/rustworkx/visit.pyi @@ -9,7 +9,7 @@ # This file contains only type annotations for PyO3 functions and classes # For implementation details, see visit.py -from typing import Generic, TypeVar +from typing import Any, Generic, TypeVar class StopSearch(Exception): ... class PruneSearch(Exception): ... @@ -17,23 +17,23 @@ class PruneSearch(Exception): ... _T = TypeVar("_T") class BFSVisitor(Generic[_T]): - def discover_vertex(self, v: int): ... - def finish_vertex(self, v: int): ... - def tree_edge(self, e: tuple[int, int, _T]): ... - def non_tree_edge(self, e: tuple[int, int, _T]): ... - def gray_target_edge(self, e: tuple[int, int, _T]): ... - def black_target_edge(self, e: tuple[int, int, _T]): ... + def discover_vertex(self, v: int) -> Any: ... + def finish_vertex(self, v: int) -> Any: ... + def tree_edge(self, e: tuple[int, int, _T]) -> Any: ... + def non_tree_edge(self, e: tuple[int, int, _T]) -> Any: ... + def gray_target_edge(self, e: tuple[int, int, _T]) -> Any: ... + def black_target_edge(self, e: tuple[int, int, _T]) -> Any: ... class DFSVisitor(Generic[_T]): - def discover_vertex(self, v: int, t: int): ... - def finish_vertex(self, v: int, t: int): ... - def tree_edge(self, e: tuple[int, int, _T]): ... - def back_edge(self, e: tuple[int, int, _T]): ... - def forward_or_cross_edge(self, e: tuple[int, int, _T]): ... + def discover_vertex(self, v: int, t: int) -> Any: ... + def finish_vertex(self, v: int, t: int) -> Any: ... + def tree_edge(self, e: tuple[int, int, _T]) -> Any: ... + def back_edge(self, e: tuple[int, int, _T]) -> Any: ... + def forward_or_cross_edge(self, e: tuple[int, int, _T]) -> Any: ... class DijkstraVisitor(Generic[_T]): - def discover_vertex(self, v: int, score: float): ... - def finish_vertex(self, v: int): ... - def examine_edge(self, edge: tuple[int, int, _T]): ... - def edge_relaxed(self, edge: tuple[int, int, _T]): ... - def edge_not_relaxed(self, edge: tuple[int, int, _T]): ... + def discover_vertex(self, v: int, score: float) -> Any: ... + def finish_vertex(self, v: int) -> Any: ... + def examine_edge(self, edge: tuple[int, int, _T]) -> Any: ... + def edge_relaxed(self, edge: tuple[int, int, _T]) -> Any: ... + def edge_not_relaxed(self, edge: tuple[int, int, _T]) -> Any: ...