Skip to content

Commit

Permalink
Inline comprehensions
Browse files Browse the repository at this point in the history
Summary:
This diff adds support for inlining list/dict/set comprehensions where it is considered safe - names introduced by inlined comprehension will not conflict with local names used  in comprehensions or free/implicitly global names used in sibling scopes. It also only inlines comprehensions in functions - inlining for top level statements comes with additional set of challenges and I'm not sure whether adding extra complexity to handle something that is executed once would be worth it.

After inlining comprehension we generate the code to delete locals added by comprehension to avoid adding extra references that are not controlled by user. This works fine for non-exceptional case however in case of exception being raised by the comprehension lifetime of object referenced by comprehension iteration variable will be extended until execution leaves current frame. Another related issue is - if original iterable being used in comprehension yields no values, comprehension iteration variable will stay unbound and `DELETE_FAST` would fail. To handle this we can either:
- relax requirements to `DELETE_FAST` so deleting unbound name would be no-op
 - have a dedicated opcode that would behave as relaxed `DELETE_FAST`
- keep `DELETE_FAST` relaxed (similar to (1)) but change generated code for `del x` to be `LOAD_FAST; POP_TOP; DELETE_FAST` so name binding would still be checked by `LOAD_FAST` (suggested by DinoV)

This diff currently uses option 1 as the simplest one but this could be changed.

Reviewed By: vladima

Differential Revision: D28940584

fbshipit-source-id: b5b7512
  • Loading branch information
DinoV authored and facebook-github-bot committed Nov 12, 2021
1 parent 9e8e3a5 commit 03040db
Show file tree
Hide file tree
Showing 22 changed files with 3,732 additions and 2,765 deletions.
6 changes: 6 additions & 0 deletions Include/pythonrun.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,12 @@ PyAPI_FUNC(struct symtable *) _Py_SymtableStringObjectFlags(
PyObject *filename,
int start,
PyCompilerFlags *flags);
PyAPI_FUNC(struct symtable *) _Py_SymtableStringObjectFlagsOptFlags(
const char *str,
PyObject *filename,
int start,
PyCompilerFlags *flags,
int inline_comprehensions);
#endif

PyAPI_FUNC(void) PyErr_Print(void);
Expand Down
7 changes: 7 additions & 0 deletions Include/symtable.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ struct symtable {
the symbol table */
int recursion_depth; /* current recursion depth */
int recursion_limit; /* recursion limit */
int st_inline_comprehensions;
};

