Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add derives for Encode and Decode #71

Merged
merged 7 commits into from
Jan 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -84,5 +84,9 @@ required-features = [ "mysql" ]
name = "mysql-types-chrono"
required-features = [ "mysql", "chrono", "macros" ]

[[test]]
name = "derives"
required-features = [ "macros" ]

[profile.release]
lto = true
88 changes: 88 additions & 0 deletions sqlx-macros/src/derives.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
use quote::quote;
use syn::{parse_quote, Data, DataStruct, DeriveInput, Fields, FieldsUnnamed};

pub(crate) fn expand_derive_encode(input: DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
match &input.data {
Data::Struct(DataStruct {
fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }),
..
}) if unnamed.len() == 1 => {
let ident = &input.ident;
let ty = &unnamed.first().unwrap().ty;

// extract type generics
let generics = &input.generics;
let (_, ty_generics, _) = generics.split_for_impl();

// add db type for impl generics & where clause
let mut generics = generics.clone();
generics.params.insert(0, parse_quote!(DB: sqlx::Database));
generics
.make_where_clause()
.predicates
.push(parse_quote!(#ty: sqlx::encode::Encode<DB>));
let (impl_generics, _, where_clause) = generics.split_for_impl();

Ok(quote!(
impl #impl_generics sqlx::encode::Encode<DB> for #ident #ty_generics #where_clause {
fn encode(&self, buf: &mut std::vec::Vec<u8>) {
sqlx::encode::Encode::encode(&self.0, buf)
}
fn encode_nullable(&self, buf: &mut std::vec::Vec<u8>) -> sqlx::encode::IsNull {
sqlx::encode::Encode::encode_nullable(&self.0, buf)
}
fn size_hint(&self) -> usize {
sqlx::encode::Encode::size_hint(&self.0)
}
}
))
}
_ => Err(syn::Error::new_spanned(
input,
"expected a tuple struct with a single field",
)),
}
}

pub(crate) fn expand_derive_decode(input: DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
match &input.data {
Data::Struct(DataStruct {
fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }),
..
}) if unnamed.len() == 1 => {
let ident = &input.ident;
let ty = &unnamed.first().unwrap().ty;

// extract type generics
let generics = &input.generics;
let (_, ty_generics, _) = generics.split_for_impl();

// add db type for impl generics & where clause
let mut generics = generics.clone();
generics.params.insert(0, parse_quote!(DB: sqlx::Database));
generics
.make_where_clause()
.predicates
.push(parse_quote!(#ty: sqlx::decode::Decode<DB>));
let (impl_generics, _, where_clause) = generics.split_for_impl();

Ok(quote!(
impl #impl_generics sqlx::decode::Decode<DB> for #ident #ty_generics #where_clause {
fn decode(raw: &[u8]) -> std::result::Result<Self, sqlx::decode::DecodeError> {
<#ty as sqlx::decode::Decode<DB>>::decode(raw).map(Self)
}
fn decode_null() -> std::result::Result<Self, sqlx::decode::DecodeError> {
<#ty as sqlx::decode::Decode<DB>>::decode_null().map(Self)
}
fn decode_nullable(raw: std::option::Option<&[u8]>) -> std::result::Result<Self, sqlx::decode::DecodeError> {
<#ty as sqlx::decode::Decode<DB>>::decode_nullable(raw).map(Self)
}
}
))
}
_ => Err(syn::Error::new_spanned(
input,
"expected a tuple struct with a single field",
)),
}
}
20 changes: 20 additions & 0 deletions sqlx-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ type Result<T> = std::result::Result<T, Error>;

mod database;

mod derives;

mod query_macros;

use query_macros::*;
Expand Down Expand Up @@ -136,3 +138,21 @@ pub fn query_as(input: TokenStream) -> TokenStream {
pub fn query_file_as(input: TokenStream) -> TokenStream {
async_macro!(db, input: QueryAsMacroInput => expand_query_file_as(input, db))
}

#[proc_macro_derive(Encode)]
pub fn derive_encode(tokenstream: TokenStream) -> TokenStream {
let input = syn::parse_macro_input!(tokenstream as syn::DeriveInput);
match derives::expand_derive_encode(input) {
Ok(ts) => ts.into(),
Err(e) => e.to_compile_error().into(),
}
}

#[proc_macro_derive(Decode)]
pub fn derive_decode(tokenstream: TokenStream) -> TokenStream {
let input = syn::parse_macro_input!(tokenstream as syn::DeriveInput);
match derives::expand_derive_decode(input) {
Ok(ts) => ts.into(),
Err(e) => e.to_compile_error().into(),
}
}
6 changes: 6 additions & 0 deletions src/decode.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
//! Types and traits for decoding values from the database.

pub use sqlx_core::decode::*;

#[cfg(feature = "macros")]
pub use sqlx_macros::Decode;
6 changes: 6 additions & 0 deletions src/encode.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
//! Types and traits for encoding values to the database.

pub use sqlx_core::encode::*;

#[cfg(feature = "macros")]
pub use sqlx_macros::Encode;
6 changes: 5 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ compile_error!("one of 'runtime-async-std' or 'runtime-tokio' features must be e
compile_error!("only one of 'runtime-async-std' or 'runtime-tokio' features must be enabled");

// Modules
pub use sqlx_core::{arguments, decode, describe, encode, error, pool, row, types};
pub use sqlx_core::{arguments, describe, error, pool, row, types};

// Types
pub use sqlx_core::{
Expand Down Expand Up @@ -42,3 +42,7 @@ pub mod ty_cons;
#[cfg(feature = "macros")]
#[doc(hidden)]
pub mod result_ext;

pub mod encode;

pub mod decode;
60 changes: 60 additions & 0 deletions tests/derives.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
use sqlx::decode::Decode;
use sqlx::encode::Encode;

#[derive(PartialEq, Debug, Encode, Decode)]
struct Foo(i32);

#[test]
#[cfg(feature = "mysql")]
fn encode_mysql() {
encode_with_db::<sqlx::MySql>();
}

#[test]
#[cfg(feature = "postgres")]
fn encode_postgres() {
encode_with_db::<sqlx::Postgres>();
}

#[allow(dead_code)]
fn encode_with_db<DB: sqlx::Database>()
where
Foo: Encode<DB>,
i32: Encode<DB>,
{
let example = Foo(0x1122_3344);

let mut encoded = Vec::new();
let mut encoded_orig = Vec::new();

Encode::<DB>::encode(&example, &mut encoded);
Encode::<DB>::encode(&example.0, &mut encoded_orig);

assert_eq!(encoded, encoded_orig);
}

#[test]
#[cfg(feature = "mysql")]
fn decode_mysql() {
decode_with_db::<sqlx::MySql>();
}

#[test]
#[cfg(feature = "postgres")]
fn decode_postgres() {
decode_with_db::<sqlx::Postgres>();
}

#[allow(dead_code)]
fn decode_with_db<DB: sqlx::Database>()
where
Foo: Decode<DB> + Encode<DB>,
{
let example = Foo(0x1122_3344);

let mut encoded = Vec::new();
Encode::<DB>::encode(&example, &mut encoded);

let decoded = Foo::decode(&encoded).unwrap();
assert_eq!(example, decoded);
}