Skip to content

Commit

Permalink
proc_macro: add an optimized CrossThread execution strategy, and a de…
Browse files Browse the repository at this point in the history
…bug flag to use it

This new strategy supports avoiding waiting for a reply for noblock messages.
This strategy requires using a channel-like approach (similar to the previous
CrossThread1 approach).

This new CrossThread execution strategy takes a type parameter for the channel
to use, allowing rustc to use a more efficient channel which the proc_macro
crate could not declare as a dependency.
  • Loading branch information
mystor committed Jul 2, 2021
1 parent b6f8dc9 commit 6ce595a
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 88 deletions.
52 changes: 31 additions & 21 deletions compiler/rustc_expand/src/proc_macro.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@ use rustc_parse::parser::ForceCollect;
use rustc_span::def_id::CrateNum;
use rustc_span::{Span, DUMMY_SP};

const EXEC_STRATEGY: pm::bridge::server::SameThread = pm::bridge::server::SameThread;
fn exec_strategy(ecx: &ExtCtxt<'_>) -> impl pm::bridge::server::ExecutionStrategy {
<pm::bridge::server::MaybeCrossThread<pm::bridge::server::StdMessagePipe<_>>>::new(
ecx.sess.opts.debugging_opts.proc_macro_cross_thread,
)
}

pub struct BangProcMacro {
pub client: pm::bridge::client::Client<fn(pm::TokenStream) -> pm::TokenStream>,
Expand All @@ -27,14 +31,16 @@ impl base::ProcMacro for BangProcMacro {
input: TokenStream,
) -> Result<TokenStream, ErrorReported> {
let server = proc_macro_server::Rustc::new(ecx, self.krate);
self.client.run(&EXEC_STRATEGY, server, input, ecx.ecfg.proc_macro_backtrace).map_err(|e| {
let mut err = ecx.struct_span_err(span, "proc macro panicked");
if let Some(s) = e.as_str() {
err.help(&format!("message: {}", s));
}
err.emit();
ErrorReported
})
self.client.run(&exec_strategy(ecx), server, input, ecx.ecfg.proc_macro_backtrace).map_err(
|e| {
let mut err = ecx.struct_span_err(span, "proc macro panicked");
if let Some(s) = e.as_str() {
err.help(&format!("message: {}", s));
}
err.emit();
ErrorReported
},
)
}
}