typedef struct _symtable_entry {
Expand Down Expand Up @@ -64,6 +65,7 @@ typedef struct _symtable_entry {
int ste_col_offset; /* offset of first line of block */
int ste_opt_lineno; /* lineno of last exec or import * */
int ste_opt_col_offset; /* offset of last exec or import * */
unsigned int ste_inlined_comprehension; /* true is comprehension is inlined and symbols were already merged in parent scope */
struct symtable *ste_table;
} PySTEntryObject;

Expand All @@ -81,6 +83,11 @@ PyAPI_FUNC(struct symtable *) PySymtable_BuildObject(
mod_ty mod,
PyObject *filename,
PyFutureFeatures *future);
PyAPI_FUNC(struct symtable *) _PySymtable_BuildObjectOptFlags(
mod_ty mod,
PyObject *filename,
PyFutureFeatures *future,
int inline_comprehensions);
PyAPI_FUNC(PySTEntryObject *) PySymtable_Lookup(struct symtable *, void *);

PyAPI_FUNC(void) PySymtable_Free(struct symtable *);
Expand Down
99 changes: 87 additions & 12 deletions Lib/compiler/pycodegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ class CodeGenerator(ASTVisitor):
class_name = None # provide default for instance variable
future_flags = 0
flow_graph = pyassem.PyFlowGraph
_SymbolVisitor = symbols.SymbolVisitor

def __init__(
self,
Expand Down Expand Up @@ -912,7 +913,7 @@ def get_qual_prefix(self, gen):
while not isinstance(parent, symbols.ModuleScope):
# Only real functions use "<locals>", nested scopes like
# comprehensions don't.
if type(parent) in (symbols.FunctionScope, symbols.LambdaScope):
if parent.is_function_scope:
prefix = parent.name + ".<locals>." + prefix
else:
prefix = parent.name + "." + prefix
Expand Down Expand Up @@ -961,7 +962,9 @@ def compile_comprehension(
if opcode:
gen.emit(opcode, oparg)

gen.compile_comprehension_generator(node.generators, 0, elt, val, type(node))
gen.compile_comprehension_generator(
node.generators, 0, elt, val, type(node), True
)

if not isinstance(node, ast.GeneratorExp):
gen.emit("RETURN_VALUE")
Expand Down Expand Up @@ -1001,19 +1004,27 @@ def visitDictComp(self, node):
node, sys.intern("<dictcomp>"), node.key, node.value, "BUILD_MAP"
)

def compile_comprehension_generator(self, generators, gen_index, elt, val, type):
def compile_comprehension_generator(
self, generators, gen_index, elt, val, type, outermost_gen_is_param
):
if generators[gen_index].is_async:
self.compile_async_comprehension(generators, gen_index, elt, val, type)
self.compile_async_comprehension(
generators, gen_index, elt, val, type, outermost_gen_is_param
)
else:
self.compile_sync_comprehension(generators, gen_index, elt, val, type)
self.compile_sync_comprehension(
generators, gen_index, elt, val, type, outermost_gen_is_param
)

def compile_async_comprehension(self, generators, gen_index, elt, val, type):
def compile_async_comprehension(
self, generators, gen_index, elt, val, type, outermost_gen_is_param
):
start = self.newBlock("start")
except_ = self.newBlock("except")
if_cleanup = self.newBlock("if_cleanup")

gen = generators[gen_index]
if gen_index == 0:
if gen_index == 0 and outermost_gen_is_param:
self.loadName(".0")
else:
self.visit(gen.iter)
Expand All @@ -1033,7 +1044,9 @@ def compile_async_comprehension(self, generators, gen_index, elt, val, type):

gen_index += 1
if gen_index < len(generators):
self.compile_comprehension_generator(generators, gen_index, elt, val, type)
self.compile_comprehension_generator(
generators, gen_index, elt, val, type, False
)
elif type is ast.GeneratorExp:
self.visit(elt)
self.emit("YIELD_VALUE")
Expand All @@ -1056,14 +1069,16 @@ def compile_async_comprehension(self, generators, gen_index, elt, val, type):
self.nextBlock(except_)
self.emit("END_ASYNC_FOR")

def compile_sync_comprehension(self, generators, gen_index, elt, val, type):
def compile_sync_comprehension(
self, generators, gen_index, elt, val, type, outermost_gen_is_param
):
start = self.newBlock("start")
skip = self.newBlock("skip")
if_cleanup = self.newBlock("if_cleanup")
anchor = self.newBlock("anchor")

gen = generators[gen_index]
if gen_index == 0:
if gen_index == 0 and outermost_gen_is_param:
self.loadName(".0")
else:
self.visit(gen.iter)
Expand All @@ -1080,7 +1095,9 @@ def compile_sync_comprehension(self, generators, gen_index, elt, val, type):

gen_index += 1
if gen_index < len(generators):
self.compile_comprehension_generator(generators, gen_index, elt, val, type)
self.compile_comprehension_generator(
generators, gen_index, elt, val, type, False
)
else:
if type is ast.GeneratorExp:
self.visit(elt)
Expand Down Expand Up @@ -2319,7 +2336,7 @@ def make_code_gen(
):
if ast_optimizer_enabled:
tree = cls.optimize_tree(optimize, tree)
s = symbols.SymbolVisitor()
s = cls._SymbolVisitor()
walk(tree, s)

graph = cls.flow_graph(
Expand Down Expand Up @@ -2361,6 +2378,7 @@ def __init__(self, kind, block, exit):

class CinderCodeGenerator(CodeGenerator):
flow_graph = pyassem.PyFlowGraphCinder
_SymbolVisitor = symbols.CinderSymbolVisitor

def set_qual_name(self, qualname):
self._qual_name = qualname
Expand Down Expand Up @@ -2463,6 +2481,63 @@ def findFutures(self, node):
future_flags |= consts.CO_FUTURE_LAZY_IMPORTS
return future_flags

def compile_comprehension(self, node, name, elt, val, opcode, oparg=0):
self.update_lineno(node)
# fetch the scope that correspond to comprehension
scope = self.scopes[node]
if scope.inlined:
# for inlined comprehension process with current generator
gen = self
else:
gen = self.make_func_codegen(
node, self.conjure_arguments([ast.arg(".0", None)]), name, node.lineno
)

if opcode:
gen.emit(opcode, oparg)

gen.compile_comprehension_generator(
node.generators, 0, elt, val, type(node), not scope.inlined
)

if scope.inlined:
# collect list of defs that were introduced by comprehension
# note that we need to exclude:
# - .0 parameter since it is used
# - non-local names (typically named expressions), they are
# defined in enclosing scope and thus should not be deleted
to_delete = [
v
for v in scope.defs
if v != ".0" and v not in scope.nonlocals and v not in scope.cells
]
# sort names to have deterministic deletion order
to_delete.sort()
for v in to_delete:
self.delName(v)
return

if not isinstance(node, ast.GeneratorExp):
gen.emit("RETURN_VALUE")

gen.finishFunction()

self._makeClosure(gen, 0)

# precomputation of outmost iterable
self.visit(node.generators[0].iter)
if node.generators[0].is_async:
self.emit("GET_AITER")
else:
self.emit("GET_ITER")
self.emit("CALL_FUNCTION", 1)

if gen.scope.coroutine and type(node) is not ast.GeneratorExp:
self.emit("GET_AWAITABLE")
self.emit("LOAD_CONST", None)
self.emit("YIELD_FROM")


def get_default_generator():

if "cinder" in sys.version:
Expand Down
2 changes: 1 addition & 1 deletion Lib/compiler/static/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ def _bind(
if name not in self.modules:
tree = self.add_module(name, filename, tree, optimize)
# Analyze variable scopes
s = SymbolVisitor()
s = self.code_generator._SymbolVisitor()
s.visit(tree)

# Analyze the types of objects within local scopes
Expand Down
2 changes: 1 addition & 1 deletion Lib/compiler/static/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@
from ..optimizer import AstOptimizer
from ..pyassem import Block
from ..pycodegen import FOR_LOOP, CodeGenerator
from ..symbols import SymbolVisitor
from ..symbols import SymbolVisitor, CinderSymbolVisitor
from ..symbols import Scope, ModuleScope
from ..unparse import to_expr
from ..visitor import ASTVisitor
Expand Down
2 changes: 1 addition & 1 deletion Lib/compiler/strict/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def make_code_gen(
) -> StrictCodeGenerator:
if ast_optimizer_enabled:
tree = cls.optimize_tree(optimize, tree)
s = symbols.SymbolVisitor()
s = cls._SymbolVisitor()
walk(tree, s)

graph = cls.flow_graph(
Expand Down
Loading

0 comments on commit 03040db

Please sign in to comment.