Skip to content

Commit

Permalink
Removing change callbacks short circuiting
Browse files Browse the repository at this point in the history
The graph_shortcircuit test was flaky in that the
quadrupled.for_each_cloned callback could happen after quadrupled was
set to 16, and the short circuiting of the original call chain prevented
the changes from propagating because when the callback executed it
observed 16 not the originally expected 4. The original callbacks queued
up by setting quadrupled to 16 would fire, but PartialEq on the set
operation would prevent the change callbacks from firing, causuing the
chain to be out of sync.

Since the PartialEq requirements on various APIs are designed to prevent
infinite cycle updates, it seems silly to also try to short circuit in
the graph as well. And, I couldn't come up with a reasonable approach to
solving this problem that I would expect end-users to also apply.

So, this may lead to situations where infinite callback loops can happen
if logic does not prevent it. However, that seems a lot easier to reason
about as opposed to "why did these updates never fire?"
  • Loading branch information
ecton committed Jan 20, 2025
1 parent 737b65c commit 56536be
Showing 1 changed file with 40 additions and 223 deletions.
263 changes: 40 additions & 223 deletions src/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2306,10 +2306,6 @@ impl ChangeCallbacks {
self.data.lock.sync.notify_all();
count
}

fn id(&self) -> CallbacksId {
CallbacksId(Arc::as_ptr(&self.data) as usize)
}
}

trait ValueCallback: Send {
Expand Down Expand Up @@ -4578,195 +4574,32 @@ where
}

fn defer_execute_callbacks(callbacks: ChangeCallbacks) {
static THREAD_SENDER: Lazy<mpsc::SyncSender<(ChangeCallbacks, Option<CallbackInvocationId>)>> =
Lazy::new(|| {
let (sender, receiver) = mpsc::sync_channel(256);
std::thread::spawn(move || CallbackExecutor::new(receiver).run());
sender
});
let id = EXECUTING_CALLBACK_ROOT.get();
let _ = THREAD_SENDER.send((callbacks, id));
}

#[derive(Clone, Copy, Eq, PartialEq, Ord, PartialOrd)]
struct CallbacksId(usize);

#[derive(Clone, Copy, Eq, PartialEq)]
struct CallbackInvocationId {
node: LotId,
invocation_root: CallbacksId,
}

struct InvocationTreeNode {
id: CallbacksId,
enqueued: Option<ChangeCallbacks>,
parent: Option<LotId>,
first_child: Option<LotId>,
root_callbacks: CallbacksId,
next: Option<LotId>,
}

#[derive(Default)]
struct InvocationTree {
nodes: Lots<InvocationTreeNode>,
}

