Skip to content
This repository has been archived by the owner on Jan 27, 2022. It is now read-only.

Commit

Permalink
chore(Py/WIP): Refactor parser to be recursive and handle more types
Browse files Browse the repository at this point in the history
  • Loading branch information
beneboy committed Sep 2, 2019
1 parent 398e658 commit 0199b1e
Showing 1 changed file with 208 additions and 86 deletions.
294 changes: 208 additions & 86 deletions py/stencila/schema/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class DataFrame:
class CodeChunkParseResult(typing.NamedTuple):
chunk_ast: typing.Optional[ast.Module] = None
imports: typing.List[typing.Union[str, SoftwareSourceCode]] = []
assigns: typing.List[typing.Union[Variable]] = []
assigns: typing.List[str] = []
declares: typing.List[typing.Union[Function, Variable]] = []
alters: typing.List[str] = []
uses: typing.List[str] = []
Expand Down Expand Up @@ -106,7 +106,10 @@ class DocumentCompilationResult:
imports: typing.List[str] = []


def annotation_name_to_schema(name: str) -> typing.Optional[SchemaTypes]:
def annotation_name_to_schema(name: typing.Optional[str]) -> typing.Optional[SchemaTypes]:
if not name:
return None

if name == 'bool':
return BooleanSchema()
elif name == 'str':
Expand Down Expand Up @@ -174,105 +177,223 @@ def set_code_error(code: typing.Union[CodeChunk, CodeExpression], e: typing.Unio
code.errors.append(e)


def parse_code_chunk(chunk: CodeChunk) -> CodeChunkParseResult:
try:
chunk_ast = ast.parse(chunk.text)
except Exception as e:
return CodeChunkParseResult(None, error=exception_to_code_error(e))
class CodeChunkParser:
imports: typing.List[str]
declares: typing.List[typing.Union[Variable, Function]]
assigns: typing.List[str]
alters: typing.List[str]
uses: typing.List[str]
reads: typing.List[str]

imports: typing.List[str] = []
assigns: typing.Set[Variable] = set()
declares: typing.Set[typing.Union[Function, Variable]] = set()
alters: typing.Set[str] = set()
uses: typing.Set[str] = set()
reads: typing.Set[str] = set()
seen_vars: typing.Set[str] = set()

# If this is True, then there should be a call to 'open' somewhere in the code, which means the parser should
# try to find it. This is a basic check so there might not be one (like if the code did , but if 'open(' is NOT in
# the string then there definitely ISN'T one
search_for_open = 'open(' in chunk.text

for statement in chunk_ast.body:
if isinstance(statement, ast.ImportFrom):
if statement.module not in imports:
imports.append(statement.module)
elif isinstance(statement, ast.Import):
for module_name in statement.names:
if module_name.name not in imports:
imports.append(module_name.name)
seen_vars: typing.List[str]

def reset(self) -> None:
self.imports = []
self.declares = []
self.assigns = []
self.alters = []
self.uses = []
self.reads = []

self.seen_vars = []

def add_variable(self, name: str, type_annotation: typing.Optional[str]) -> None:
if name in self.seen_vars:
return
v = Variable(name)
v.schema = annotation_name_to_schema(type_annotation)
self.seen_vars.append(name)
self.declares.append(v)

def add_name(self, name: str, target: typing.List) -> None:
if name not in self.seen_vars and name not in target:
self.seen_vars.append(name)
target.append(name)

def parse(self, chunk: CodeChunk) -> CodeChunkParseResult:
self.reset()

try:
chunk_ast = ast.parse(chunk.text)
except Exception as e:
return CodeChunkParseResult(None, error=exception_to_code_error(e))

# If this is True, then there should be a call to 'open' somewhere in the code, which means the parser should
# try to find it. This is a basic check so there might not be one (like if the code did , but if 'open(' is NOT in
# the string then there definitely ISN'T one
search_for_open = 'open(' in chunk.text

for statement in chunk_ast.body:
self.parse_statement(statement)

if search_for_open:
self.find_file_reads(chunk_ast)

return CodeChunkParseResult(chunk_ast, self.imports, self.assigns, self.declares, self.alters, self.uses,
self.reads)

def parse_statement(self, statement: typing.Union[ast.stmt, typing.List[ast.stmt]]) -> None:
if isinstance(statement, list):
for sub_statement in statement:
self.parse_statement(sub_statement)
elif isinstance(statement, ast.ImportFrom):
self.parse_import(statement)
elif isinstance(statement, (ast.Assign, ast.AnnAssign)):
self.parse_assigns(statement)
elif isinstance(statement, ast.BinOp):
self.parse_bin_op(statement)
elif isinstance(statement, ast.Call):
self.parse_call(statement)
elif isinstance(statement, ast.FunctionDef):
f = Function(statement.name)
f.parameters = []
self.parse_function_def(statement)
elif isinstance(statement, ast.Dict):
self.parse_dict(statement)
elif isinstance(statement, ast.List):
self.parse_statement(statement.elts)
elif isinstance(statement, ast.Name):
self.add_name(statement.id, self.uses)
elif isinstance(statement, ast.Expr):
self.parse_statement(statement.value)
elif isinstance(statement, ast.AugAssign):
self.parse_aug_assign(statement)
elif isinstance(statement, ast.If):
self.parse_if(statement)
elif isinstance(statement, ast.Compare):
self.parse_compare(statement)
elif isinstance(statement, ast.For):
self.parse_for(statement)
elif isinstance(statement, (ast.ClassDef, ast.Num, ast.Str)):
pass
else:
raise TypeError('Unrecognized statement: {}'.format(statement))

for i, arg in enumerate(statement.args.args):
p = Parameter(arg.arg)
def parse_import(self, statement: ast.ImportFrom) -> None:
if statement.module not in self.imports:
self.imports.append(statement.module)

if arg.annotation:
p.schema = annotation_name_to_schema(arg.annotation.id)
def recurse_attribute(self, attr: ast.Attribute) -> str:
if isinstance(attr.value, ast.Attribute):
return self.recurse_attribute(attr.value)
if isinstance(attr.value, ast.Name):
return attr.value.id

default_index = len(statement.args.defaults) - len(statement.args.args) + i
# Only the last len(statement.args.defaults) can have defaults (since they must come after non-default
# parameters)
if default_index >= 0:
p.default = statement.args.defaults[default_index].value
p.required = False
raise TypeError('Unable to determine name of attribute {}'.format(attr.value))

def parse_assigns(self, statement: typing.Union[ast.Assign, ast.AnnAssign]) -> None:
if hasattr(statement, 'targets'):
targets = statement.targets
elif hasattr(statement, 'target'):
targets = [statement.target]
else:
raise TypeError('{} has no target or targets'.format(statement))

for target in targets:
if isinstance(target, ast.Attribute):
self.add_name(self.recurse_attribute(target), self.alters)
continue

if isinstance(target, ast.Name):
if isinstance(statement, ast.AnnAssign):
annotation_name = statement.annotation.id if isinstance(statement.annotation, ast.Name) else None
self.add_variable(target.id, annotation_name)
else:
p.required = True
self.add_name(target.id, self.assigns)

f.parameters.append(p)
if getattr(statement, 'value', None) is not None:
self.parse_statement(statement.value)

declares.append(f)
elif isinstance(statement, (ast.Assign, ast.AnnAssign)):
if hasattr(statement, 'targets'):
targets = statement.targets
elif hasattr(statement, 'target'):
targets = [statement.target]
else:
raise TypeError('statement has no target or targets')

for target in targets:
is_alters = False
if hasattr(target, 'id'):
# simple variable set/declaration
target_name = target.id
elif hasattr(target, 'value'):
target_name = target.value.id
is_alters = True
def parse_bin_op(self, statement: ast.BinOp) -> None:
self.parse_statement(statement.left)
self.parse_statement(statement.right)

def parse_call(self, statement: ast.Call) -> None:
if hasattr(statement, 'args'):
self.parse_statement(statement.args)

if hasattr(statement, 'keywords'):
for kw in statement.keywords:
self.parse_statement(kw.value)

def parse_function_def(self, statement: ast.FunctionDef) -> None:
if statement.name in self.seen_vars:
return

return_ann = statement.returns.id if isinstance(statement.returns, ast.Name) else None

f = Function(statement.name, returns=annotation_name_to_schema(return_ann), parameters=[])

for i, arg in enumerate(statement.args.args):
p = Parameter(arg.arg)

if arg.annotation:
p.schema = annotation_name_to_schema(arg.annotation.id)

default_index = len(statement.args.defaults) - len(statement.args.args) + i
# Only the last len(statement.args.defaults) can have defaults (since they must come after non-default
# parameters)
if default_index >= 0:
default = statement.args.defaults[default_index]

if isinstance(default, ast.Num):
p.default = default.n
elif isinstance(default, ast.Str):
p.default = default.s
elif isinstance(default, ast.NameConstant):
# default of None/True/False
p.default = default.value
else:
raise ValueError("Don't know how to handle this")
self.parse_statement(default)
p.required = False
else:
p.required = True

if target_name not in seen_vars:
if is_alters:
alters.add(target.value.id)
continue
f.parameters.append(p)

v = Variable(target_name)
self.seen_vars.append(f.name)
self.declares.append(f)

if hasattr(statement, 'annotation'):
# assignment with Type Annotation
v.schema = annotation_name_to_schema(statement.annotation.id)
declares.add(v)
else:
assigns.add(v)
seen_vars.add(target_name)
seen_vars.add(target_name)
elif isinstance(statement, ast.Expr) and isinstance(statement.value, ast.Call):
if hasattr(statement.value, 'args'):
for arg in statement.value.args:
if isinstance(arg, ast.Name):
uses.add(arg.id)

if search_for_open:
def parse_dict(self, statement: ast.Dict) -> None:
for key in statement.keys:
if isinstance(key, ast.Name):
self.add_name(key.id, self.uses)
else:
self.parse_statement(key)
for value in statement.values:
if isinstance(value, ast.Name):
self.add_name(value.id, self.uses)
else:
self.parse_statement(value)

def parse_aug_assign(self, statement: ast.AugAssign) -> None:
if isinstance(statement.target, ast.Name):
self.add_name(statement.target.id, self.alters)
else:
self.parse_statement(statement.target)

self.parse_statement(statement.value)

def parse_if(self, statement: ast.If) -> None:
self.parse_statement(statement.test)
self.parse_statement(statement.body)
self.parse_statement(statement.orelse)

def parse_compare(self, statement: ast.Compare) -> None:
self.parse_statement(statement.left)
self.parse_statement(statement.comparators)

def parse_for(self, statement: ast.For) -> None:
if isinstance(statement.target, ast.Name):
self.add_name(statement.target.id, self.assigns)
self.parse_statement(statement.iter)
self.parse_statement(statement.body)

def find_file_reads(self, chunk_ast: ast.Module) -> None:
for node in ast.walk(chunk_ast):
if isinstance(node, ast.Call) and hasattr(node, 'func') and node.func.id == 'open':
filename = parse_open_filename(node)

if filename:
reads.add(filename)

return CodeChunkParseResult(chunk_ast, imports, list(assigns), list(declares), list(alters), list(uses),
list(reads))
if filename and filename not in self.reads:
self.reads.append(filename)


class DocumentCompiler:
Expand Down Expand Up @@ -301,7 +422,8 @@ def handle_item(self, item: typing.Any, compilation_result: DocumentCompilationR
if item.language == self.TARGET_LANGUAGE: # Only add Python code

if isinstance(item, CodeChunk):
cc_result = parse_code_chunk(item)
parser = CodeChunkParser()
cc_result = parser.parse(item)
item.imports = cc_result.imports
item.declares = cc_result.declares
item.assigns = cc_result.assigns
Expand Down

0 comments on commit 0199b1e

Please sign in to comment.