Skip to content

Commit

Permalink
Merge pull request #1107 from pytorch/fb-sync-wwei6
Browse files Browse the repository at this point in the history
fx2trt] Modify lower setting class to accommandate AIT lowering
  • Loading branch information
Wei authored Jun 9, 2022
2 parents 2e09ce5 + 4df1d24 commit e4e02e1
Showing 1 changed file with 30 additions and 35 deletions.
65 changes: 30 additions & 35 deletions py/torch_tensorrt/fx/lower_setting.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,65 +10,64 @@


@dc.dataclass
class LowerSetting:
class LowerSettingBasic:
"""
Basic class for lowering.
max_batch_size: The maximum batch size for lowering job.
If run with TensorRT lowering, this is the maximum
batch size which can be used at execution time,
and also the batch size for which the ICudaEngine
will be optimized.
If run with AITemplate lowering, this the max batch_size
for the model.
lower_precision: lower precision dtype during lowering.
min_acc_module_size(int): minimal number of nodes for an accelerate submodule.
ast_rewriter_allow_list (Optional[Set[nn.Module]]): Optional allow list of
modules that need AST rewriting. This is aiming to eliminate input variable involve in
exception checking control flow.
leaf_module_list (Optional[Set[nn.Module]]): Optional leaf module list where
modules will not be traced into.
verbose_profile (bool): verbosity of profiler, default to False.
"""
Basic configuration for lowering stack.

Args:
max_batch_size: The maximum batch size which can be used at execution time,
and also the batch size for which the ICudaEngine will be optimized.
max_batch_size: int = 2048
lower_precision: LowerPrecision = LowerPrecision.FP32
min_acc_module_size: int = 10
ast_rewriter_allow_list: Optional[Set[Type[nn.Module]]] = None
leaf_module_list: Optional[Set[Type[nn.Module]]] = None
verbose_profile: bool = False


@dc.dataclass
class LowerSetting(LowerSettingBasic):
"""
Basic configuration for lowering stack.
Args:
input_specs: Specs for inputs to engine, can either be a single size or a
range defined by Min, Optimal, Max sizes.
explicit_batch_dimension: Use explicit batch dimension during lowering.
explicit_precision: Use explicit precision during lowering.
lower_precision: lower precision dtype during lowering.
max_workspace_size: The maximum workspace size. The maximum GPU temporary
memory which the TensorRT engine can use at execution time.
strict_type_constraints: Require TensorRT engine to strictly follow data type
setting at execution time.
customized_fuse_pass: List of custmozied pass to apply during lowering process.
lower_basic_fuse_pass: Enable basic pass fuse duirng lowering, i.e. fuse multiple operations
as (a->b->c->d)=>(e). Current basic fuse patterns are:
permute->linear
permute->matmul
verbose_log: Enable TensorRT engine verbose log mode.
algo_selector: Enable TensorRT algorithm selector at execution time.
timing_cache_prefix: TensorRT timing cache file path. TensorRT engine will use timing
cache file at execution time if valid timing cache file is provided.
save_timing_cache: Save updated timing cache data into timing cache file if the timing
cache file is provided.
ast_rewriter_allow_list (Optional[Set[nn.Module]]): Optional allow list of
modules that need AST rewriting. This is aiming to eliminate input variable involve in
exception checking control flow.
leaf_module_list (Optional[Set[nn.Module]]): Optional leaf module list where
modules will not be traced into.
cuda_graph_batch_size (int): Cuda graph batch size, default to be -1.
verbose_profile (bool): verbosity of profiler, default to False.
min_acc_module_size(int): minimal number of nodes for an accelerate submodule.
"""

max_batch_size: int = 2048
input_specs: List[InputTensorSpec] = dc.field(default_factory=list)
explicit_batch_dimension: bool = True
explicit_precision: bool = False
lower_precision: LowerPrecision = LowerPrecision.FP32
max_workspace_size: int = 1 << 30
strict_type_constraints: bool = False
customized_fuse_pass: PassManager = PassManager.build_from_passlist([])
Expand All @@ -79,8 +78,4 @@ class LowerSetting:
algo_selector = None
timing_cache_prefix: str = ""
save_timing_cache: bool = False
ast_rewriter_allow_list: Optional[Set[Type[nn.Module]]] = None
leaf_module_list: Optional[Set[Type[nn.Module]]] = None
cuda_graph_batch_size: int = -1
verbose_profile: bool = False
min_acc_module_size: int = 10

0 comments on commit e4e02e1

Please sign in to comment.