Expand All @@ -53,7 +59,7 @@ impl base::AttrProcMacro for AttrProcMacro {
) -> Result<TokenStream, ErrorReported> {
let server = proc_macro_server::Rustc::new(ecx, self.krate);
self.client
.run(&EXEC_STRATEGY, server, annotation, annotated, ecx.ecfg.proc_macro_backtrace)
.run(&exec_strategy(ecx), server, annotation, annotated, ecx.ecfg.proc_macro_backtrace)
.map_err(|e| {
let mut err = ecx.struct_span_err(span, "custom attribute panicked");
if let Some(s) = e.as_str() {
Expand Down Expand Up @@ -102,18 +108,22 @@ impl MultiItemModifier for ProcMacroDerive {
};

let server = proc_macro_server::Rustc::new(ecx, self.krate);
let stream =
match self.client.run(&EXEC_STRATEGY, server, input, ecx.ecfg.proc_macro_backtrace) {
Ok(stream) => stream,
Err(e) => {
let mut err = ecx.struct_span_err(span, "proc-macro derive panicked");
if let Some(s) = e.as_str() {
err.help(&format!("message: {}", s));
}
err.emit();
return ExpandResult::Ready(vec![]);
let stream = match self.client.run(
&exec_strategy(ecx),
server,
input,
ecx.ecfg.proc_macro_backtrace,
) {
Ok(stream) => stream,
Err(e) => {
let mut err = ecx.struct_span_err(span, "proc-macro derive panicked");
if let Some(s) = e.as_str() {
err.help(&format!("message: {}", s));
}
};
err.emit();
return ExpandResult::Ready(vec![]);
}
};

let error_count_before = ecx.sess.parse_sess.span_diagnostic.err_count();
let mut parser =
Expand Down
2 changes: 2 additions & 0 deletions compiler/rustc_session/src/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1207,6 +1207,8 @@ options! {
"print layout information for each type encountered (default: no)"),
proc_macro_backtrace: bool = (false, parse_bool, [UNTRACKED],
"show backtraces for panics during proc-macro execution (default: no)"),
proc_macro_cross_thread: bool = (false, parse_bool, [UNTRACKED],
"run proc-macro code on a separate thread (default: no)"),
profile: bool = (false, parse_bool, [TRACKED],
"insert profiling code (default: no)"),
profile_closures: bool = (false, parse_no_flag, [UNTRACKED],
Expand Down
6 changes: 5 additions & 1 deletion library/proc_macro/src/bridge/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,11 @@ macro_rules! client_send_impl {

b = bridge.dispatch.call(b);

let r = Result::<(), PanicMessage>::decode(&mut &b[..], &mut ());
let r = if b.len() > 0 {
Result::<(), PanicMessage>::decode(&mut &b[..], &mut ())
} else {
Ok(())
};

bridge.cached_buffer = b;

Expand Down
169 changes: 103 additions & 66 deletions library/proc_macro/src/bridge/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
use super::*;

use std::marker::PhantomData;

// FIXME(eddyb) generate the definition of `HandleStore` in `server.rs`.
use super::client::HandleStore;

Expand Down Expand Up @@ -174,6 +176,50 @@ pub trait ExecutionStrategy {
) -> Buffer<u8>;
}

pub struct MaybeCrossThread<P> {
cross_thread: bool,
marker: PhantomData<P>,
}

impl<P> MaybeCrossThread<P> {
pub const fn new(cross_thread: bool) -> Self {
MaybeCrossThread { cross_thread, marker: PhantomData }
}
}

impl<P> ExecutionStrategy for MaybeCrossThread<P>
where
P: MessagePipe<Buffer<u8>> + Send + 'static,
{
fn run_bridge_and_client<D: Copy + Send + 'static>(
&self,
dispatcher: &mut impl DispatcherTrait,
input: Buffer<u8>,
run_client: extern "C" fn(BridgeConfig<'_>, D) -> Buffer<u8>,
client_data: D,
force_show_panics: bool,
) -> Buffer<u8> {
if self.cross_thread {
<CrossThread<P>>::new().run_bridge_and_client(
dispatcher,
input,
run_client,
client_data,
force_show_panics,
)
} else {
SameThread.run_bridge_and_client(
dispatcher,
input,
run_client,
client_data,
force_show_panics,
)
}
}
}

#[derive(Default)]
pub struct SameThread;

impl ExecutionStrategy for SameThread {
Expand All @@ -194,12 +240,18 @@ impl ExecutionStrategy for SameThread {
}
}

// NOTE(eddyb) Two implementations are provided, the second one is a bit
// faster but neither is anywhere near as fast as same-thread execution.
pub struct CrossThread<P>(PhantomData<P>);

pub struct CrossThread1;
impl<P> CrossThread<P> {
pub const fn new() -> Self {
CrossThread(PhantomData)
}
}

impl ExecutionStrategy for CrossThread1 {
impl<P> ExecutionStrategy for CrossThread<P>
where
P: MessagePipe<Buffer<u8>> + Send + 'static,
{
fn run_bridge_and_client<D: Copy + Send + 'static>(
&self,
dispatcher: &mut impl DispatcherTrait,
Expand All @@ -208,15 +260,18 @@ impl ExecutionStrategy for CrossThread1 {
client_data: D,
force_show_panics: bool,
) -> Buffer<u8> {
use std::sync::mpsc::channel;

let (req_tx, req_rx) = channel();
let (res_tx, res_rx) = channel();
let (mut server, mut client) = P::new();

let join_handle = thread::spawn(move || {
let mut dispatch = |b| {
req_tx.send(b).unwrap();
res_rx.recv().unwrap()
let mut dispatch = |b: Buffer<u8>| -> Buffer<u8> {
let method_tag = api_tags::Method::decode(&mut &b[..], &mut ());
client.send(b);

if method_tag.should_wait() {
client.recv().expect("server died while client waiting for reply")
} else {
Buffer::new()
}
};

run_client(
Expand All @@ -225,73 +280,55 @@ impl ExecutionStrategy for CrossThread1 {
)
});

for b in req_rx {
res_tx.send(dispatcher.dispatch(b)).unwrap();
while let Some(b) = server.recv() {
let method_tag = api_tags::Method::decode(&mut &b[..], &mut ());
let b = dispatcher.dispatch(b);

if method_tag.should_wait() {
server.send(b);
} else if let Err(err) = <Result<(), PanicMessage>>::decode(&mut &b[..], &mut ()) {
panic::resume_unwind(err.into());
}
}

join_handle.join().unwrap()
}
}

pub struct CrossThread2;

impl ExecutionStrategy for CrossThread2 {
fn run_bridge_and_client<D: Copy + Send + 'static>(
&self,
dispatcher: &mut impl DispatcherTrait,
input: Buffer<u8>,
run_client: extern "C" fn(BridgeConfig<'_>, D) -> Buffer<u8>,
client_data: D,
force_show_panics: bool,
) -> Buffer<u8> {
use std::sync::{Arc, Mutex};

enum State<T> {
Req(T),
Res(T),
}

let mut state = Arc::new(Mutex::new(State::Res(Buffer::new())));
/// A message pipe used for communicating between server and client threads.
pub trait MessagePipe<T>: Sized {
/// Create a new pair of endpoints for the message pipe.
fn new() -> (Self, Self);

let server_thread = thread::current();
let state2 = state.clone();
let join_handle = thread::spawn(move || {
let mut dispatch = |b| {
*state2.lock().unwrap() = State::Req(b);
server_thread.unpark();
loop {
thread::park();
if let State::Res(b) = &mut *state2.lock().unwrap() {
break b.take();
}
}
};
/// Send a message to the other endpoint of this pipe.
fn send(&mut self, value: T);

let r = run_client(
BridgeConfig { input, dispatch: (&mut dispatch).into(), force_show_panics },
client_data,
);
/// Receive a message from the other endpoint of this pipe.
///
/// Returns `None` if the other end of the pipe has been destroyed, and no
/// message was received.
fn recv(&mut self) -> Option<T>;
}

// Wake up the server so it can exit the dispatch loop.
drop(state2);
server_thread.unpark();
/// Implementation of `MessagePipe` using `std::sync::mpsc`
pub struct StdMessagePipe<T> {
tx: std::sync::mpsc::Sender<T>,
rx: std::sync::mpsc::Receiver<T>,
}

r
});
impl<T> MessagePipe<T> for StdMessagePipe<T> {
fn new() -> (Self, Self) {
let (tx1, rx1) = std::sync::mpsc::channel();
let (tx2, rx2) = std::sync::mpsc::channel();
(StdMessagePipe { tx: tx1, rx: rx2 }, StdMessagePipe { tx: tx2, rx: rx1 })
}

// Check whether `state2` was dropped, to know when to stop.
while Arc::get_mut(&mut state).is_none() {
thread::park();
let mut b = match &mut *state.lock().unwrap() {
State::Req(b) => b.take(),
_ => continue,
};
b = dispatcher.dispatch(b.take());
*state.lock().unwrap() = State::Res(b);
join_handle.thread().unpark();
}
fn send(&mut self, v: T) {
self.tx.send(v).unwrap();
}

join_handle.join().unwrap()
fn recv(&mut self) -> Option<T> {
self.rx.recv().ok()
}
}

Expand Down

0 comments on commit 6ce595a

Please sign in to comment.