-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Collage] CollagePartition pass (#12086)
* [Collage] CollagePartition pass See https://github.com/apache/tvm-rfcs/blob/main/rfcs/0062-collage.md. This adds the main CollagePartition pass, which: 1. Inspects all the targets in the CompilationConfig and builds PartitionSpecs describing how to generate speculative CandidatePartitions for them. 2. Runs the above rules on the model to collect all the candidates. 3. Eliminates candidates whose target contradicts any constraints already imposed by, eg, device planning. 4. Eagerly estimates the cost of each candidate. 5. Performs a shortest path search to chose an 'optimal' set of candidate partitions so as to minimize estimated model latency, such that every sub-expression node is contained in exactly one candidate partition. 6. Coalesces adjacent optimal candidates which ended up on the same target. 7. Rewrites the model according to the chosen optimal partitioning. As for the existing partition_for_<external codegen name> methods, the result of CollagePartition can then be built using regular TVM. Very special thanks to @mbaret for authoring test_pass_collage_partition.py. Logic to prune the candidates after step 3 will be in a follow up PR since it deserves its own testing. A demonstration driver will also come as a follow up. * - lints * - more lints * - use the _ffi_api properly
- Loading branch information
1 parent
e084791
commit 7661ba8
Showing
18 changed files
with
1,854 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
"""relay.collage exports""" | ||
from .collage import ( | ||
MEASURE_NUMBER, | ||
MEASURE_REPEAT, | ||
WARMUP_MIN_REPEAT_MS, | ||
CostEstimator, | ||
MockEstimator, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
"""FFI APIs for the Collage partitioner.""" | ||
import tvm._ffi | ||
|
||
|
||
tvm._ffi._init_api("relay.collage", __name__) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
"""Mostly helper methods which interface the main C++ Collage implementation with Python. | ||
See relay.transform.CollagePartition for the main Collage entrypoint.""" | ||
|
||
import logging | ||
import os | ||
import math | ||
import tempfile | ||
|
||
import numpy as np | ||
|
||
import tvm | ||
from tvm._ffi.registry import register_func, register_object | ||
from tvm.runtime import Object | ||
from . import _ffi_api | ||
|
||
# Parameters to use when estimating latency (of both partitions and overall models). | ||
MEASURE_NUMBER = 20 | ||
MEASURE_REPEAT = 5 | ||
WARMUP_MIN_REPEAT_MS = 250 | ||
|
||
|
||
@register_object("relay.collage.CostEstimator") | ||
class CostEstimator(Object): | ||
"""CostEstimator class""" | ||
|
||
def __init__(self): | ||
self.__init_handle_by_constructor__(_ffi_api.CostEstimator) | ||
|
||
|
||
@register_object("relay.collage.MockEstimator") | ||
class MockEstimator(Object): | ||
"""MockEstimator class""" | ||
|
||
def __init__(self, target_costs): | ||
self.__init_handle_by_constructor__(_ffi_api.MockEstimator, target_costs) | ||
|
||
|
||
def arg_for(arg_type, device): | ||
"""Returns a test argument of Relay arg_type on device""" | ||
assert isinstance(arg_type, tvm.ir.TensorType) | ||
return tvm.nd.array( | ||
np.random.uniform(-1.0, 1.0, size=arg_type.concrete_shape).astype(arg_type.dtype), | ||
device=device, | ||
) | ||
|
||
|
||
def vm_estimate_seconds(device, the_vm, func_name, args): | ||
"""Returns the estimated latency, in seconds, of running func_name with args on the_vm.""" | ||
# Warmup | ||
the_vm.benchmark( | ||
device, repeat=1, number=1, min_repeat_ms=WARMUP_MIN_REPEAT_MS, func_name=func_name, **args | ||
) | ||
# One more time, with feeling | ||
return the_vm.benchmark( | ||
device, | ||
repeat=MEASURE_REPEAT, | ||
number=MEASURE_NUMBER, | ||
min_repeat_ms=0, | ||
func_name=func_name, | ||
**args, | ||
) | ||
|
||
|
||
@register_func("tvm.relay.collage.estimate_seconds") | ||
def estimate_seconds(mod, target): | ||
"""Returns the mean execution time of "main" in mod on target with params. The module | ||
may contain "Primitive" functions, possibly with "Compiler" attributes.""" | ||
device = tvm.device(target.kind.device_type) | ||
|
||
try: | ||
# Build the module. | ||
logging.info("Compiling module to estimate") | ||
exe = tvm.relay.vm.compile(mod, target) | ||
except RuntimeError as err: | ||
# A build failure indicates the partition is not supported. | ||
# eg trying to build an nn.batch_norm on GPU, which has no schedule since we assume it | ||
# is only ever used with a tuple projection which is rewritten away. | ||
logging.info("Assigning module infinite cost since unable to build: %s", err) | ||
return math.inf | ||
|
||
# Finalize compilation | ||
tmp_dir = tempfile.mkdtemp() | ||
code, lib = exe.save() | ||
lib_path = os.path.join(tmp_dir, "library.so") | ||
# TODO(mbs): Avoid nvcc dependency? | ||
lib.export_library(lib_path, workspace_dir=tmp_dir, cc="nvcc") | ||
lib = tvm.runtime.load_module(lib_path) | ||
exe = tvm.runtime.vm.Executable.load_exec(code, lib) | ||
|
||
# Benchmark the module. | ||
the_vm = tvm.runtime.vm.VirtualMachine(exe, device) | ||
func_name = "main" | ||
main_args = {v.name_hint: arg_for(v.checked_type, device) for v in mod[func_name].params} | ||
logging.info("Benchmarking module to estimate") | ||
profile = vm_estimate_seconds(device, the_vm, func_name, main_args) | ||
logging.info("profile: %s", profile) | ||
return profile.median # seconds | ||
|
||
|
||
def make_labelled_dfpattern_partition_rule_wrapper(compiler, pattern_tuple): | ||
"""Returns a DFPatternPartitionRule representing one (label, pattern, predicate) entry from | ||
the pattern table for external codegen compiler""" | ||
if len(pattern_tuple) == 2: | ||
rule_name, dataflow_pattern = pattern_tuple | ||
return _ffi_api.MakeLabelledDFPatternPartitionRule(compiler, rule_name, dataflow_pattern) | ||
else: | ||
rule_name, dataflow_pattern, predicate = pattern_tuple | ||
return _ffi_api.MakeLabelledDFPatternPartitionRuleWithPredicate( | ||
compiler, rule_name, dataflow_pattern, predicate | ||
) | ||
|
||
|
||
@register_func("tvm.relay.collage.make_byoc_partition_rule") | ||
def make_byoc_partition_rule(compiler): | ||
"""Returns the PartitionRule for external codegen compiler""" | ||
pattern_table = tvm.relay.op.contrib.get_pattern_table(compiler) | ||
assert ( | ||
pattern_table is not None | ||
), f"No pattern table entry was found for BYOC compiler {compiler}" | ||
logging.info( | ||
"Converting %s rules for %s for use in pattern style BYOC lowering/codegen", | ||
len(pattern_table), | ||
compiler, | ||
) | ||
sub_rules = [ | ||
make_labelled_dfpattern_partition_rule_wrapper(compiler, pattern_tuple) | ||
for pattern_tuple in pattern_table | ||
] | ||
return _ffi_api.MakePatternBYOCPartitionRule(compiler, sub_rules) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
/*! | ||
* \file relay/collage/candidate_partition_index.h | ||
* \brief Index for finding relevant candidate partitions for a particular search state. | ||
*/ | ||
|
||
#include "./candidate_partition_index.h" | ||
|
||
#include "./gather_partition_specs.h" | ||
#include "./utils.h" | ||
|
||
namespace tvm { | ||
namespace relay { | ||
namespace collage { | ||
|
||
CandidatePartitionIndex::CandidatePartitionIndex( | ||
const std::unordered_map<const ExprNode*, VirtualDevice>* virtual_devices, | ||
DataflowGraph* dataflow_graph) | ||
: virtual_devices_(virtual_devices), | ||
dataflow_graph_(dataflow_graph), | ||
first_inside_index_to_candidates_(dataflow_graph->size()) {} | ||
|
||
void CandidatePartitionIndex::Index(const Array<PartitionSpec>& partition_specs) { | ||
std::vector<CandidatePartition> candidates = Collect(partition_specs); | ||
|
||
// (The candidates could be pruned at this point to elliminate those which are heuristically | ||
// unlikely to appear in the optimal partitioning.) | ||
|
||
// Index the candidates by their first inside index. | ||
for (auto& candidate : candidates) { | ||
first_inside_index_to_candidates_[candidate->sub_graph_->first_inside_index_].emplace_back( | ||
candidate); | ||
} | ||
size_ = candidates.size(); | ||
} | ||
|
||
void CandidatePartitionIndex::EstimateAllCosts( | ||
const CostEstimator cost_estimator, const std::shared_ptr<CandidateFunctionCache>& cache) { | ||
size_t n = 0; | ||
for (PostDfsIndex index = 0; index < dataflow_graph_->size(); ++index) { | ||
for (const auto& candidate : first_inside_index_to_candidates_[index]) { | ||
LOG(INFO) << "Estimating cost of candidate " << candidate->ToSummary(*dataflow_graph_) << " [" | ||
<< n++ << "/" << size_ << "]"; | ||
// Cost will be cached in candidate as a side effect. | ||
Cost cost = candidate->EstimatedCost(*dataflow_graph_, cost_estimator, cache); | ||
LOG(INFO) << "Candidate has cost " << cost.ToString(); | ||
} | ||
} | ||
} | ||
|
||
std::string CandidatePartitionIndex::ToSummary() const { | ||
std::vector<std::string> lines; | ||
for (const auto& candidates : first_inside_index_to_candidates_) { | ||
for (const auto& candidate : candidates) { | ||
if (candidate->partition_spec_name() == kHostSpecName) { | ||
continue; | ||
} | ||
lines.emplace_back(candidate->ToSummary(*dataflow_graph_)); | ||
} | ||
} | ||
std::sort(lines.begin(), lines.end()); | ||
std::ostringstream os; | ||
bool first = true; | ||
for (const auto& line : lines) { | ||
if (first) { | ||
first = false; | ||
} else { | ||
os << std::endl; | ||
} | ||
os << line; | ||
} | ||
return os.str(); | ||
} | ||
|
||
bool CandidatePartitionIndex::IsCompatibleWithVirtualDevice(const CandidatePartition& candidate) { | ||
for (PostDfsIndex index : candidate->sub_graph_->inside_) { | ||
const ExprNode* sub_expr_node = dataflow_graph_->index_to_node(index)->node_ref_; | ||
if (sub_expr_node->IsInstance<OpNode>() || sub_expr_node->IsInstance<ConstructorNode>()) { | ||
// These nodes are target/device polymorphic. | ||
continue; | ||
} | ||
auto itr = virtual_devices_->find(sub_expr_node); | ||
ICHECK(itr != virtual_devices_->end()) << PrettyPrint(GetRef<Expr>(sub_expr_node)); | ||
const Target& existing_target = itr->second->target; | ||
if (!existing_target.defined()) { | ||
// No constraint. | ||
continue; | ||
} | ||
if (StructuralEqual()(existing_target, candidate->target())) { | ||
// No disagreement. | ||
continue; | ||
} | ||
if (!candidate->target().IsExternalCodegenFor(itr->second->target)) { | ||
// The candidate's target is not an external codegen target compatible with the existing | ||
// target. | ||
// TODO(mbs): There's a conflict here between Collage's desire to leave some expression nodes | ||
// 'behind' on the VM and PlanDevice's desire to assign a primitive Target to every node. | ||
// I think PlanDevices is the one that needs to give here by leaving such nodes | ||
// unconstrained. | ||
VLOG(1) << "Ignoring candidate " << candidate->ToString() | ||
<< " since incompatible with existing virtual device assignment of:" << std::endl | ||
<< itr->second << std::endl | ||
<< "to sub-graph:" << std::endl | ||
<< PrettyPrint(GetRef<Expr>(sub_expr_node)); | ||
return false; | ||
} | ||
} | ||
return true; | ||
} | ||
|
||
std::vector<CandidatePartition> CandidatePartitionIndex::Collect( | ||
const Array<PartitionSpec>& partition_specs) { | ||
VLOG_CONTEXT << "collecting"; | ||
std::vector<CandidatePartition> result; | ||
for (const auto& spec : partition_specs) { | ||
VLOG_CONTEXT << "spec " << spec->spec_name_; | ||
VLOG(1) << "collecting candidates"; | ||
std::vector<CandidatePartition> candidates = spec->AllCandidates(*dataflow_graph_); | ||
for (auto& candidate : candidates) { | ||
if (!IsCompatibleWithVirtualDevice(candidate)) { | ||
continue; | ||
} | ||
result.push_back(candidate); | ||
} | ||
} | ||
VLOG(1) << "Found " << result.size() << " candidates"; | ||
return result; | ||
} | ||
|
||
} // namespace collage | ||
} // namespace relay | ||
} // namespace tvm |
Oops, something went wrong.