Skip to content

Commit

Permalink
Improved implicit borrowing and changed to resolve Askama variables a…
Browse files Browse the repository at this point in the history
…t compile time
  • Loading branch information
vallentin committed Dec 15, 2020
1 parent 5b01e60 commit 5202327
Showing 1 changed file with 139 additions and 49 deletions.
188 changes: 139 additions & 49 deletions askama_shared/src/generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use proc_macro2::Span;

use quote::{quote, ToTokens};

use std::collections::{HashMap, HashSet};
use std::collections::HashMap;
use std::path::PathBuf;
use std::{cmp, hash, mem, str};

Expand All @@ -20,7 +20,7 @@ pub fn generate<S: std::hash::BuildHasher>(
heritage: &Option<Heritage>,
integrations: Integrations,
) -> Result<String, CompileError> {
Generator::new(input, contexts, heritage, integrations, SetChain::new())
Generator::new(input, contexts, heritage, integrations, MapChain::new())
.build(&contexts[&input.path])
}

Expand All @@ -34,7 +34,7 @@ struct Generator<'a, S: std::hash::BuildHasher> {
// What integrations need to be generated
integrations: Integrations,
// Variables accessible directly from the current scope (not redirected to context)
locals: SetChain<'a, &'a str>,
locals: MapChain<'a, &'a str, Option<String>>,
// Suffix whitespace from the previous literal. Will be flushed to the
// output buffer unless suppressed by whitespace suppression on the next
// non-literal.
Expand All @@ -56,7 +56,7 @@ impl<'a, S: std::hash::BuildHasher> Generator<'a, S> {
contexts: &'n HashMap<&'n PathBuf, Context<'n>, S>,
heritage: &'n Option<Heritage>,
integrations: Integrations,
locals: SetChain<'n, &'n str>,
locals: MapChain<'n, &'n str, Option<String>>,
) -> Generator<'n, S> {
Generator {
input,
Expand All @@ -73,7 +73,7 @@ impl<'a, S: std::hash::BuildHasher> Generator<'a, S> {
}

fn child(&mut self) -> Generator<'_, S> {
let locals = SetChain::with_parent(&self.locals);
let locals = MapChain::with_parent(&self.locals);
Self::new(
self.input,
self.contexts,
Expand Down Expand Up @@ -592,7 +592,7 @@ impl<'a, S: std::hash::BuildHasher> Generator<'a, S> {
buf.write("(");
for (i, param) in params.iter().enumerate() {
if let MatchParameter::Name(p) = *param {
self.locals.insert(p);
self.locals.insert(p, None);
}
if i > 0 {
buf.write(", ");
Expand All @@ -606,9 +606,9 @@ impl<'a, S: std::hash::BuildHasher> Generator<'a, S> {
buf.write("{");
for (i, param) in params.iter().enumerate() {
if let Some(MatchParameter::Name(p)) = param.1 {
self.locals.insert(p);
self.locals.insert(p, None);
} else {
self.locals.insert(param.0);
self.locals.insert(param.0, None);
}

if i > 0 {
Expand Down Expand Up @@ -719,21 +719,58 @@ impl<'a, S: std::hash::BuildHasher> Generator<'a, S> {
let mut names = Buffer::new(0);
let mut values = Buffer::new(0);
for (i, arg) in def.args.iter().enumerate() {
if i > 0 {
names.write(", ");
values.write(", ");
let expr = args.get(i).ok_or_else(|| {
CompileError::String(format!("macro '{}' takes more than {} arguments", name, i))
})?;

match expr {
// If `expr` is already a form of variable then
// don't reintroduce a new variable. This is
// to avoid moving non-copyable values.
Expr::Var(name) => {
let var = self
.locals
.resolve_var(name)
.map(ToString::to_string)
.unwrap_or_else(|| format!("self.{}", name));
self.locals.insert(arg, Some(var));
}
Expr::Attr(obj, attr) => {
let mut attr_buf = Buffer::new(0);
self.visit_attr(&mut attr_buf, obj, attr)?;

let var = self
.locals
.resolve_var(&attr_buf.buf)
.map(ToString::to_string)
.unwrap_or(attr_buf.buf);
self.locals.insert(arg, Some(var));
continue;
}
// Everything else still needs to become variables,
// to avoid having the same logic be executed
// multiple times, e.g. in the case of macro
// parameters being used multiple times.
_ => {
if i > 0 {
names.write(", ");
values.write(", ");
}
names.write(arg);

values.write("(");
values.write(&self.visit_expr_root(expr)?);
values.write(")");
self.locals.insert(arg, None);
}
}
names.write(arg);
}

values.write("&(");
values.write(&self.visit_expr_root(args.get(i).ok_or_else(|| {
CompileError::String(format!("macro '{}' takes more than {} arguments", name, i))
})?)?);
values.write(")");
self.locals.insert(arg);
debug_assert_eq!(names.buf.is_empty(), values.buf.is_empty());
if !names.buf.is_empty() {
buf.writeln(&format!("let ({}) = ({});", names.buf, values.buf))?;
}

buf.writeln(&format!("let ({}) = ({});", names.buf, values.buf))?;
let mut size_hint = self.handle(own_ctx, &def.nodes, buf, AstLevel::Nested)?;

self.flush_ws(def.ws2);
Expand Down Expand Up @@ -794,13 +831,13 @@ impl<'a, S: std::hash::BuildHasher> Generator<'a, S> {
buf.write("let ");
match *var {
Target::Name(name) => {
self.locals.insert(name);
self.locals.insert(name, None);
buf.write(name);
}
Target::Tuple(ref targets) => {
buf.write("(");
for name in targets {
self.locals.insert(name);
self.locals.insert(name, None);
buf.write(name);
buf.write(",");
}
Expand All @@ -823,16 +860,16 @@ impl<'a, S: std::hash::BuildHasher> Generator<'a, S> {

match *var {
Target::Name(name) => {
if !self.locals.contains(name) {
if !self.locals.contains(&name) {
buf.write("let ");
self.locals.insert(name);
self.locals.insert(name, None);
}
buf.write(name);
}
Target::Tuple(ref targets) => {
buf.write("let (");
for name in targets {
self.locals.insert(name);
self.locals.insert(name, None);
buf.write(name);
buf.write(",");
}
Expand Down Expand Up @@ -1371,11 +1408,18 @@ impl<'a, S: std::hash::BuildHasher> Generator<'a, S> {
}

fn visit_var(&mut self, buf: &mut Buffer, s: &str) -> DisplayWrap {
if self.locals.contains(s) || s == "self" {
buf.write(s);
} else {
buf.write("self.");
if s == "self" {
buf.write(s);
return DisplayWrap::Unwrapped;
}

match self.locals.get(&s) {
Some(None) => buf.write(s),
Some(Some(mapped_var)) => buf.write(&mapped_var),
None => {
buf.write("self.");
buf.write(s);
}
}
DisplayWrap::Unwrapped
}
Expand All @@ -1387,7 +1431,7 @@ impl<'a, S: std::hash::BuildHasher> Generator<'a, S> {
args: &[Expr],
) -> Result<DisplayWrap, CompileError> {
buf.write("(");
if self.locals.contains(s) || s == "self" {
if self.locals.contains(&s) || s == "self" {
buf.write(s);
} else {
buf.write("self.");
Expand Down Expand Up @@ -1422,13 +1466,13 @@ impl<'a, S: std::hash::BuildHasher> Generator<'a, S> {
fn visit_target(&mut self, buf: &mut Buffer, target: &'a Target) {
match *target {
Target::Name(name) => {
self.locals.insert(name);
self.locals.insert(name, None);
buf.write(name);
}
Target::Tuple(ref targets) => {
buf.write("(");
for name in targets {
self.locals.insert(name);
self.locals.insert(name, None);
buf.write(name);
buf.write(",");
}
Expand Down Expand Up @@ -1523,51 +1567,73 @@ impl Buffer {
}
}

// type SetChain<'a, T> = MapChain<'a, T, ()>;

#[derive(Debug)]
struct SetChain<'a, T: 'a>
struct MapChain<'a, K: 'a, V: 'a>
where
T: cmp::Eq + hash::Hash,
K: cmp::Eq + hash::Hash,
{
parent: Option<&'a SetChain<'a, T>>,
scopes: Vec<HashSet<T>>,
parent: Option<&'a MapChain<'a, K, V>>,
scopes: Vec<HashMap<K, V>>,
}

impl<'a, T: 'a> SetChain<'a, T>
impl<'a, K: 'a, V: 'a> MapChain<'a, K, V>
where
T: cmp::Eq + hash::Hash,
K: cmp::Eq + hash::Hash,
{
fn new() -> SetChain<'a, T> {
SetChain {
fn new() -> MapChain<'a, K, V> {
MapChain {
parent: None,
scopes: vec![HashSet::new()],
scopes: vec![HashMap::new()],
}
}

fn with_parent<'p>(parent: &'p SetChain<T>) -> SetChain<'p, T> {
SetChain {
fn with_parent<'p>(parent: &'p MapChain<K, V>) -> MapChain<'p, K, V> {
MapChain {
parent: Some(parent),
scopes: vec![HashSet::new()],
scopes: vec![HashMap::new()],
}
}

fn contains(&self, val: T) -> bool {
self.scopes.iter().rev().any(|set| set.contains(&val))
fn contains(&self, key: &K) -> bool {
self.scopes.iter().rev().any(|set| set.contains_key(&key))
|| match self.parent {
Some(set) => set.contains(val),
Some(set) => set.contains(key),
None => false,
}
}

fn get(&self, key: &K) -> Option<&V> {
self.get_skip(key, 0)
}

/// Returns `Some` if `key` exists, while `skip` represents
/// the amount of previous scopes to skip in reverse.
fn get_skip(&self, key: &K, skip: usize) -> Option<&V> {
self.scopes
.iter()
.rev()
.skip(skip)
.filter_map(|set| set.get(&key))
.next()
.or_else(|| match self.parent {
Some(set) => set.get(key),
None => None,
})
}

fn is_current_empty(&self) -> bool {
self.scopes.last().unwrap().is_empty()
}

fn insert(&mut self, val: T) {
self.scopes.last_mut().unwrap().insert(val);
fn insert(&mut self, key: K, val: V) {
let old_val = self.scopes.last_mut().unwrap().insert(key, val);
assert!(old_val.is_none());
}

fn push(&mut self) {
self.scopes.push(HashSet::new());
self.scopes.push(HashMap::new());
}

fn pop(&mut self) {
Expand All @@ -1576,6 +1642,30 @@ where
}
}

impl MapChain<'_, &str, Option<String>> {
/// Given a variable `key` it resolves all the way
/// back to the initial the variable.
fn resolve_var<'a>(&'a self, key: &'a str) -> Option<&'a str> {
let mut key = key;
let mut var = None;
for i in 1.. {
let prev_var = self.get_skip(&key, i);
match prev_var {
Some(Some(prev_var)) => {
key = &prev_var;
var = Some(key);
}
Some(None) => {
var = Some(key);
break;
}
None => break,
}
}
var
}
}

fn median(sizes: &mut [usize]) -> usize {
sizes.sort_unstable();
if sizes.len() % 2 == 1 {
Expand Down

0 comments on commit 5202327

Please sign in to comment.