-
Notifications
You must be signed in to change notification settings - Fork 1.8k
fix(speedup): re-write aten schema parser to support pytorch versions < 1.9.0 #5138
Conversation
I suggest to keep both version and use a flag to switch. |
Azure Pipelines successfully started running 1 pipeline(s). |
/azp run full test - compression |
Azure Pipelines successfully started running 1 pipeline(s). |
@@ -486,7 +748,10 @@ def generate_aten_to_python(func: Callable, node: NodePyGroup, speedup: ModelSpe | |||
c_node = node.key_node | |||
|
|||
schema = c_node.schema() | |||
positional_num, keyword_list, special_treat = parse_aten_schema(schema) | |||
if torch.__version__ < '1.9.0': |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the type of torch.__version__
is not a str
. it's a version type that can correctly compare with str.
Parse the schema, to positional_num and keyword_list, and detect if the argument should be specially treated. | ||
Cannot use 'torch._C.parse_schema' because 'torch._C.Argument' has no 'kwarg_only' on pytorch v1.8.x | ||
Using a lexer-parser like method to parse it. | ||
Re-write from torch/csrc/jit/frontend/function_schema_parser.cpp |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if this function a mirror rewrite from this cpp
file?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not totally equal to the original function in cpp. Actually, the original parser in that cpp file is too specific and too long (5k+ lines). I wrote a much simpler syntax analyzer, deleted some error check and specific type check to shorten the code.
/azp run full test - compression |
Azure Pipelines successfully started running 1 pipeline(s). |
/azp run full test - nas |
Azure Pipelines successfully started running 1 pipeline(s). |
/azp run fast test |
Azure Pipelines successfully started running 1 pipeline(s). |
Description
reason:
the official aten schema parser of python lost the 'kwarg_only' info in python versions <1.9.0.
and schema is not simple enough to use some regex to parser.
solution:
so I translated the schema parser code from pytorch 1.12 c++ code to python and adapted some deleted syntaxes in pytorch 1.8.
tested in pytorch versions 1.8, 1.9, 1.10, 1.11 and 1.12.
note:
torch._C.Node
supportsschema
function since version 1.8. so speedup cannot use on pytorch 1.7 now.#5131 can be solved.
Test Options
Checklist
How to test