Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

minor State cleanup #3768

Merged
merged 3 commits into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 22 additions & 22 deletions reflex/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
from reflex.utils.exceptions import ImmutableStateError, LockExpiredError
from reflex.utils.exec import is_testing_env
from reflex.utils.serializers import SerializedType, serialize, serializer
from reflex.utils.types import override
from reflex.vars import BaseVar, ComputedVar, Var, computed_var

if TYPE_CHECKING:
Expand Down Expand Up @@ -1232,6 +1233,17 @@ def _get_parent_states(self) -> list[tuple[str, BaseState]]:
parent_states_with_name.append((parent_state.get_full_name(), parent_state))
return parent_states_with_name

def _get_root_state(self) -> BaseState:
"""Get the root state of the state tree.

Returns:
The root state of the state tree.
"""
parent_state = self
while parent_state.parent_state is not None:
parent_state = parent_state.parent_state
return parent_state

async def _populate_parent_states(self, target_state_cls: Type[BaseState]):
"""Populate substates in the tree between the target_state_cls and common ancestor of this state.

Expand Down Expand Up @@ -1291,10 +1303,7 @@ def _get_state_from_cache(self, state_cls: Type[BaseState]) -> BaseState:
Returns:
The instance of state_cls associated with this state's client_token.
"""
if self.parent_state is None:
root_state = self
else:
root_state = self._get_parent_states()[-1][1]
root_state = self._get_root_state()
return root_state.get_substate(state_cls.get_full_name().split("."))

async def _get_state_from_redis(self, state_cls: Type[BaseState]) -> BaseState:
Expand Down Expand Up @@ -1445,9 +1454,7 @@ def _as_state_update(
The valid StateUpdate containing the events and final flag.
"""
# get the delta from the root of the state tree
state = self
while state.parent_state is not None:
state = state.parent_state
state = self._get_root_state()

token = self.router.session.client_token

Expand Down Expand Up @@ -2368,6 +2375,7 @@ class Config:
"_states_locks": {"exclude": True},
}

@override
async def get_state(self, token: str) -> BaseState:
"""Get the state for a token.

Expand All @@ -2383,6 +2391,7 @@ async def get_state(self, token: str) -> BaseState:
self.states[token] = self.state(_reflex_internal_init=True)
return self.states[token]

@override
async def set_state(self, token: str, state: BaseState):
"""Set the state for a token.

Expand All @@ -2392,6 +2401,7 @@ async def set_state(self, token: str, state: BaseState):
"""
pass

@override
@contextlib.asynccontextmanager
async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
"""Modify the state for a token while holding exclusive lock.
Expand Down Expand Up @@ -2483,19 +2493,6 @@ class StateManagerRedis(StateManager):
# Only warn about each state class size once.
_warned_about_state_size: ClassVar[Set[str]] = set()

def _get_root_state(self, state: BaseState) -> BaseState:
"""Chase parent_state pointers to find an instance of the top-level state.

Args:
state: The state to start from.

Returns:
An instance of the top-level state (self.state).
"""
while type(state) != self.state and state.parent_state is not None:
state = state.parent_state
return state

async def _get_parent_state(self, token: str) -> BaseState | None:
"""Get the parent state for the state requested in the token.

Expand Down Expand Up @@ -2558,6 +2555,7 @@ async def _populate_substates(
for substate_name, substate_task in tasks.items():
state.substates[substate_name] = await substate_task

@override
async def get_state(
self,
token: str,
Expand Down Expand Up @@ -2609,7 +2607,7 @@ async def get_state(
# To retain compatibility with previous implementation, by default, we return
# the top-level state by chasing `parent_state` pointers up the tree.
if top_level:
return self._get_root_state(state)
return state._get_root_state()
return state

# TODO: dedupe the following logic with the above block
Expand All @@ -2631,7 +2629,7 @@ async def get_state(
# To retain compatibility with previous implementation, by default, we return
# the top-level state by chasing `parent_state` pointers up the tree.
if top_level:
return self._get_root_state(state)
return state._get_root_state()
return state

def _warn_if_too_large(
Expand All @@ -2657,6 +2655,7 @@ def _warn_if_too_large(
)
self._warned_about_state_size.add(state_full_name)

@override
async def set_state(
self,
token: str,
Expand Down Expand Up @@ -2717,6 +2716,7 @@ async def set_state(
for t in tasks:
await t

@override
@contextlib.asynccontextmanager
async def modify_state(self, token: str) -> AsyncIterator[BaseState]:
"""Modify the state for a token while holding exclusive lock.
Expand Down
2 changes: 1 addition & 1 deletion reflex/utils/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
from reflex.utils import console

if sys.version_info >= (3, 12):
from typing import override
from typing import override as override
else:

def override(func: Callable) -> Callable:
Expand Down
Loading