Skip to content

Commit

Permalink
Base implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
la10736 committed Jan 1, 2025
1 parent 7792a9b commit d77d7c6
Show file tree
Hide file tree
Showing 9 changed files with 283 additions and 39 deletions.
16 changes: 16 additions & 0 deletions rstest/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1469,3 +1469,19 @@ pub use rstest_macros::fixture;
/// ```
///
pub use rstest_macros::rstest;

pub struct Context {
pub name: &'static str,
pub description: Option<&'static str>,
pub case: Option<usize>,
}

impl Context {
pub fn new(name: &'static str, description: Option<&'static str>, case: Option<usize>) -> Self {
Self {
name,
description,
case,
}
}
}
16 changes: 16 additions & 0 deletions rstest/tests/resources/rstest/context.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
use rstest::*;

#[rstest]
#[case::description(42)]
fn with_case(#[context] ctx: Context, #[case] _c: u32) {
assert_eq!("with_case", ctx.name);
assert_eq!(Some("description"), ctx.description);
assert_eq!(Some(0), ctx.case);
}

#[rstest]
fn without_case(#[context] ctx: Context) {
assert_eq!("without_case", ctx.name);
assert_eq!(None, ctx.description);
assert_eq!(None, ctx.case);
}
10 changes: 10 additions & 0 deletions rstest/tests/rstest/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1216,6 +1216,16 @@ fn no_std() {
.assert(output);
}

#[test]
fn context() {
let (output, _) = run_test("context.rs");

TestResults::new()
.ok("with_case::case_1_description")
.ok("without_case")
.assert(output);
}

mod async_timeout_feature {
use super::*;

Expand Down
16 changes: 15 additions & 1 deletion rstest_macros/src/parse/arguments.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::collections::HashMap;
use std::collections::{HashMap, HashSet};

use quote::format_ident;
use syn::{FnArg, Ident, Pat};
Expand Down Expand Up @@ -98,6 +98,7 @@ pub(crate) struct ArgumentsInfo {
args: Args,
is_global_await: bool,
once: Option<syn::Attribute>,
contexts: HashSet<Pat>,
}

impl ArgumentsInfo {
Expand Down Expand Up @@ -235,6 +236,19 @@ impl ArgumentsInfo {
fn_arg
})
}

#[allow(dead_code)]
pub(crate) fn add_context(&mut self, pat: Pat) {
self.contexts.insert(pat);
}

pub(crate) fn set_contexts(&mut self, contexts: impl Iterator<Item = Pat>) {
contexts.for_each(|c| self.add_context(c))
}

pub(crate) fn contexts(&self) -> impl Iterator<Item = &Pat> + '_ {
self.contexts.iter()
}
}

#[cfg(test)]
Expand Down
60 changes: 60 additions & 0 deletions rstest_macros/src/parse/context.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
use syn::{visit_mut::VisitMut, ItemFn, Pat};

use crate::error::ErrorsVec;

use super::just_once::JustOnceFnArgAttributeExtractor;

pub(crate) fn extract_context(item_fn: &mut ItemFn) -> Result<Vec<Pat>, ErrorsVec> {
let mut extractor = JustOnceFnArgAttributeExtractor::from("context");
extractor.visit_item_fn_mut(item_fn);
extractor.take()
}

#[cfg(test)]
mod should {
use super::*;
use crate::test::{assert_eq, *};
use rstest_test::assert_in;

#[rstest]
#[case("fn simple(a: u32) {}")]
#[case("fn more(a: u32, b: &str) {}")]
#[case("fn gen<S: AsRef<str>>(a: u32, b: S) {}")]
#[case("fn attr(#[case] a: u32, #[values(1,2)] b: i32) {}")]
fn not_change_anything_if_no_ignore_attribute_found(#[case] item_fn: &str) {
let mut item_fn: ItemFn = item_fn.ast();
let orig = item_fn.clone();

let by_refs = extract_context(&mut item_fn).unwrap();

assert_eq!(orig, item_fn);
assert!(by_refs.is_empty());
}

#[rstest]
#[case::simple("fn f(#[context] a: u32) {}", "fn f(a: u32) {}", &["a"])]
#[case::more_than_one(
"fn f(#[context] a: u32, #[context] b: String, #[context] c: std::collection::HashMap<usize, String>) {}",
r#"fn f(a: u32,
b: String,
c: std::collection::HashMap<usize, String>) {}"#,
&["a", "b", "c"])]
fn extract(#[case] item_fn: &str, #[case] expected: &str, #[case] expected_refs: &[&str]) {
let mut item_fn: ItemFn = item_fn.ast();
let expected: ItemFn = expected.ast();

let by_refs = extract_context(&mut item_fn).unwrap();

assert_eq!(expected, item_fn);
assert_eq!(by_refs, to_pats!(expected_refs));
}

