Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

fix(speedup): re-write aten schema parser to support pytorch versions < 1.9.0 #5138

Merged
merged 5 commits into from
Oct 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
269 changes: 267 additions & 2 deletions nni/compression/pytorch/speedup/jit_translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from nni.common.graph_utils import NodePyGroup

import re
import string
import logging
from functools import partial, lru_cache
import copy
Expand Down Expand Up @@ -394,10 +395,11 @@ def arg_trans_layout(ivalue: Union[int, torch.layout]):
# ce=None, bool? pin_memory=None) -> (Tensor"""'
}

@lru_cache(maxsize=256)
@lru_cache
def parse_aten_schema(schema: str):
"""
Parse the schema, to positional_num and keyword_list, and detect if the argument should be specially treated.
only available on pytorch >= v1.9.0
"""
if schema in schema_fix_dict:
schema = schema_fix_dict[schema]
Expand All @@ -422,6 +424,266 @@ def parse_aten_schema(schema: str):

return positional_num, keyword_list, special_treat

@lru_cache
def parse_aten_schema_version_1_8_x(schema: str):
"""
Parse the schema, to positional_num and keyword_list, and detect if the argument should be specially treated.
Cannot use 'torch._C.parse_schema' because 'torch._C.Argument' has no 'kwarg_only' on pytorch v1.8.x
Using a lexer-parser like method to parse it.
Re-write from torch/csrc/jit/frontend/function_schema_parser.cpp
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if this function a mirror rewrite from this cpp file?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not totally equal to the original function in cpp. Actually, the original parser in that cpp file is too specific and too long (5k+ lines). I wrote a much simpler syntax analyzer, deleted some error check and specific type check to shorten the code.

"""
if schema in schema_fix_dict:
schema = schema_fix_dict[schema]

single_solid_tokens = [
'(', ')', '[', ']',
'+', '-', '!', '>',
'|', '=', ':', '.', ',',
'?', '*',
]
# no '>=', '<=', '&', '/'
# '|' only occurs in 'Tensor(a|b)'
spec_tokens = [
'numdigits', 'string', 'quoted', 'unknown',
]
str_chars_first = (*string.ascii_letters, '_')
str_chars = (*string.ascii_letters, *string.digits, '_')
num_chars_first = (*string.digits,)
num_chars_16 = (*string.digits, *string.ascii_lowercase[:6], *string.ascii_uppercase[:6])

tokens = list()
# 1: in ('\'', '"'); 2: in num; 3: in str;
status = 0
status_esc_char = False

for char in schema:
if status == 1:
if status_esc_char:
status_esc_char = False
tokens[-1][1] += char
elif char == '\\':
status_esc_char = True
else:
tokens[-1][1] += char
if char == tokens[-1][1][0]:
status = 0
continue
elif status == 2:
if char in num_chars_16:
tokens[-1][1] += char
continue
else:
status = 0
elif status == 3:
if char in str_chars:
tokens[-1][1] += char
continue
else:
status = 0
if status == 0:
if char in single_solid_tokens:
tokens.append(char)
elif char in ('\'', '\"'):
tokens.append(['quoted', char])
status = 1
elif char in num_chars_first:
tokens.append(['numdigits', char])
status = 2
elif char in str_chars_first:
tokens.append(['string', char])
status = 3
elif char not in ('\n', ' ', '\t'):
tokens.append(['unknown', char])
assert status == 0

index = 0
def next_pass(index_diff = 1) -> str:
nonlocal index
index += index_diff
if index_diff == 1:
return tokens[index - 1]

def next_if(tk: str, index_diff=0) -> bool:
nonlocal index
if tk in spec_tokens:
return isinstance(tokens[index + index_diff], list) and tokens[index + index_diff][0] == tk
else:
return tokens[index + index_diff] == tk

def next_if_pass_value(tk: str, default_value = None) -> Optional[str]:
nonlocal index
if tk in spec_tokens:
if isinstance(tokens[index], list) and tokens[index][0] == tk:
index += 1
return tokens[index - 1][1]
else:
if tokens[index] == tk:
index += 1
return tk
return default_value

def next_expect_pass_value(tk: str) -> str:
nonlocal index
if tk in spec_tokens:
if not isinstance(tokens[index], list) or tokens[index][0] != tk:
raise Exception('aten schema parse error')
ret = tokens[index][1]
else:
if tokens[index] != tk:
raise Exception('aten schema parse error')
ret = tk
index += 1
return ret

def parse_number():
if next_if('+') or next_if('-'):
value = next_pass() + next_expect_pass_value('numdigits')
elif (get := next_if_pass_value('numdigits')) is not None:
value = get
else:
return None
if next_if_pass_value('.') is not None:
value += '.'
if (get := next_if_pass_value('numdigits')):
value += get
if value[-1] == 'e' and next_if_pass_value('-') is not None:
# only occur in versions < 1.9.0
# 1e-10
value += '-' + next_expect_pass_value('numdigits')
return value

def parse_name():
name = next_expect_pass_value('string')
if next_if_pass_value(':') is not None:
next_expect_pass_value(':')
name += '::' + next_expect_pass_value('string')
overload_name = ''
if next_if_pass_value('.') is not None:
overload_name = next_expect_pass_value('string')
return name, overload_name

def parse_list(sep, end, callback):
ret = []
if end is None or not next_if(end):
ret.append(callback())
while (get := next_if_pass_value(sep)) is not None:
ret.append(get)
ret.append(callback())
if end is not None:
ret.append(next_expect_pass_value(end))
return ret

def parse_alias_annotation():
if next_if_pass_value('(') is not None:
def parse_inner():
if next_if_pass_value('*') is not None:
return '*'
else:
return next_expect_pass_value('string')