impl InvocationTree {
fn new_root(&mut self, callbacks: ChangeCallbacks) -> CallbackInvocationId {
let callbacks_id = callbacks.id();
let node = self.nodes.push(InvocationTreeNode {
id: callbacks_id,
parent: None,
first_child: None,
root_callbacks: callbacks_id,
enqueued: Some(callbacks),
next: None,
});
CallbackInvocationId {
node,
invocation_root: callbacks_id,
}
}

/// Pushes `invoked` into the list of callbacks executed by group of
/// callbacks pointed to by `invoked_by`.
///
/// # Errors
///
/// Returns an error if `invoked` has already been executed by this group of
/// callbacks.
fn push(
&mut self,
callbacks: ChangeCallbacks,
enqueued_while_executing: Option<CallbackInvocationId>,
) -> Option<CallbackInvocationId> {
if let Some(enqueued_while_executing) = enqueued_while_executing {
// Verify that `callbacks` wasn't executed in the chain leading to the node that was executing this node.
let mut search = enqueued_while_executing.node;
let callbacks_id = callbacks.id();
loop {
let node = &self.nodes[search];
if node.id == callbacks_id {
// This set of callbacks has already been executed in this
// chain.
return None;
}
let Some(parent) = node.parent else {
break;
};
search = parent;
}
let root_invoked_by = enqueued_while_executing.invocation_root;
let existing_first_child = self.nodes[enqueued_while_executing.node].first_child;
if let Some(mut node_id) = existing_first_child {
loop {
let node = &mut self.nodes[node_id];
if node.id == callbacks_id {
if let Some(enqueued) = &mut node.enqueued {
enqueued.changed_at = enqueued.changed_at.max(callbacks.changed_at);
return None;
}

node.enqueued = Some(callbacks);
return Some(CallbackInvocationId {
node: node_id,
invocation_root: root_invoked_by,
});
}

if let Some(next) = node.next {
// Continue traversing the list.
node_id = next;
} else {
break;
}
}
}

// `callbacks` hasn't been executed by the list pointed at by
// `enqueued_while_executing`.
let id = self.nodes.push(InvocationTreeNode {
id: callbacks.id(),
enqueued: Some(callbacks),
parent: Some(enqueued_while_executing.node),
first_child: None,
root_callbacks: root_invoked_by,
next: existing_first_child,
});
self.nodes[enqueued_while_executing.node].first_child = Some(id);
Some(CallbackInvocationId {
node: id,
invocation_root: root_invoked_by,
})
} else {
// New root
Some(self.new_root(callbacks))
}
}

fn complete(&mut self, invocation: CallbackInvocationId) {
self.remove_completed_recursive(invocation.node);
}

fn remove_completed_recursive(&mut self, mut node_id: LotId) {
let mut node = &mut self.nodes[node_id];
while node.enqueued.is_none() && node.first_child.is_none() {
let after_removed = node.next;
if let Some(parent_id) = node.parent {
let parent = &mut self.nodes[parent_id];
// Repair the linked list
if parent.first_child == Some(node_id) {
parent.first_child = after_removed;
} else {
let mut current = parent.first_child.expect("valid child");
while self.nodes[current].next != Some(node_id) {
current = self.nodes[current].next.expect("removed node to exist");
}
self.nodes[current].next = after_removed;
}

self.nodes.remove(node_id);

// Attempt to remove the parent if this was its last node.
node = &mut self.nodes[parent_id];
node_id = parent_id;
} else {
self.nodes.remove(node_id);
break;
}
}
}
}

thread_local! {
static EXECUTING_CALLBACK_ROOT: Cell<Option<CallbackInvocationId>> = const { Cell::new(None) };
}

struct EnqueuedCallbacks {
node_id: CallbackInvocationId,
callbacks: ChangeCallbacks,
static THREAD_SENDER: Lazy<mpsc::SyncSender<ChangeCallbacks>> = Lazy::new(|| {
let (sender, receiver) = mpsc::sync_channel(256);
std::thread::spawn(move || CallbackExecutor::new(receiver).run());
sender
});
let _ = THREAD_SENDER.send(callbacks);
}

struct CallbackExecutor {
receiver: mpsc::Receiver<(ChangeCallbacks, Option<CallbackInvocationId>)>,
receiver: mpsc::Receiver<ChangeCallbacks>,

invocations: InvocationTree,
queue: VecDeque<LotId>,
queue: VecDeque<ChangeCallbacks>,
}

