From e9bca2c867fc891659c72c0175036074ed9229cd Mon Sep 17 00:00:00 2001 From: Benedikt Bartscher Date: Fri, 9 Aug 2024 22:00:36 +0200 Subject: [PATCH 1/3] improve abstract typing --- reflex/state.py | 7 +++++++ reflex/utils/types.py | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/reflex/state.py b/reflex/state.py index b0c6646ce9b..f3aefe54f8d 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -55,6 +55,7 @@ from reflex.utils.exceptions import ImmutableStateError, LockExpiredError from reflex.utils.exec import is_testing_env from reflex.utils.serializers import SerializedType, serialize, serializer +from reflex.utils.types import override from reflex.vars import BaseVar, ComputedVar, Var, computed_var if TYPE_CHECKING: @@ -2368,6 +2369,7 @@ class Config: "_states_locks": {"exclude": True}, } + @override async def get_state(self, token: str) -> BaseState: """Get the state for a token. @@ -2383,6 +2385,7 @@ async def get_state(self, token: str) -> BaseState: self.states[token] = self.state(_reflex_internal_init=True) return self.states[token] + @override async def set_state(self, token: str, state: BaseState): """Set the state for a token. @@ -2392,6 +2395,7 @@ async def set_state(self, token: str, state: BaseState): """ pass + @override @contextlib.asynccontextmanager async def modify_state(self, token: str) -> AsyncIterator[BaseState]: """Modify the state for a token while holding exclusive lock. @@ -2558,6 +2562,7 @@ async def _populate_substates( for substate_name, substate_task in tasks.items(): state.substates[substate_name] = await substate_task + @override async def get_state( self, token: str, @@ -2657,6 +2662,7 @@ def _warn_if_too_large( ) self._warned_about_state_size.add(state_full_name) + @override async def set_state( self, token: str, @@ -2717,6 +2723,7 @@ async def set_state( for t in tasks: await t + @override @contextlib.asynccontextmanager async def modify_state(self, token: str) -> AsyncIterator[BaseState]: """Modify the state for a token while holding exclusive lock. diff --git a/reflex/utils/types.py b/reflex/utils/types.py index 3bb5eae355b..32438b9e684 100644 --- a/reflex/utils/types.py +++ b/reflex/utils/types.py @@ -49,7 +49,7 @@ from reflex.utils import console if sys.version_info >= (3, 12): - from typing import override + from typing import override as override else: def override(func: Callable) -> Callable: From 21353794f9aff86c30594b82f347cee0ac6602f0 Mon Sep 17 00:00:00 2001 From: Benedikt Bartscher Date: Fri, 9 Aug 2024 22:06:58 +0200 Subject: [PATCH 2/3] streamline get_root_state --- reflex/state.py | 39 ++++++++++++++++----------------------- 1 file changed, 16 insertions(+), 23 deletions(-) diff --git a/reflex/state.py b/reflex/state.py index f3aefe54f8d..6506ba3fef1 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -1233,6 +1233,17 @@ def _get_parent_states(self) -> list[tuple[str, BaseState]]: parent_states_with_name.append((parent_state.get_full_name(), parent_state)) return parent_states_with_name + def _get_root_state(self) -> BaseState: + """Get the root state of the state tree. + + Returns: + The root state of the state tree. + """ + parent_state = self + while parent_state.parent_state is not None: + parent_state = parent_state.parent_state + return parent_state + async def _populate_parent_states(self, target_state_cls: Type[BaseState]): """Populate substates in the tree between the target_state_cls and common ancestor of this state. @@ -1260,7 +1271,7 @@ async def _populate_parent_states(self, target_state_cls: Type[BaseState]): # Fetch all missing parent states and link them up to the common ancestor. parent_states_tuple = self._get_parent_states() - root_state = parent_states_tuple[-1][1] + root_state = self._get_root_state() parent_states_by_name = dict(parent_states_tuple) parent_state = parent_states_by_name[common_ancestor_name] for parent_state_name in missing_parent_states: @@ -1292,10 +1303,7 @@ def _get_state_from_cache(self, state_cls: Type[BaseState]) -> BaseState: Returns: The instance of state_cls associated with this state's client_token. """ - if self.parent_state is None: - root_state = self - else: - root_state = self._get_parent_states()[-1][1] + root_state = self._get_root_state() return root_state.get_substate(state_cls.get_full_name().split(".")) async def _get_state_from_redis(self, state_cls: Type[BaseState]) -> BaseState: @@ -1446,9 +1454,7 @@ def _as_state_update( The valid StateUpdate containing the events and final flag. """ # get the delta from the root of the state tree - state = self - while state.parent_state is not None: - state = state.parent_state + state = self._get_root_state() token = self.router.session.client_token @@ -2487,19 +2493,6 @@ class StateManagerRedis(StateManager): # Only warn about each state class size once. _warned_about_state_size: ClassVar[Set[str]] = set() - def _get_root_state(self, state: BaseState) -> BaseState: - """Chase parent_state pointers to find an instance of the top-level state. - - Args: - state: The state to start from. - - Returns: - An instance of the top-level state (self.state). - """ - while type(state) != self.state and state.parent_state is not None: - state = state.parent_state - return state - async def _get_parent_state(self, token: str) -> BaseState | None: """Get the parent state for the state requested in the token. @@ -2614,7 +2607,7 @@ async def get_state( # To retain compatibility with previous implementation, by default, we return # the top-level state by chasing `parent_state` pointers up the tree. if top_level: - return self._get_root_state(state) + return state._get_root_state() return state # TODO: dedupe the following logic with the above block @@ -2636,7 +2629,7 @@ async def get_state( # To retain compatibility with previous implementation, by default, we return # the top-level state by chasing `parent_state` pointers up the tree. if top_level: - return self._get_root_state(state) + return state._get_root_state() return state def _warn_if_too_large( From d4526872144b8a44d437985accdf24c1f6df5ec8 Mon Sep 17 00:00:00 2001 From: Benedikt Bartscher Date: Mon, 12 Aug 2024 16:40:26 +0200 Subject: [PATCH 3/3] revert this line, avoid unneeded computation --- reflex/state.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/reflex/state.py b/reflex/state.py index 6506ba3fef1..228530fdace 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -1271,7 +1271,7 @@ async def _populate_parent_states(self, target_state_cls: Type[BaseState]): # Fetch all missing parent states and link them up to the common ancestor. parent_states_tuple = self._get_parent_states() - root_state = self._get_root_state() + root_state = parent_states_tuple[-1][1] parent_states_by_name = dict(parent_states_tuple) parent_state = parent_states_by_name[common_ancestor_name] for parent_state_name in missing_parent_states: