Skip to content

Commit

Permalink
Make generated 'project' reference take an '&mut Pin<&mut Self>'
Browse files Browse the repository at this point in the history
Based on rust-lang/unsafe-code-guidelines#148 (comment)
by @CAD97

Currently, the generated 'project' method takes a 'Pin<&mut Self>',
consuming it. This makes it impossible to use the original Pin<&mut Self>
after calling project(), since the 'Pin<&mut Self>' has been moved into
the the 'Project' method.

This makes it impossible to implement useful pattern when working with
enums:

```rust

enum Foo {
    Variant1(#[pin] SomeFuture),
    Variant2(OtherType)
}

fn process(foo: Pin<&mut Foo>) {
    match foo.project() {
        __FooProjection(fut) => {
            fut.poll();
            let new_foo: Foo = ...;
            foo.set(new_foo);
        },
        _ => {}
    }
}
```

This pattern is common when implementing a Future combinator - an inner
future is polled, and then the containing enum is changed to a new
variant. However, as soon as 'project()' is called, it becoms imposible
to call 'set' on the original 'Pin<&mut Self>'.

To support this pattern, this commit changes the 'project' method to
take a '&mut Pin<&mut Self>'. The projection types works exactly as
before - however, creating it no longer requires consuming the original
'Pin<&mut Self>'

Unfortunately, current limitations of Rust prevent us from simply
modifiying the signature of the 'project' method in the inherent impl
of the projection type. While using 'Pin<&mut Self>' as a receiver is
supported on stable rust, using '&mut Pin<&mut Self>' as a receiver
requires the unstable `#![feature(arbitrary_self_types)]`

For compatibility with stable Rust, we instead dynamically define a new
trait, '__{Type}ProjectionTrait', where {Type} is the name of the type
with the `#[pin_project]` attribute.

This trait looks like this:

```rust
trait __FooProjectionTrait {
    fn project(&'a mut self) -> __FooProjection<'a>;
}
```

It is then implemented for `Pin<&mut {Type}>`. This allows the `project`
method to be invoked on `&mut Pin<&mut {Type}>`, which is what we want.

If Generic Associated Types (rust-lang/rust#44265)
were implemented and stablized, we could use a single trait for all pin
projections:

```rust
trait Projectable {
    type Projection<'a>;
    fn project(&'a mut self) -> Self::Projection<'a>;
}
```

However, Generic Associated Types are not even implemented on nightly
yet, so we need for generate a new trait per type for the forseeable
future.
  • Loading branch information
Aaron1011 committed Aug 23, 2019
1 parent 62b4921 commit d20f3fe
Show file tree
Hide file tree
Showing 7 changed files with 90 additions and 21 deletions.
7 changes: 4 additions & 3 deletions pin-project-internal/src/pin_project/enums.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,17 @@ pub(super) fn parse(mut cx: Context, mut item: ItemEnum) -> Result<TokenStream>
let Context { original, projected, lifetime, impl_unpin, .. } = cx;
let proj_generics = proj_generics(&item.generics, &lifetime);
let proj_ty_generics = proj_generics.split_for_impl().1;
let proj_trait = &cx.projected_trait;
let (impl_generics, ty_generics, where_clause) = item.generics.split_for_impl();

let mut proj_items = quote! {
enum #projected #proj_generics #where_clause { #(#proj_variants,)* }
};
let proj_method = quote! {
impl #impl_generics #original #ty_generics #where_clause {
fn project<#lifetime>(self: ::core::pin::Pin<&#lifetime mut Self>) -> #projected #proj_ty_generics {
impl #impl_generics #proj_trait #ty_generics for ::core::pin::Pin<&mut #original #ty_generics> #where_clause {
fn project<#lifetime>(&#lifetime mut self) -> #projected #proj_ty_generics #where_clause {
unsafe {
match ::core::pin::Pin::get_unchecked_mut(self) {
match self.as_mut().get_unchecked_mut() {
#(#proj_arms,)*
}
}
Expand Down
43 changes: 35 additions & 8 deletions pin-project-internal/src/pin_project/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@ use syn::{
parse::{Parse, ParseStream},
punctuated::Punctuated,
token::Comma,
Fields, FieldsNamed, FieldsUnnamed, GenericParam, Generics, Index, Item, ItemStruct, Lifetime,
LifetimeDef, Meta, NestedMeta, Result, Type,
*
};

use crate::utils::{crate_path, proj_ident};
use crate::utils::{crate_path, proj_ident, proj_trait_ident};

mod enums;
mod structs;
Expand Down Expand Up @@ -51,6 +50,10 @@ struct Context {
original: Ident,
/// Name of the projected type.
projected: Ident,
/// Name of the trait generated
/// to provide a 'project' method
projected_trait: Ident,
generics: Generics,

lifetime: Lifetime,
impl_unpin: ImplUnpin,
Expand All @@ -63,7 +66,8 @@ impl Context {
let projected = proj_ident(&original);
let lifetime = proj_lifetime(&generics.params);
let impl_unpin = ImplUnpin::new(generics, unsafe_unpin);
Ok(Self { original, projected, lifetime, impl_unpin, pinned_drop })
let projected_trait = proj_trait_ident(&original);
Ok(Self { original, projected, projected_trait, lifetime, impl_unpin, pinned_drop, generics: generics.clone() })
}

fn impl_drop<'a>(&self, generics: &'a Generics) -> ImplDrop<'a> {
Expand All @@ -72,24 +76,47 @@ impl Context {
}

fn parse(args: TokenStream, input: TokenStream) -> Result<TokenStream> {

match syn::parse2(input)? {
Item::Struct(item) => {
let cx = Context::new(args, item.ident.clone(), &item.generics)?;
let mut cx = Context::new(args, item.ident.clone(), &item.generics)?;
let mut res = make_proj_trait(&mut cx)?;

let packed_check = ensure_not_packed(&item)?;
let mut res = structs::parse(cx, item)?;
res.extend(structs::parse(cx, item)?);
res.extend(packed_check);
Ok(res)
}
Item::Enum(item) => {
let cx = Context::new(args, item.ident.clone(), &item.generics)?;
let mut cx = Context::new(args, item.ident.clone(), &item.generics)?;
let mut res = make_proj_trait(&mut cx)?;

// We don't need to check for '#[repr(packed)]',
// since it does not apply to enums
enums::parse(cx, item)
res.extend(enums::parse(cx, item));
Ok(res)
}
item => Err(error!(item, "may only be used on structs or enums")),
}
}

fn make_proj_trait(cx: &mut Context) -> Result<TokenStream> {
let proj_trait = &cx.projected_trait;
let lifetime = &cx.lifetime;
let proj_ident = &cx.projected;
let proj_generics = proj_generics(&cx.generics, &cx.lifetime);
let proj_ty_generics = proj_generics.split_for_impl().1;

let (orig_generics, _orig_ty_generics, orig_where_clause) = cx.generics.split_for_impl();

Ok(quote! {
trait #proj_trait #orig_generics {
fn project<#lifetime>(&#lifetime mut self) -> #proj_ident #proj_ty_generics #orig_where_clause;
}
})

}

fn ensure_not_packed(item: &ItemStruct) -> Result<TokenStream> {
for meta in item.attrs.iter().filter_map(|attr| attr.parse_meta().ok()) {
if let Meta::List(l) = meta {
Expand Down
7 changes: 4 additions & 3 deletions pin-project-internal/src/pin_project/structs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,17 @@ pub(super) fn parse(mut cx: Context, mut item: ItemStruct) -> Result<TokenStream
let impl_drop = cx.impl_drop(&item.generics);
let proj_generics = proj_generics(&item.generics, &cx.lifetime);
let proj_ty_generics = proj_generics.split_for_impl().1;
let proj_trait = &cx.projected_trait;
let (impl_generics, ty_generics, where_clause) = item.generics.split_for_impl();

let mut proj_items = quote! {
struct #proj_ident #proj_generics #where_clause #proj_fields
};
let proj_method = quote! {
impl #impl_generics #orig_ident #ty_generics #where_clause {
fn project<#lifetime>(self: ::core::pin::Pin<&#lifetime mut Self>) -> #proj_ident #proj_ty_generics {
impl #impl_generics #proj_trait #ty_generics for ::core::pin::Pin<&mut #orig_ident #ty_generics> #where_clause {
fn project<#lifetime>(&#lifetime mut self) -> #proj_ident #proj_ty_generics #where_clause {
unsafe {
let this = ::core::pin::Pin::get_unchecked_mut(self);
let this = self.as_mut().get_unchecked_mut();
#proj_ident #proj_init
}
}
Expand Down
4 changes: 4 additions & 0 deletions pin-project-internal/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ pub(crate) fn proj_ident(ident: &Ident) -> Ident {
format_ident!("__{}Projection", ident)
}

pub(crate) fn proj_trait_ident(ident: &Ident) -> Ident {
format_ident!("__{}ProjectionTrait", ident)
}

pub(crate) trait VecExt {
fn find_remove(&mut self, ident: &str) -> Option<Attribute>;
}
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
//! }
//!
//! impl<T, U> Foo<T, U> {
//! fn baz(self: Pin<&mut Self>) {
//! fn baz(mut self: Pin<&mut Self>) {
//! let this = self.project();
//! let _: Pin<&mut T> = this.future; // Pinned reference to the field
//! let _: &mut U = this.field; // Normal reference to the field
Expand Down
46 changes: 41 additions & 5 deletions tests/pin_project.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,22 @@ fn test_pin_project() {

let mut foo = Foo { field1: 1, field2: 2 };

let foo = Pin::new(&mut foo).project();
let mut foo_orig = Pin::new(&mut foo);
let foo = foo_orig.project();

let x: Pin<&mut i32> = foo.field1;
assert_eq!(*x, 1);

let y: &mut i32 = foo.field2;
assert_eq!(*y, 2);

assert_eq!(foo_orig.as_ref().field1, 1);
assert_eq!(foo_orig.as_ref().field2, 2);

let mut foo = Foo { field1: 1, field2: 2 };

let foo = Pin::new(&mut foo).project();
let mut foo = Pin::new(&mut foo);
let foo = foo.project();

let __FooProjection { field1, field2 } = foo;
let _: Pin<&mut i32> = field1;
Expand All @@ -42,7 +47,8 @@ fn test_pin_project() {

let mut bar = Bar(1, 2);

let bar = Pin::new(&mut bar).project();
let mut bar = Pin::new(&mut bar);
let bar = bar.project();

let x: Pin<&mut i32> = bar.0;
assert_eq!(*x, 1);
Expand All @@ -53,6 +59,7 @@ fn test_pin_project() {
// enum

#[pin_project]
#[derive(Eq, PartialEq, Debug)]
enum Baz<A, B, C, D> {
Variant1(#[pin] A, B),
Variant2 {
Expand All @@ -65,7 +72,8 @@ fn test_pin_project() {

let mut baz = Baz::Variant1(1, 2);

let baz = Pin::new(&mut baz).project();
let mut baz_orig = Pin::new(&mut baz);
let baz = baz_orig.project();

match baz {
__BazProjection::Variant1(x, y) => {
Expand All @@ -82,9 +90,12 @@ fn test_pin_project() {
__BazProjection::None => {}
}

assert_eq!(Pin::into_ref(baz_orig).get_ref(), &Baz::Variant1(1, 2));

let mut baz = Baz::Variant2 { field1: 3, field2: 4 };

let mut baz = Pin::new(&mut baz).project();
let mut baz = Pin::new(&mut baz);
let mut baz = baz.project();

match &mut baz {
__BazProjection::Variant1(x, y) => {
Expand All @@ -110,6 +121,31 @@ fn test_pin_project() {
}
}

#[test]
fn enum_project_set() {

#[pin_project]
#[derive(Eq, PartialEq, Debug)]
enum Bar {
Variant1(#[pin] u8),
Variant2(bool)
}

let mut bar = Bar::Variant1(25);
let mut bar_orig = Pin::new(&mut bar);
let bar_proj = bar_orig.project();

match bar_proj {
__BarProjection::Variant1(val) => {
let new_bar = Bar::Variant2(val.as_ref().get_ref() == &25);
bar_orig.set(new_bar);
},
_ => unreachable!()
}

assert_eq!(bar, Bar::Variant2(true));
}

#[test]
fn where_clause_and_associated_type_fields() {
// struct
Expand Down
2 changes: 1 addition & 1 deletion tests/pinned_drop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ pub struct Foo<'a> {
}

#[pinned_drop]
fn do_drop(foo: Pin<&mut Foo<'_>>) {
fn do_drop(mut foo: Pin<&mut Foo<'_>>) {
**foo.project().was_dropped = true;
}

Expand Down

0 comments on commit d20f3fe

Please sign in to comment.