Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support cancellation of requests #125

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 37 additions & 6 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ use core::{marker::PhantomData, task::Poll};
use crate::api::*;
use crate::backend::{BackendId, CoreOnly, Dispatch};
use crate::error::*;
use crate::interrupt::InterruptFlag;
use crate::pipe::{TrussedRequester, TRUSSED_INTERCHANGE};
use crate::service::Service;
use crate::types::*;
Expand Down Expand Up @@ -113,6 +114,9 @@ impl<S: Syscall, E> Client for ClientImplementation<S, E> {}
pub trait PollClient {
fn request<Rq: RequestVariant>(&mut self, req: Rq) -> ClientResult<'_, Rq::Reply, Self>;
fn poll(&mut self) -> Poll<Result<Reply, Error>>;
fn interrupt(&self) -> Option<&'static InterruptFlag> {
None
}
}

pub struct FutureResult<'c, T, C: ?Sized>
Expand Down Expand Up @@ -148,6 +152,7 @@ pub struct ClientImplementation<S, D = CoreOnly> {

// RawClient:
pub(crate) interchange: TrussedRequester,
pub(crate) interrupt: Option<&'static InterruptFlag>,
// pending: Option<Discriminant<Request>>,
pending: Option<u8>,
_marker: PhantomData<D>,
Expand All @@ -165,11 +170,16 @@ impl<S, E> ClientImplementation<S, E>
where
S: Syscall,
{
pub fn new(interchange: TrussedRequester, syscall: S) -> Self {
pub fn new(
interchange: TrussedRequester,
syscall: S,
interrupt: Option<&'static InterruptFlag>,
) -> Self {
Self {
interchange,
pending: None,
syscall,
interrupt,
_marker: Default::default(),
}
}
Expand Down Expand Up @@ -205,7 +215,14 @@ where
}
}
}
None => Poll::Pending,
None => {
debug_assert_ne!(
self.interchange.state(),
interchange::State::Idle,
"requests can't be cancelled"
);
Poll::Pending
}
}
}

Expand All @@ -227,6 +244,10 @@ where
self.syscall.syscall();
Ok(FutureResult::new(self))
}

fn interrupt(&self) -> Option<&'static InterruptFlag> {
self.interrupt
}
}

impl<S: Syscall, E> CertificateClient for ClientImplementation<S, E> {}
Expand Down Expand Up @@ -701,6 +722,7 @@ pub trait UiClient: PollClient {
pub struct ClientBuilder<D: Dispatch = CoreOnly> {
id: PathBuf,
backends: &'static [BackendId<D::BackendId>],
interrupt: Option<&'static InterruptFlag>,
}

impl ClientBuilder {
Expand All @@ -712,6 +734,7 @@ impl ClientBuilder {
Self {
id: id.into(),
backends: &[],
interrupt: None,
}
}
}
Expand All @@ -727,17 +750,22 @@ impl<D: Dispatch> ClientBuilder<D> {
ClientBuilder {
id: self.id,
backends,
interrupt: self.interrupt,
}
}

pub fn interrupt(self, interrupt: Option<&'static InterruptFlag>) -> Self {
Self { interrupt, ..self }
}

fn create_endpoint<P: Platform>(
self,
service: &mut Service<P, D>,
) -> Result<TrussedRequester, Error> {
let (requester, responder) = TRUSSED_INTERCHANGE
.claim()
.ok_or(Error::ClientCountExceeded)?;
service.add_endpoint(responder, self.id, self.backends)?;
service.add_endpoint(responder, self.id, self.backends, self.interrupt)?;
Ok(requester)
}

Expand All @@ -749,8 +777,9 @@ impl<D: Dispatch> ClientBuilder<D> {
self,
service: &mut Service<P, D>,
) -> Result<PreparedClient<D>, Error> {
let interrupt = self.interrupt;
self.create_endpoint(service)
.map(|requester| PreparedClient::new(requester))
.map(|requester| PreparedClient::new(requester, interrupt))
}
}