value = '('.join(parse_list('|', None, parse_inner))
value += next_if_pass_value('!', '')
if next_if('-') and next_if('>', 1):
next_pass(2)
value += '->'
value += ''.join(parse_list('|', None, parse_inner))
return value + next_expect_pass_value(')')
else:
return next_if_pass_value('!', '')

def parse_type():
if next_if_pass_value('(') is not None:
value = ''.join(parse_list(',', ')', parse_type))
else:
value = next_expect_pass_value('string')
if value == '__torch__':
# only occur in versions < 1.9.0
while (get := next_if_pass_value('.')) is not None:
value += get + next_expect_pass_value('string')
if next_if_pass_value('('):
the_types = ''.join(parse_list(',', ')', parse_type))
value += '(%s)' % the_types
value += parse_alias_annotation()
while True:
if next_if('[') and next_if(']', 1):
next_pass(2)
value += '[]'
value += parse_alias_annotation()
elif next_if_pass_value('?') is not None:
value += '?'
elif next_if_pass_value('-') is not None:
# only occur in versions < 1.9.0
# t(x -> *)
value += '-' + next_expect_pass_value('>') + next_expect_pass_value('*')
break
else:
break
return value

def parse_default_value():
if next_if_pass_value('[') is not None:
return parse_list(',', ']', parse_default_value)
elif (get := parse_number()) is not None:
return get
elif (get := next_if_pass_value('quoted')) is not None:
return get
else:
return next_expect_pass_value('string')

def parse_argument():
the_type = parse_type()
if next_if_pass_value('[') is not None:
the_type += '[' + parse_number() + next_expect_pass_value(']')
the_type += parse_alias_annotation()
the_type += next_if_pass_value('?', '')
name = next_expect_pass_value('string')
default_value = ''
if next_if_pass_value('=') is not None:
default_value = parse_default_value()
return the_type, name, default_value

def parse_declaration():
name, overload_name = parse_name()
arguments = list()
kwarg_only = False
is_vararg = False
next_expect_pass_value('(')
def parse_inner():
nonlocal kwarg_only
nonlocal is_vararg
if is_vararg:
raise Exception('"..." must be the last element')
elif next_if_pass_value('*') is not None:
kwarg_only = True
elif next_if_pass_value('.') is not None:
next_expect_pass_value('.')
next_expect_pass_value('.')
is_vararg = True
else:
arguments.append((parse_argument()[1], kwarg_only))
parse_list(',', ')', parse_inner)
return name, overload_name, arguments, is_vararg

positional_num = 0
keyword_list = list()
special_treat = dict() # for dtype and memory_format trans now

for name, kwarg_only in parse_declaration()[2]:
if not kwarg_only:
key = positional_num
positional_num += 1
else:
key = name
keyword_list.append(key)

if name in special_treat_dict:
if key not in special_treat:
special_treat[key] = [special_treat_dict[name]]
else:
special_treat[key].append(special_treat_dict[name])

return positional_num, keyword_list, special_treat

def parse_input_value(speedup: ModelSpeedup, input_nodes: List[torch._C.Node], positional_num: int, keyword_list: List[str]):
"""
translate inputs, to constant positional arguments, constant keyword arguments, and undetermined positions
Expand Down Expand Up @@ -486,7 +748,10 @@ def generate_aten_to_python(func: Callable, node: NodePyGroup, speedup: ModelSpe
c_node = node.key_node

schema = c_node.schema()
positional_num, keyword_list, special_treat = parse_aten_schema(schema)
if torch.__version__ < '1.9.0':
Copy link
Contributor

@J-shang J-shang Oct 17, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image

I'm worried with this method of judging version.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the type of torch.__version__ is not a str. it's a version type that can correctly compare with str.

positional_num, keyword_list, special_treat = parse_aten_schema_version_1_8_x(schema)
else:
positional_num, keyword_list, special_treat = parse_aten_schema(schema)

input_nodes = list(c_node.inputs())
positional, keyword, undetermined = parse_input_value(speedup, input_nodes, positional_num, keyword_list)
Expand Down
50 changes: 50 additions & 0 deletions test/algo/compression/v2/test_schema_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import unittest

import torch

from nni.compression.pytorch.speedup.jit_translate import parse_aten_schema_version_1_8_x, schema_fix_dict, special_treat_dict

def parse_aten_schema_origin(schema: str):
if schema in schema_fix_dict:
schema = schema_fix_dict[schema]

positional_num = 0
keyword_list = list()
special_treat = dict() # for dtype and memory_format trans now

for arg in torch._C.parse_schema(schema).arguments:
if torch.__version__ < '1.9.0' or not arg.kwarg_only:
key = positional_num
positional_num += 1
else:
key = arg.name
keyword_list.append(key)

if arg.name in special_treat_dict:
if key not in special_treat:
special_treat[key] = [special_treat_dict[arg.name]]
else:
special_treat[key].append(special_treat_dict[arg.name])

return positional_num, keyword_list, special_treat

class SchemaParserTestCase(unittest.TestCase):
def test_diff_manual_parser(self):
all_schema_list = (str(i) for i in torch._C._jit_get_all_schemas())
for schema in all_schema_list:
if not schema.startswith('aten::'):
continue
if torch.__version__ < '1.9.0' and '*,' in schema:
continue
positional_num_origin, keyword_list_origin, special_treat_origin = parse_aten_schema_origin(schema)
positional_num_manual, keyword_list_manual, special_treat_manual = parse_aten_schema_version_1_8_x(schema)

assert positional_num_origin == positional_num_manual
assert keyword_list_origin == keyword_list_manual
assert special_treat_origin == special_treat_manual

if __name__ == '__main__':
unittest.main()