Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add partial support for -fwasm-exceptions in Asyncify (#5343) #5475

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
284 changes: 276 additions & 8 deletions src/passes/Asyncify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,13 @@
// Overall, this should allow good performance with small overhead that is
// mostly noticed at rewind time.
//
// Exceptions handling (-fwasm-exceptions) is partially supported. Asyncify
// can't start unwind operation when a catch block is in the stack trace.
// If assertions mode is enabled then pass will check if unwind called from
// within catch block or not, and if so throw an unreachable exception.
// If "ignore unwind from catch" mode is enable then Asyncify will skip
// any unwind call from within catch block.
//
// After this pass is run a new i32 global "__asyncify_state" is added, which
// has the following values:
//
Expand Down Expand Up @@ -239,6 +246,12 @@
// an unwind/rewind in an invalid place (this can be helpful for manual
// tweaking of the only-list / remove-list, see later).
//
// --pass-arg=asyncify-ignore-unwind-from-catch
//
// This enables additional check to be performed before unwinding. In
// cases where the unwind operation is triggered from the catch block,
// it will be silently ignored (-fwasm-exceptions support)
//
// --pass-arg=asyncify-verbose
//
// Logs out instrumentation decisions to the console. This can help figure
Expand Down Expand Up @@ -313,13 +326,16 @@

#include "asmjs/shared-constants.h"
#include "cfg/liveness-traversal.h"
#include "ir/branch-utils.h"
#include "ir/effects.h"
#include "ir/eh-utils.h"
#include "ir/find_all.h"
#include "ir/linear-execution.h"
#include "ir/literal-utils.h"
#include "ir/memory-utils.h"
#include "ir/module-utils.h"
#include "ir/names.h"
#include "ir/parents.h"
#include "ir/utils.h"
#include "pass.h"
#include "passes/pass-utils.h"
Expand All @@ -334,6 +350,8 @@ namespace {

static const Name ASYNCIFY_STATE = "__asyncify_state";
static const Name ASYNCIFY_GET_STATE = "asyncify_get_state";
static const Name ASYNCIFY_CATCH_COUNTER = "__asyncify_catch_counter";
static const Name ASYNCIFY_GET_CATCH_COUNTER = "asyncify_get_catch_counter";
static const Name ASYNCIFY_DATA = "__asyncify_data";
static const Name ASYNCIFY_START_UNWIND = "asyncify_start_unwind";
static const Name ASYNCIFY_STOP_UNWIND = "asyncify_stop_unwind";
Expand Down Expand Up @@ -1134,6 +1152,18 @@ struct AsyncifyFlow : public Pass {
// here as well.
results.push_back(makeCallSupport(curr));
continue;
} else if (auto* try_ = curr->dynCast<Try>()) {
if (item.phase == Work::Scan) {
work.push_back(Work{curr, Work::Finish});
work.push_back(Work{try_->body, Work::Scan});
// catchBodies are ignored because we assume that pause/resume will
// not happen inside them
continue;
}
try_->body = results.back();
results.pop_back();
results.push_back(try_);
continue;
}
// We must handle all control flow above, and all things that can change
// the state, so there should be nothing that can reach here - add it
Expand Down Expand Up @@ -1214,6 +1244,202 @@ struct AsyncifyFlow : public Pass {
}
};

// Add catch block counters to verify that unwind is not called from catch
// block.
struct AsyncifyAddCatchCounters : public Pass {
bool isFunctionParallel() override { return true; }

std::unique_ptr<Pass> create() override {
return std::make_unique<AsyncifyAddCatchCounters>();
}

void runOnFunction(Module* module_, Function* func) override {
class CountersBuilder : public Builder {
public:
CountersBuilder(Module& wasm) : Builder(wasm) {}
Expression* makeInc(int amount = 1) {
return makeGlobalSet(
ASYNCIFY_CATCH_COUNTER,
makeBinary(AddInt32,
makeGlobalGet(ASYNCIFY_CATCH_COUNTER, Type::i32),
makeConst(int32_t(amount))));
}
Expression* makeDec(int amount = 1) {
return makeGlobalSet(
ASYNCIFY_CATCH_COUNTER,
makeBinary(SubInt32,
makeGlobalGet(ASYNCIFY_CATCH_COUNTER, Type::i32),
makeConst(int32_t(amount))));
}
};

// with this walker we will handle those changes of counter:
// - entering top-level catch (= pop) +1
// - entering nested catch (= pop) 0 (ignored)
//
// - return inside top-level/nested catch -1
// - return outside top-level/nested catch 0 (ignored)
//
// - break target outside of top-level catch -1
// - break target inside of top-level catch 0 (ignored)
// - break outside top-level/nested catch 0 (ignored)
//
// - exiting from top-level catch -1
// - exiting from nested catch 0 (ignored)
struct AddCountersWalker : public PostWalker<AddCountersWalker> {
Function* func;
CountersBuilder* builder;
BranchUtils::BranchTargets* branchTargets;
Parents* parents;
int finallyNum = 0;
int popNum = 0;

int getCatchCount(Expression* expression) {
int catchCount = 0;
while (expression != func->body) {
auto parent = parents->getParent(expression);
if (auto* try_ = parent->dynCast<Try>()) {
if (try_->body != expression) {
catchCount++;
}
}
expression = parent;
}

return catchCount;
}

// Each catch block except catch_all should have pop instruction
// We increment counter each time when we enter top-level catch block
void visitPop(Pop* pop) {
if (getCatchCount(pop) == 1) {
auto name =
func->name.toString() + "-pop-" + std::to_string(++popNum);
replaceCurrent(
builder->makeBlock(name, {pop, builder->makeInc()}, Type::none));
}
}
void visitLocalSet(LocalSet* set) {
auto block = set->value->dynCast<Block>(); // from visitPop above
if (block && block->name.hasSubstring("-pop-")) {
auto pop = block->list[0]->dynCast<Pop>();
assert(pop && getCatchCount(pop) == 1);
set->value = pop;
replaceCurrent(builder->makeBlock(
block->name, {set, builder->makeInc()}, Type::none));
}
}

// When return happens we decrement counter on 1, because we account
// only top-level catch blocks
// catch
// +1
// catch
// ;; not counted
// -1
// return
// ...
void visitReturn(Return* ret) {
if (getCatchCount(ret) > 0) {
replaceCurrent(builder->makeSequence(builder->makeDec(), ret));
}
}

// When break happens we decrement counter only if it goes out
// from top-level catch block
void visitBreak(Break* br) {
Expression* target = branchTargets->getTarget(br->name);
assert(target != nullptr);
if (getCatchCount(br) > 0 && getCatchCount(target) == 0) {
if (br->condition == nullptr) {
replaceCurrent(builder->makeSequence(builder->makeDec(), br));
} else if (br->value == nullptr) {
auto decIf =
builder->makeIf(br->condition,
builder->makeSequence(builder->makeDec(), br),
nullptr);
br->condition = nullptr;
replaceCurrent(decIf);
} else {
Index newLocal = builder->addVar(func, br->value->type);
auto setLocal = builder->makeLocalSet(newLocal, br->value);
auto getLocal = builder->makeLocalGet(newLocal, br->value->type);
auto condition = br->condition;
br->condition = nullptr;
br->value = getLocal;
auto decIf =
builder->makeIf(condition,
builder->makeSequence(builder->makeDec(), br),
getLocal);
replaceCurrent(builder->makeSequence(setLocal, decIf));
}
}
}

// Replacing each top-level catch block with try/catch_all(finally) and
// increase counter for catch_all blocks (not handled by visitPop); dec
// counter at the end of catch block try ({fn}-finally-{label})
// +1
// {catch body}
// -1
// catch_all
// -1
// rethrow {fn}-finally-{label}
void visitTry(Try* curr) {
if (getCatchCount(curr) == 0) {
for (size_t i = 0; i < curr->catchBodies.size(); ++i) {
curr->catchBodies[i] = addCatchCounters(
curr->catchBodies[i], i == curr->catchTags.size());
}
}
}
Expression* addCatchCounters(Expression* expression, bool catchAll) {
auto block = expression->dynCast<Block>();
if (block == nullptr) {
block = builder->makeBlock(expression);
}

// catch_all case is not covered by visitPop
if (catchAll) {
block->list.insertAt(0, builder->makeInc());
}

// dec counters at the end of catch
if (block->type == Type::none) {
auto last = block->list[block->list.size() - 1];
if (!last->dynCast<Return>()) {
block->list.push_back(builder->makeDec());
block->finalize();
}
}

auto name =
func->name.toString() + "-finally-" + std::to_string(++finallyNum);
return builder->makeTry(
name,
block,
{},
{builder->makeSequence(builder->makeDec(),
builder->makeRethrow(name))},
block->type);
}
};

Parents parents(func->body);
CountersBuilder builder(*module_);
BranchUtils::BranchTargets branchTargets(func->body);

AddCountersWalker addCountersWalker;
addCountersWalker.func = func;
addCountersWalker.builder = &builder;
addCountersWalker.branchTargets = &branchTargets;
addCountersWalker.parents = &parents;
addCountersWalker.walk(func->body);

EHUtils::handleBlockNestedPops(func, *module_);
}
};

// Add asserts in non-instrumented code.
struct AsyncifyAssertInNonInstrumented : public Pass {
bool isFunctionParallel() override { return true; }
Expand Down Expand Up @@ -1646,6 +1872,9 @@ struct Asyncify : public Pass {
auto relocatable = hasArgument("asyncify-relocatable");
auto secondaryMemory = hasArgument("asyncify-in-secondary-memory");
auto propagateAddList = hasArgument("asyncify-propagate-addlist");
auto ignoreCatchUnwind =
hasArgument("asyncify-ignore-unwind-from-catch");
auto addAsyncifyCounters = asserts || ignoreCatchUnwind;

// Ensure there is a memory, as we need it.
if (secondaryMemory) {
Expand Down Expand Up @@ -1693,7 +1922,7 @@ struct Asyncify : public Pass {
verbose);

// Add necessary globals before we emit code to use them.
addGlobals(module, relocatable);
addGlobals(module, relocatable, addAsyncifyCounters);

// Compute the set of functions we will instrument. All of the passes we run
// below only need to run there.
Expand Down Expand Up @@ -1734,12 +1963,17 @@ struct Asyncify : public Pass {
runner.setValidateGlobally(false);
runner.run();
}
if (asserts) {
if (asserts || addAsyncifyCounters) {
// Add asserts in non-instrumented code. Note we do not use an
// instrumented pass runner here as we do want to run on all functions.
PassRunner runner(module);
runner.add(std::make_unique<AsyncifyAssertInNonInstrumented>(
&analyzer, pointerType, asyncifyMemory));
if (addAsyncifyCounters) {
runner.add(std::make_unique<AsyncifyAddCatchCounters>());
}
if (asserts) {
runner.add(std::make_unique<AsyncifyAssertInNonInstrumented>(
&analyzer, pointerType, asyncifyMemory));
}
runner.setIsNested(true);
runner.setValidateGlobally(false);
runner.run();
Expand All @@ -1765,11 +1999,11 @@ struct Asyncify : public Pass {
}
// Finally, add function support (that should not have been seen by
// the previous passes).
addFunctions(module);
addFunctions(module, asserts, ignoreCatchUnwind);
}

private:
void addGlobals(Module* module, bool imported) {
void addGlobals(Module* module, bool imported, bool addAsyncifyCounters) {
Builder builder(*module);

auto asyncifyState = builder.makeGlobal(ASYNCIFY_STATE,
Expand All @@ -1782,6 +2016,19 @@ struct Asyncify : public Pass {
}
module->addGlobal(std::move(asyncifyState));

if (addAsyncifyCounters) {
auto asyncifyCatchCounter =
builder.makeGlobal(ASYNCIFY_CATCH_COUNTER,
Type::i32,
builder.makeConst(int32_t(0)),
Builder::Mutable);
if (imported) {
asyncifyCatchCounter->module = ENV;
asyncifyCatchCounter->base = ASYNCIFY_CATCH_COUNTER;
}
module->addGlobal(std::move(asyncifyCatchCounter));
}

auto asyncifyData = builder.makeGlobal(ASYNCIFY_DATA,
pointerType,
builder.makeConst(pointerType),
Expand All @@ -1793,14 +2040,24 @@ struct Asyncify : public Pass {
module->addGlobal(std::move(asyncifyData));
}

void addFunctions(Module* module) {
void addFunctions(Module* module, bool asserts, bool ignoreCatchUnwind) {
Builder builder(*module);
auto makeFunction = [&](Name name, bool setData, State state) {
auto* body = builder.makeBlock();
if (name == ASYNCIFY_START_UNWIND && (asserts || ignoreCatchUnwind)) {
auto* check = builder.makeIf(
builder.makeBinary(
NeInt32,
builder.makeGlobalGet(ASYNCIFY_CATCH_COUNTER, Type::i32),
builder.makeConst(int32_t(0))),
ignoreCatchUnwind ? (Expression*)builder.makeReturn()
: (Expression*)builder.makeUnreachable());
body->list.push_back(check);
}
std::vector<Type> params;
if (setData) {
params.push_back(pointerType);
}
auto* body = builder.makeBlock();
body->list.push_back(builder.makeGlobalSet(
ASYNCIFY_STATE, builder.makeConst(int32_t(state))));
if (setData) {
Expand Down Expand Up @@ -1848,6 +2105,17 @@ struct Asyncify : public Pass {
builder.makeGlobalGet(ASYNCIFY_STATE, Type::i32)));
module->addExport(builder.makeExport(
ASYNCIFY_GET_STATE, ASYNCIFY_GET_STATE, ExternalKind::Function));

if (asserts || ignoreCatchUnwind) {
module->addFunction(builder.makeFunction(
ASYNCIFY_GET_CATCH_COUNTER,
Signature(Type::none, Type::i32),
{},
builder.makeGlobalGet(ASYNCIFY_CATCH_COUNTER, Type::i32)));
module->addExport(builder.makeExport(ASYNCIFY_GET_CATCH_COUNTER,
ASYNCIFY_GET_CATCH_COUNTER,
ExternalKind::Function));
}
}

Name createSecondaryMemory(Module* module, Address secondaryMemorySize) {
Expand Down
Loading
Loading