Expand All @@ -761,20 +790,22 @@ impl<D: Dispatch> ClientBuilder<D> {
/// implementation.
pub struct PreparedClient<D> {
requester: TrussedRequester,
interrupt: Option<&'static InterruptFlag>,
_marker: PhantomData<D>,
}

impl<D> PreparedClient<D> {
fn new(requester: TrussedRequester) -> Self {
fn new(requester: TrussedRequester, interrupt: Option<&'static InterruptFlag>) -> Self {
Self {
requester,
interrupt,
_marker: Default::default(),
}
}

/// Builds the client using the given syscall implementation.
pub fn build<S: Syscall>(self, syscall: S) -> ClientImplementation<S, D> {
ClientImplementation::new(self.requester, syscall)
ClientImplementation::new(self.requester, syscall, self.interrupt)
}
}

Expand Down
77 changes: 77 additions & 0 deletions src/interrupt.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
use core::{
fmt::Debug,
sync::atomic::{AtomicU8, Ordering::Relaxed},
};

#[derive(Default, Debug, PartialEq, Eq)]
pub enum InterruptState {
#[default]
Idle = 0,
Working = 1,
Interrupted = 2,
}

#[derive(Default, Debug, PartialEq, Eq, Clone)]
pub struct FromU8Error;

impl TryFrom<u8> for InterruptState {
type Error = FromU8Error;
fn try_from(value: u8) -> Result<Self, Self::Error> {
match value {
0 => Ok(Self::Idle),
1 => Ok(Self::Working),
2 => Ok(Self::Interrupted),
_ => Err(FromU8Error),
}
}
}

impl From<InterruptState> for u8 {
fn from(value: InterruptState) -> Self {
value as _
}
}

#[derive(Default)]
pub struct InterruptFlag(AtomicU8);

const CONV_ERROR: &str =
"Internal trussed error: InterruptState must always be set to an enum variant";

impl InterruptFlag {
pub const fn new() -> Self {
Self(AtomicU8::new(0))
}
fn load(&self) -> InterruptState {
self.0.load(Relaxed).try_into().expect(CONV_ERROR)
}

pub fn set_idle(&self) {
self.0.store(InterruptState::Idle.into(), Relaxed)
}
pub fn set_working(&self) {
self.0.store(InterruptState::Working.into(), Relaxed)
}
pub fn interrupt(&self) -> bool {
self.0
.compare_exchange(
InterruptState::Working.into(),
InterruptState::Interrupted.into(),
Relaxed,
Relaxed,
)
.is_ok()
}

pub fn is_interrupted(&self) -> bool {
let res = self.load();
info_now!("got interrupt state: {:?}", res);
res == InterruptState::Interrupted
}
}

impl Debug for InterruptFlag {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
self.load().fmt(f)
}
}
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ pub mod backend;
pub mod client;
pub mod config;
pub mod error;
pub mod interrupt;
pub mod key;
pub mod mechanisms;
pub mod pipe;
Expand Down
57 changes: 41 additions & 16 deletions src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ use littlefs2::{
use rand_chacha::ChaCha8Rng;
pub use rand_core::{RngCore, SeedableRng};

use crate::api::*;
use crate::backend::{BackendId, CoreOnly, Dispatch};
use crate::client::{ClientBuilder, ClientImplementation};
use crate::config::*;
Expand All @@ -23,6 +22,7 @@ pub use crate::store::{
};
use crate::types::*;
use crate::Bytes;
use crate::{api::*, interrupt::InterruptFlag};

pub mod attest;

Expand Down Expand Up @@ -106,8 +106,8 @@ impl<P: Platform> ServiceResources<P> {
.map_err(|_| Error::EntropyMalfunction)
}

pub fn filestore(&mut self, ctx: &CoreContext) -> ClientFilestore<P::S> {
ClientFilestore::new(ctx.path.clone(), self.platform.store())
pub fn filestore(&mut self, client_id: PathBuf) -> ClientFilestore<P::S> {
ClientFilestore::new(client_id, self.platform.store())
}

pub fn trussed_filestore(&mut self) -> ClientFilestore<P::S> {
Expand Down Expand Up @@ -143,7 +143,7 @@ impl<P: Platform> ServiceResources<P> {
let keystore = &mut self.keystore(ctx)?;
let certstore = &mut self.certstore(ctx)?;
let counterstore = &mut self.counterstore(ctx)?;
let filestore = &mut self.filestore(ctx);
let filestore = &mut self.filestore(ctx.path.clone());

debug_now!("TRUSSED {:?}", request);
match request {
Expand Down Expand Up @@ -515,30 +515,43 @@ impl<P: Platform> ServiceResources<P> {
},

Request::RequestUserConsent(request) => {

// assert_eq!(request.level, consent::Level::Normal);

let starttime = self.platform.user_interface().uptime();
let timeout = core::time::Duration::from_millis(request.timeout_milliseconds as u64);
let timeout =
core::time::Duration::from_millis(request.timeout_milliseconds as u64);

let previous_status = self.platform.user_interface().status();
self.platform.user_interface().set_status(ui::Status::WaitingForUserPresence);
self.platform
.user_interface()
.set_status(ui::Status::WaitingForUserPresence);
loop {
if ctx.interrupt.map(|i| i.is_interrupted()) == Some(true) {
info_now!("User presence request cancelled");
return Ok(reply::RequestUserConsent {
result: Err(consent::Error::Interrupted),
}
.into());
}

self.platform.user_interface().refresh();
let nowtime = self.platform.user_interface().uptime();
if (nowtime - starttime) > timeout {
let result = Err(consent::Error::TimedOut);
return Ok(Reply::RequestUserConsent(reply::RequestUserConsent { result } ));
return Ok(Reply::RequestUserConsent(reply::RequestUserConsent {
result,
}));
}
let up = self.platform.user_interface().check_user_presence();
match request.level {
// If Normal level consent is request, then both Strong and Normal
// indications will result in success.
consent::Level::Normal => {
if up == consent::Level::Normal ||
up == consent::Level::Strong {
break;
}
},
if up == consent::Level::Normal || up == consent::Level::Strong {
break;
}
}
// Otherwise, only strong level indication will work.
consent::Level::Strong => {
if up == consent::Level::Strong {
Expand All @@ -553,7 +566,9 @@ impl<P: Platform> ServiceResources<P> {
self.platform.user_interface().set_status(previous_status);

let result = Ok(());
Ok(Reply::RequestUserConsent(reply::RequestUserConsent { result } ))
Ok(Reply::RequestUserConsent(reply::RequestUserConsent {
result,
}))
}

Request::Reboot(request) => {
Expand Down Expand Up @@ -709,8 +724,10 @@ impl<P: Platform> Service<P> {
&mut self,
client_id: &str,
syscall: S,
interrupt: Option<&'static InterruptFlag>,
) -> Result<ClientImplementation<S>, Error> {
ClientBuilder::new(client_id)
.interrupt(interrupt)
.prepare(self)
.map(|p| p.build(syscall))
}
Expand All @@ -721,8 +738,10 @@ impl<P: Platform> Service<P> {
pub fn try_as_new_client(
&mut self,
client_id: &str,
interrupt: Option<&'static InterruptFlag>,
) -> Result<ClientImplementation<&mut Self>, Error> {
ClientBuilder::new(client_id)
.interrupt(interrupt)
.prepare(self)
.map(|p| p.build(self))
}
Expand All @@ -732,8 +751,10 @@ impl<P: Platform> Service<P> {
pub fn try_into_new_client(
mut self,
client_id: &str,
interrupt: Option<&'static InterruptFlag>,
) -> Result<ClientImplementation<Self>, Error> {
ClientBuilder::new(client_id)
.interrupt(interrupt)
.prepare(&mut self)
.map(|p| p.build(self))
}
Expand All @@ -743,10 +764,11 @@ impl<P: Platform, D: Dispatch> Service<P, D> {
pub fn add_endpoint(
&mut self,
interchange: TrussedResponder,
core_ctx: impl Into<CoreContext>,
client: impl Into<PathBuf>,
backends: &'static [BackendId<D::BackendId>],
interrupt: Option<&'static InterruptFlag>,
) -> Result<(), Error> {
let core_ctx = core_ctx.into();
let core_ctx = CoreContext::with_interrupt(client.into(), interrupt);
if &*core_ctx.path == path!("trussed") {
panic!("trussed is a reserved client ID");
}
Expand Down Expand Up @@ -813,7 +835,10 @@ impl<P: Platform, D: Dispatch> Service<P, D> {
.platform
.user_interface()
.set_status(ui::Status::Idle);
ep.interchange.respond(reply_result).ok();
if ep.interchange.respond(reply_result).is_err() && ep.interchange.is_canceled() {
info!("Cancelled request");
ep.interchange.acknowledge_cancel().ok();
};
}
}
debug_now!(
Expand Down
Loading