#[test]
fn raise_error() {
let mut item_fn: ItemFn = "fn f(#[context] #[context] a: u32) {}".ast();

let err = extract_context(&mut item_fn).unwrap_err();

assert_in!(format!("{:?}", err), "more than once");
}
}
1 change: 1 addition & 0 deletions rstest_macros/src/parse/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ pub(crate) mod macros;

pub(crate) mod arguments;
pub(crate) mod by_ref;
pub(crate) mod context;
pub(crate) mod expressions;
pub(crate) mod fixture;
pub(crate) mod future;
Expand Down
34 changes: 30 additions & 4 deletions rstest_macros/src/parse/rstest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@ use self::files::{extract_files, ValueListFromFiles};
use super::{
arguments::ArgumentsInfo,
by_ref::extract_by_ref,
check_timeout_attrs, extract_case_args, extract_cases, extract_excluded_trace,
extract_fixtures, extract_value_list,
check_timeout_attrs,
context::extract_context,
extract_case_args, extract_cases, extract_excluded_trace, extract_fixtures, extract_value_list,
future::{extract_futures, extract_global_awt},
ignore::extract_ignores,
parse_vector_trailing_till_double_comma,
Expand Down Expand Up @@ -49,20 +50,24 @@ impl Parse for RsTestInfo {

impl ExtendWithFunctionAttrs for RsTestInfo {
fn extend_with_function_attrs(&mut self, item_fn: &mut ItemFn) -> Result<(), ErrorsVec> {
let composed_tuple!(_inner, excluded, _timeout, futures, global_awt, by_refs, ignores) = merge_errors!(
let composed_tuple!(
_inner, excluded, _timeout, futures, global_awt, by_refs, ignores, contexts
) = merge_errors!(
self.data.extend_with_function_attrs(item_fn),
extract_excluded_trace(item_fn),
check_timeout_attrs(item_fn),
extract_futures(item_fn),
extract_global_awt(item_fn),
extract_by_ref(item_fn),
extract_ignores(item_fn)
extract_ignores(item_fn),
extract_context(item_fn)
)?;
self.attributes.add_notraces(excluded);
self.arguments.set_global_await(global_awt);
self.arguments.set_futures(futures.into_iter());
self.arguments.set_by_refs(by_refs.into_iter());
self.arguments.set_ignores(ignores.into_iter());
self.arguments.set_contexts(contexts.into_iter());
self.arguments
.register_inner_destructored_idents_names(item_fn);
Ok(())
Expand Down Expand Up @@ -379,6 +384,8 @@ mod test {
}

mod no_cases {
use std::collections::HashSet;

use super::{assert_eq, *};

#[test]
Expand Down Expand Up @@ -563,6 +570,25 @@ mod test {
assert!(info.arguments.is_future(&pat("a")));
assert!(!info.arguments.is_future(&pat("b")));
}

#[rstest]
fn extract_context() {
let mut item_fn =
"fn f(#[context] c: Context, #[context] other: Context, more: u32) {}".ast();
let expected = "fn f(c: Context, other: Context, more: u32) {}".ast();

let mut info = RsTestInfo::default();

info.extend_with_function_attrs(&mut item_fn).unwrap();

assert_eq!(item_fn, expected);
assert_eq!(
info.arguments.contexts().cloned().collect::<HashSet<_>>(),
vec![pat("c"), pat("other")]
.into_iter()
.collect::<HashSet<_>>()
);
}
}

mod parametrize_cases {
Expand Down
Loading

0 comments on commit d77d7c6

Please sign in to comment.