diff --git a/askama_shared/src/generator.rs b/askama_shared/src/generator.rs index 6a185c12a..ea8c09dda 100644 --- a/askama_shared/src/generator.rs +++ b/askama_shared/src/generator.rs @@ -10,7 +10,7 @@ use proc_macro2::Span; use quote::{quote, ToTokens}; -use std::collections::{HashMap, HashSet}; +use std::collections::HashMap; use std::path::PathBuf; use std::{cmp, hash, mem, str}; @@ -20,7 +20,7 @@ pub fn generate( heritage: &Option, integrations: Integrations, ) -> Result { - Generator::new(input, contexts, heritage, integrations, SetChain::new()) + Generator::new(input, contexts, heritage, integrations, MapChain::new()) .build(&contexts[&input.path]) } @@ -34,7 +34,7 @@ struct Generator<'a, S: std::hash::BuildHasher> { // What integrations need to be generated integrations: Integrations, // Variables accessible directly from the current scope (not redirected to context) - locals: SetChain<'a, &'a str>, + locals: MapChain<'a, &'a str, Option>, // Suffix whitespace from the previous literal. Will be flushed to the // output buffer unless suppressed by whitespace suppression on the next // non-literal. @@ -56,7 +56,7 @@ impl<'a, S: std::hash::BuildHasher> Generator<'a, S> { contexts: &'n HashMap<&'n PathBuf, Context<'n>, S>, heritage: &'n Option, integrations: Integrations, - locals: SetChain<'n, &'n str>, + locals: MapChain<'n, &'n str, Option>, ) -> Generator<'n, S> { Generator { input, @@ -73,7 +73,7 @@ impl<'a, S: std::hash::BuildHasher> Generator<'a, S> { } fn child(&mut self) -> Generator<'_, S> { - let locals = SetChain::with_parent(&self.locals); + let locals = MapChain::with_parent(&self.locals); Self::new( self.input, self.contexts, @@ -592,7 +592,7 @@ impl<'a, S: std::hash::BuildHasher> Generator<'a, S> { buf.write("("); for (i, param) in params.iter().enumerate() { if let MatchParameter::Name(p) = *param { - self.locals.insert(p); + self.locals.insert(p, None); } if i > 0 { buf.write(", "); @@ -606,9 +606,9 @@ impl<'a, S: std::hash::BuildHasher> Generator<'a, S> { buf.write("{"); for (i, param) in params.iter().enumerate() { if let Some(MatchParameter::Name(p)) = param.1 { - self.locals.insert(p); + self.locals.insert(p, None); } else { - self.locals.insert(param.0); + self.locals.insert(param.0, None); } if i > 0 { @@ -719,21 +719,58 @@ impl<'a, S: std::hash::BuildHasher> Generator<'a, S> { let mut names = Buffer::new(0); let mut values = Buffer::new(0); for (i, arg) in def.args.iter().enumerate() { - if i > 0 { - names.write(", "); - values.write(", "); + let expr = args.get(i).ok_or_else(|| { + CompileError::String(format!("macro '{}' takes more than {} arguments", name, i)) + })?; + + match expr { + // If `expr` is already a form of variable then + // don't reintroduce a new variable. This is + // to avoid moving non-copyable values. + Expr::Var(name) => { + let var = self + .locals + .resolve_var(name) + .map(ToString::to_string) + .unwrap_or_else(|| format!("self.{}", name)); + self.locals.insert(arg, Some(var)); + } + Expr::Attr(obj, attr) => { + let mut attr_buf = Buffer::new(0); + self.visit_attr(&mut attr_buf, obj, attr)?; + + let var = self + .locals + .resolve_var(&attr_buf.buf) + .map(ToString::to_string) + .unwrap_or(attr_buf.buf); + self.locals.insert(arg, Some(var)); + continue; + } + // Everything else still needs to become variables, + // to avoid having the same logic be executed + // multiple times, e.g. in the case of macro + // parameters being used multiple times. + _ => { + if i > 0 { + names.write(", "); + values.write(", "); + } + names.write(arg); + + values.write("("); + values.write(&self.visit_expr_root(expr)?); + values.write(")"); + self.locals.insert(arg, None); + } } - names.write(arg); + } - values.write("&("); - values.write(&self.visit_expr_root(args.get(i).ok_or_else(|| { - CompileError::String(format!("macro '{}' takes more than {} arguments", name, i)) - })?)?); - values.write(")"); - self.locals.insert(arg); + debug_assert_eq!(names.buf.is_empty(), values.buf.is_empty()); + if !names.buf.is_empty() { + buf.writeln(&format!("let ({}) = ({});", names.buf, values.buf))?; } - buf.writeln(&format!("let ({}) = ({});", names.buf, values.buf))?; let mut size_hint = self.handle(own_ctx, &def.nodes, buf, AstLevel::Nested)?; self.flush_ws(def.ws2); @@ -794,13 +831,13 @@ impl<'a, S: std::hash::BuildHasher> Generator<'a, S> { buf.write("let "); match *var { Target::Name(name) => { - self.locals.insert(name); + self.locals.insert(name, None); buf.write(name); } Target::Tuple(ref targets) => { buf.write("("); for name in targets { - self.locals.insert(name); + self.locals.insert(name, None); buf.write(name); buf.write(","); } @@ -823,16 +860,16 @@ impl<'a, S: std::hash::BuildHasher> Generator<'a, S> { match *var { Target::Name(name) => { - if !self.locals.contains(name) { + if !self.locals.contains(&name) { buf.write("let "); - self.locals.insert(name); + self.locals.insert(name, None); } buf.write(name); } Target::Tuple(ref targets) => { buf.write("let ("); for name in targets { - self.locals.insert(name); + self.locals.insert(name, None); buf.write(name); buf.write(","); } @@ -1371,11 +1408,18 @@ impl<'a, S: std::hash::BuildHasher> Generator<'a, S> { } fn visit_var(&mut self, buf: &mut Buffer, s: &str) -> DisplayWrap { - if self.locals.contains(s) || s == "self" { - buf.write(s); - } else { - buf.write("self."); + if s == "self" { buf.write(s); + return DisplayWrap::Unwrapped; + } + + match self.locals.get(&s) { + Some(None) => buf.write(s), + Some(Some(mapped_var)) => buf.write(&mapped_var), + None => { + buf.write("self."); + buf.write(s); + } } DisplayWrap::Unwrapped } @@ -1387,7 +1431,7 @@ impl<'a, S: std::hash::BuildHasher> Generator<'a, S> { args: &[Expr], ) -> Result { buf.write("("); - if self.locals.contains(s) || s == "self" { + if self.locals.contains(&s) || s == "self" { buf.write(s); } else { buf.write("self."); @@ -1422,13 +1466,13 @@ impl<'a, S: std::hash::BuildHasher> Generator<'a, S> { fn visit_target(&mut self, buf: &mut Buffer, target: &'a Target) { match *target { Target::Name(name) => { - self.locals.insert(name); + self.locals.insert(name, None); buf.write(name); } Target::Tuple(ref targets) => { buf.write("("); for name in targets { - self.locals.insert(name); + self.locals.insert(name, None); buf.write(name); buf.write(","); } @@ -1523,51 +1567,73 @@ impl Buffer { } } +// type SetChain<'a, T> = MapChain<'a, T, ()>; + #[derive(Debug)] -struct SetChain<'a, T: 'a> +struct MapChain<'a, K: 'a, V: 'a> where - T: cmp::Eq + hash::Hash, + K: cmp::Eq + hash::Hash, { - parent: Option<&'a SetChain<'a, T>>, - scopes: Vec>, + parent: Option<&'a MapChain<'a, K, V>>, + scopes: Vec>, } -impl<'a, T: 'a> SetChain<'a, T> +impl<'a, K: 'a, V: 'a> MapChain<'a, K, V> where - T: cmp::Eq + hash::Hash, + K: cmp::Eq + hash::Hash, { - fn new() -> SetChain<'a, T> { - SetChain { + fn new() -> MapChain<'a, K, V> { + MapChain { parent: None, - scopes: vec![HashSet::new()], + scopes: vec![HashMap::new()], } } - fn with_parent<'p>(parent: &'p SetChain) -> SetChain<'p, T> { - SetChain { + fn with_parent<'p>(parent: &'p MapChain) -> MapChain<'p, K, V> { + MapChain { parent: Some(parent), - scopes: vec![HashSet::new()], + scopes: vec![HashMap::new()], } } - fn contains(&self, val: T) -> bool { - self.scopes.iter().rev().any(|set| set.contains(&val)) + fn contains(&self, key: &K) -> bool { + self.scopes.iter().rev().any(|set| set.contains_key(&key)) || match self.parent { - Some(set) => set.contains(val), + Some(set) => set.contains(key), None => false, } } + fn get(&self, key: &K) -> Option<&V> { + self.get_skip(key, 0) + } + + /// Returns `Some` if `key` exists, while `skip` represents + /// the amount of previous scopes to skip in reverse. + fn get_skip(&self, key: &K, skip: usize) -> Option<&V> { + self.scopes + .iter() + .rev() + .skip(skip) + .filter_map(|set| set.get(&key)) + .next() + .or_else(|| match self.parent { + Some(set) => set.get(key), + None => None, + }) + } + fn is_current_empty(&self) -> bool { self.scopes.last().unwrap().is_empty() } - fn insert(&mut self, val: T) { - self.scopes.last_mut().unwrap().insert(val); + fn insert(&mut self, key: K, val: V) { + let old_val = self.scopes.last_mut().unwrap().insert(key, val); + assert!(old_val.is_none()); } fn push(&mut self) { - self.scopes.push(HashSet::new()); + self.scopes.push(HashMap::new()); } fn pop(&mut self) { @@ -1576,6 +1642,30 @@ where } } +impl MapChain<'_, &str, Option> { + /// Given a variable `key` it resolves all the way + /// back to the initial the variable. + fn resolve_var<'a>(&'a self, key: &'a str) -> Option<&'a str> { + let mut key = key; + let mut var = None; + for i in 1.. { + let prev_var = self.get_skip(&key, i); + match prev_var { + Some(Some(prev_var)) => { + key = &prev_var; + var = Some(key); + } + Some(None) => { + var = Some(key); + break; + } + None => break, + } + } + var + } +} + fn median(sizes: &mut [usize]) -> usize { sizes.sort_unstable(); if sizes.len() % 2 == 1 {