Skip to content

Commit

Permalink
WIP: fix timeouts user-after-free
Browse files Browse the repository at this point in the history
  • Loading branch information
yorickpeterse committed Jan 7, 2025
1 parent ddf2eeb commit c8021bf
Show file tree
Hide file tree
Showing 8 changed files with 360 additions and 568 deletions.
13 changes: 7 additions & 6 deletions rt/src/network_poller.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,18 +50,19 @@ impl Worker {
let mut processes = poller.poll(&mut events);

processes.retain(|proc| {
let mut state = proc.state();
let rights = state.try_reschedule_for_io();
// Acquiring the rights first _then_ matching on then ensures we
// don't deadlock with the timeout worker.
let rights = proc.state().try_reschedule_for_io();

// A process may have also been registered with the timeout
// thread (e.g. when using a timeout). As such we should only
// reschedule the process if the timout thread didn't already do
// this for us.
// reschedule the process if the timeout thread didn't already
// do this for us.
match rights {
RescheduleRights::Failed => false,
RescheduleRights::Acquired => true,
RescheduleRights::AcquiredWithTimeout => {
self.state.timeout_worker.increase_expired_timeouts();
RescheduleRights::AcquiredWithTimeout(id) => {
self.state.timeout_worker.expire(id);
true
}
}
Expand Down
93 changes: 30 additions & 63 deletions rt/src/process.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use crate::arc_without_weak::ArcWithoutWeak;
use crate::mem::{allocate, header_of, Header, TypePointer};
use crate::scheduler::process::Thread;
use crate::scheduler::timeouts::Timeout;
use crate::scheduler::timeouts::Id as TimeoutId;
use crate::stack::Stack;
use crate::state::State;
use std::alloc::dealloc;
Expand Down Expand Up @@ -200,7 +199,7 @@ pub(crate) enum RescheduleRights {

/// The rescheduling rights were obtained, and the process was using a
/// timeout.
AcquiredWithTimeout,
AcquiredWithTimeout(TimeoutId),
}

impl RescheduleRights {
Expand All @@ -219,11 +218,11 @@ pub(crate) struct ProcessState {
/// The status of the process.
status: ProcessStatus,

/// The timeout this process is suspended with, if any.
/// The ID of the timeout this process is suspended with, if any.
///
/// If missing and the process is suspended, it means the process is
/// suspended indefinitely.
timeout: Option<ArcWithoutWeak<Timeout>>,
timeout: Option<TimeoutId>,
}

impl ProcessState {
Expand All @@ -235,17 +234,7 @@ impl ProcessState {
}
}

pub(crate) fn has_same_timeout(
&self,
timeout: &ArcWithoutWeak<Timeout>,
) -> bool {
self.timeout
.as_ref()
.map(|t| t.as_ptr() == timeout.as_ptr())
.unwrap_or(false)
}

pub(crate) fn suspend(&mut self, timeout: ArcWithoutWeak<Timeout>) {
pub(crate) fn suspend(&mut self, timeout: TimeoutId) {
self.timeout = Some(timeout);
self.status.set_sleeping(true);
}
Expand All @@ -265,25 +254,19 @@ impl ProcessState {

self.status.no_longer_waiting();

if self.timeout.take().is_some() {
RescheduleRights::AcquiredWithTimeout
if let Some(id) = self.timeout.take() {
RescheduleRights::AcquiredWithTimeout(id)
} else {
RescheduleRights::Acquired
}
}

pub(crate) fn waiting_for_value(
&mut self,
timeout: Option<ArcWithoutWeak<Timeout>>,
) {
pub(crate) fn waiting_for_value(&mut self, timeout: Option<TimeoutId>) {
self.timeout = timeout;
self.status.set_waiting_for_value(true);
}

pub(crate) fn waiting_for_io(
&mut self,
timeout: Option<ArcWithoutWeak<Timeout>>,
) {
pub(crate) fn waiting_for_io(&mut self, timeout: Option<TimeoutId>) {
self.timeout = timeout;
self.status.set_waiting_for_io(true);
}
Expand All @@ -304,8 +287,8 @@ impl ProcessState {

self.status.set_waiting_for_value(false);

if self.timeout.take().is_some() {
RescheduleRights::AcquiredWithTimeout
if let Some(id) = self.timeout.take() {
RescheduleRights::AcquiredWithTimeout(id)
} else {
RescheduleRights::Acquired
}
Expand All @@ -318,8 +301,8 @@ impl ProcessState {

self.status.set_waiting_for_io(false);

if self.timeout.take().is_some() {
RescheduleRights::AcquiredWithTimeout
if let Some(id) = self.timeout.take() {
RescheduleRights::AcquiredWithTimeout(id)
} else {
RescheduleRights::Acquired
}
Expand Down Expand Up @@ -719,10 +702,10 @@ impl DerefMut for ProcessPointer {
#[cfg(test)]
mod tests {
use super::*;
use crate::test::{empty_process_type, setup, OwnedProcess};
use crate::test::{empty_process_type, OwnedProcess};
use rustix::param::page_size;
use std::mem::size_of;
use std::time::Duration;
use std::num::NonZeroU64;

macro_rules! offset_of {
($value: expr, $field: ident) => {{
Expand Down Expand Up @@ -857,25 +840,14 @@ mod tests {
fn test_reschedule_rights_are_acquired() {
assert!(!RescheduleRights::Failed.are_acquired());
assert!(RescheduleRights::Acquired.are_acquired());
assert!(RescheduleRights::AcquiredWithTimeout.are_acquired());
}

#[test]
fn test_process_state_has_same_timeout() {
let state = setup();
let mut proc_state = ProcessState::new();
let timeout = Timeout::duration(&state, Duration::from_secs(0));

assert!(!proc_state.has_same_timeout(&timeout));

proc_state.timeout = Some(timeout.clone());

assert!(proc_state.has_same_timeout(&timeout));
assert!(RescheduleRights::AcquiredWithTimeout(TimeoutId(
NonZeroU64::new(1).unwrap()
))
.are_acquired());
}

#[test]
fn test_process_state_try_reschedule_after_timeout() {
let state = setup();
let mut proc_state = ProcessState::new();

assert_eq!(
Expand All @@ -893,13 +865,13 @@ mod tests {
assert!(!proc_state.status.is_waiting_for_value());
assert!(!proc_state.status.is_waiting());

let timeout = Timeout::duration(&state, Duration::from_secs(0));
let id = TimeoutId(NonZeroU64::new(1).unwrap());

proc_state.waiting_for_value(Some(timeout));
proc_state.waiting_for_value(Some(id));

assert_eq!(
proc_state.try_reschedule_after_timeout(),
RescheduleRights::AcquiredWithTimeout
RescheduleRights::AcquiredWithTimeout(id)
);

assert!(!proc_state.status.is_waiting_for_value());
Expand All @@ -908,16 +880,15 @@ mod tests {

#[test]
fn test_process_state_waiting_for_value() {
let state = setup();
let mut proc_state = ProcessState::new();
let timeout = Timeout::duration(&state, Duration::from_secs(0));

proc_state.waiting_for_value(None);

assert!(proc_state.status.is_waiting_for_value());
assert!(proc_state.timeout.is_none());

proc_state.waiting_for_value(Some(timeout));
proc_state
.waiting_for_value(Some(TimeoutId(NonZeroU64::new(1).unwrap())));

assert!(proc_state.status.is_waiting_for_value());
assert!(proc_state.timeout.is_some());
Expand All @@ -943,7 +914,6 @@ mod tests {

#[test]
fn test_process_state_try_reschedule_for_value() {
let state = setup();
let mut proc_state = ProcessState::new();

assert_eq!(
Expand All @@ -958,13 +928,14 @@ mod tests {
);
assert!(!proc_state.status.is_waiting_for_value());

let id = TimeoutId(NonZeroU64::new(1).unwrap());

proc_state.status.set_waiting_for_value(true);
proc_state.timeout =
Some(Timeout::duration(&state, Duration::from_secs(0)));
proc_state.timeout = Some(id);

assert_eq!(
proc_state.try_reschedule_for_value(),
RescheduleRights::AcquiredWithTimeout
RescheduleRights::AcquiredWithTimeout(id)
);
assert!(!proc_state.status.is_waiting_for_value());
}
Expand Down Expand Up @@ -1004,29 +975,25 @@ mod tests {

#[test]
fn test_process_state_suspend() {
let state = setup();
let typ = empty_process_type("A");
let stack = Stack::new(32, page_size());
let process = OwnedProcess::new(Process::alloc(*typ, stack));
let timeout = Timeout::duration(&state, Duration::from_secs(0));

process.state().suspend(timeout);
process.state().suspend(TimeoutId(NonZeroU64::new(1).unwrap()));

assert!(process.state().timeout.is_some());
assert!(process.state().status.is_waiting());
}

#[test]
fn test_process_timeout_expired() {
let state = setup();
let typ = empty_process_type("A");
let stack = Stack::new(32, page_size());
let process = OwnedProcess::new(Process::alloc(*typ, stack));
let timeout = Timeout::duration(&state, Duration::from_secs(0));

assert!(!process.timeout_expired());

process.state().suspend(timeout);
process.state().suspend(TimeoutId(NonZeroU64::new(1).unwrap()));

assert!(!process.timeout_expired());
assert!(!process.state().status.timeout_expired());
Expand Down
35 changes: 20 additions & 15 deletions rt/src/runtime/process.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::process::{
StackFrame,
};
use crate::scheduler::process::Action;
use crate::scheduler::timeouts::Timeout;
use crate::scheduler::timeouts::Deadline;
use crate::state::State;
use std::fmt::Write as _;
use std::process::exit;
Expand Down Expand Up @@ -83,8 +83,8 @@ pub unsafe extern "system" fn inko_process_send_message(
let message = Message { method, data };
let state = &*state;
let reschedule = match receiver.send_message(message) {
RescheduleRights::AcquiredWithTimeout => {
state.timeout_worker.increase_expired_timeouts();
RescheduleRights::AcquiredWithTimeout(id) => {
state.timeout_worker.expire(id);
true
}
RescheduleRights::Acquired => true,
Expand Down Expand Up @@ -129,13 +129,16 @@ pub unsafe extern "system" fn inko_process_suspend(
nanos: i64,
) {
let state = &*state;
let timeout = Timeout::duration(state, Duration::from_nanos(nanos as _));
let timeout = Deadline::duration(state, Duration::from_nanos(nanos as _));

{
// We need to hold on to the lock until the end as to ensure the process
// is rescheduled if the timeout happens to expire before we finish the
// work here.
let mut proc_state = process.state();
let timeout_id = state.timeout_worker.suspend(process, timeout);

proc_state.suspend(timeout.clone());
state.timeout_worker.suspend(process, timeout);
proc_state.suspend(timeout_id);
}

// Safety: the current thread is holding on to the run lock
Expand Down Expand Up @@ -243,22 +246,21 @@ pub unsafe extern "system" fn inko_process_wait_for_value_until(
nanos: u64,
) -> bool {
let state = &*state;
let deadline = Timeout::until(nanos);
let deadline = Deadline::until(nanos);
let mut proc_state = process.state();

proc_state.waiting_for_value(Some(deadline.clone()));

let _ = (*lock).compare_exchange(
current,
new,
Ordering::AcqRel,
Ordering::Acquire,
);

let timeout_id = state.timeout_worker.suspend(process, deadline);

proc_state.waiting_for_value(Some(timeout_id));
drop(proc_state);

// Safety: the current thread is holding on to the run lock
state.timeout_worker.suspend(process, deadline);
context::switch(process);
process.timeout_expired()
}
Expand All @@ -270,12 +272,15 @@ pub unsafe extern "system" fn inko_process_reschedule_for_value(
waiter: ProcessPointer,
) {
let state = &*state;
let mut waiter_state = waiter.state();
let reschedule = match waiter_state.try_reschedule_for_value() {

// Acquiring the rights first _then_ matching on then ensures we don't
// deadlock with the timeout worker.
let rights = waiter.state().try_reschedule_for_value();
let reschedule = match rights {
RescheduleRights::Failed => false,
RescheduleRights::Acquired => true,
RescheduleRights::AcquiredWithTimeout => {
state.timeout_worker.increase_expired_timeouts();
RescheduleRights::AcquiredWithTimeout(id) => {
state.timeout_worker.expire(id);
true
}
};
Expand Down
8 changes: 4 additions & 4 deletions rt/src/runtime/socket.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::context;
use crate::network_poller::Interest;
use crate::process::ProcessPointer;
use crate::scheduler::timeouts::Timeout;
use crate::scheduler::timeouts::Deadline;
use crate::socket::Socket;
use crate::state::State;

Expand All @@ -26,10 +26,10 @@ pub(crate) unsafe extern "system" fn inko_socket_poll(

// A deadline of -1 signals that we should wait indefinitely.
if deadline >= 0 {
let time = Timeout::until(deadline as u64);
let time = Deadline::until(deadline as u64);
let timeout_id = state.timeout_worker.suspend(process, time);

proc_state.waiting_for_io(Some(time.clone()));
state.timeout_worker.suspend(process, time);
proc_state.waiting_for_io(Some(timeout_id));
} else {
proc_state.waiting_for_io(None);
}
Expand Down
1 change: 0 additions & 1 deletion rt/src/scheduler/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
pub mod process;
pub mod signal;
pub mod timeout_worker;
pub mod timeouts;

#[cfg(target_os = "linux")]
Expand Down
Loading

0 comments on commit c8021bf

Please sign in to comment.