impl CallbackExecutor {
fn new(receiver: mpsc::Receiver<(ChangeCallbacks, Option<CallbackInvocationId>)>) -> Self {
fn new(receiver: mpsc::Receiver<ChangeCallbacks>) -> Self {
Self {
receiver,
invocations: InvocationTree::default(),
queue: VecDeque::new(),
}
}

fn enqueue_nonblocking(&mut self) {
// Exhaust any pending callbacks without blocking.
while let Ok((callbacks, invoked_by)) = self.receiver.try_recv() {
self.enqueue(callbacks, invoked_by);
while let Ok(callbacks) = self.receiver.try_recv() {
self.enqueue(callbacks);
}
}

Expand All @@ -4775,53 +4608,22 @@ impl CallbackExecutor {

// Because this is stored in a static, this likely will never return an
// error, but if it does, it's during program shutdown, and we can exit safely.
while let Ok((callbacks, invoked_by)) = self.receiver.recv() {
self.enqueue(callbacks, invoked_by);
while let Ok(callbacks) = self.receiver.recv() {
self.enqueue(callbacks);
self.enqueue_nonblocking();
let mut callbacks_executed = 0;
loop {
let Some(enqueued) = self.pop_callbacks() else {
break;
};
EXECUTING_CALLBACK_ROOT.set(Some(enqueued.node_id));
callbacks_executed += enqueued.callbacks.execute();

// Enqueue any queued operations before we complete this
// invocation to ensure all related invocations are tracked.
self.enqueue_nonblocking();
self.invocations.complete(enqueued.node_id);
while let Some(enqueued) = self.queue.pop_front() {
callbacks_executed += enqueued.execute();
}
EXECUTING_CALLBACK_ROOT.set(None);

// Once we've exited the loop, we can assume all callback invocation
// chains have completed.
assert!(self.invocations.nodes.is_empty());

if callbacks_executed > 0 {
tracing::trace!("{callbacks_executed} callbacks executed");
}
}
}

fn enqueue(&mut self, callbacks: ChangeCallbacks, invoked_by: Option<CallbackInvocationId>) {
if let Some(pushed) = self.invocations.push(callbacks, invoked_by) {
self.queue.push_back(pushed.node);
}
}

fn pop_callbacks(&mut self) -> Option<EnqueuedCallbacks> {
while let Some(id) = self.queue.pop_front() {
if let Some(callbacks) = self.invocations.nodes[id].enqueued.take() {
return Some(EnqueuedCallbacks {
callbacks,
node_id: CallbackInvocationId {
node: id,
invocation_root: self.invocations.nodes[id].root_callbacks,
},
});
}
}

None
fn enqueue(&mut self, callbacks: ChangeCallbacks) {
self.queue.push_back(callbacks);
}

fn is_current_thread() -> bool {
Expand Down Expand Up @@ -4891,21 +4693,36 @@ fn linked_short_circuit() {
#[test]
fn graph_shortcircuit() {
let a = Dynamic::new(0_usize);
let doubled = a.map_each_cloned(|a| a * 2);
let doubled = a.map_each_cloned(|a| dbg!(a) * 2);
let doubled_reader = doubled.create_reader();
let quadrupled = doubled.map_each_cloned(|a| a * 2);
let quadrupled_reader = quadrupled.create_reader();
let quadrupled = doubled.map_each_cloned(|doubled| dbg!(doubled) * 2);
let invocation_count = Dynamic::new(0_usize);
a.set_source(quadrupled.for_each_cloned({
let a = a.clone();
move |quad| a.set(quad / 4)
let invocation_count = invocation_count.clone();
move |quad| {
*invocation_count.lock() += 1;
a.set(dbg!(quad) / 4);
}
}));
let invocation_count = invocation_count.into_reader();

assert_eq!(a.get(), 0);
assert_eq!(quadrupled_reader.get(), 0);
assert_eq!(quadrupled.get(), 0);
a.set(1);
quadrupled_reader.block_until_updated();

// We need to We expect two invocations at this point:
// - Once by using for_each_cloned.
// - Once by the callback chain invoked by setting a to 1.
//
// TODO for_each_cloned is acting like for_each_subsequent_cloned, throwing
// this count off by one
while invocation_count.get() < 1 {
invocation_count.block_until_updated();
}

assert_eq!(doubled_reader.get(), 2);
assert_eq!(quadrupled_reader.get(), 4);
assert_eq!(quadrupled.get(), 4);
quadrupled.set(16);
doubled_reader.block_until_updated();
assert_eq!(a.get(), 4);
Expand Down

0 comments on commit 56536be

Please sign in to comment.