Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(compiler): Write universal exports on linked module #2234

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
282 changes: 123 additions & 159 deletions compiler/src/codegen/comp_utils.re
Original file line number Diff line number Diff line change
Expand Up @@ -153,36 +153,6 @@ let load =
);
};

let is_grain_env = str => grain_env_name == str;

let get_exported_names = (~function_names=?, ~global_names=?, wasm_mod) => {
let num_exports = Export.get_num_exports(wasm_mod);
let exported_names: Hashtbl.t(string, string) = Hashtbl.create(10);
for (i in 0 to num_exports - 1) {
let export = Export.get_export_by_index(wasm_mod, i);
let export_kind = Export.export_get_kind(export);
let exported_name = Export.get_name(export);
let internal_name = Export.get_value(export);

if (export_kind == Export.external_function) {
let new_internal_name =
switch (function_names) {
| Some(function_names) => Hashtbl.find(function_names, internal_name)
| None => internal_name
};
Hashtbl.add(exported_names, exported_name, new_internal_name);
} else if (export_kind == Export.external_global) {
let new_internal_name =
switch (global_names) {
| Some(global_names) => Hashtbl.find(global_names, internal_name)
| None => internal_name
};
Hashtbl.add(exported_names, exported_name, new_internal_name);
};
};
exported_names;
};

let type_of_repr = repr => {
Types.(
switch (repr) {
Expand All @@ -195,152 +165,146 @@ let type_of_repr = repr => {
};

let write_universal_exports =
(wasm_mod, {Cmi_format.cmi_sign}, exported_names) => {
Types.(
Type_utils.(
List.iter(
item => {
switch (item) {
| TSigValue(
id,
{
val_repr: ReprFunction(args, rets, direct),
val_fullpath: path,
},
) =>
let name = Ident.name(id);
let exported_name = "GRAIN$EXPORT$" ++ name;
let internal_global_name =
Hashtbl.find(exported_names, exported_name);
let get_closure = () =>
Expression.Global_get.make(
(wasm_mod, {Cmi_format.cmi_sign}, exports, resolve) => {
open Types;
open Type_utils;
let export_map = Hashtbl.create(128);
List.iter(
e => {
switch (e) {
| WasmFunctionExport({ex_function_name, ex_function_internal_name}) =>
Hashtbl.add(
export_map,
ex_function_name,
resolve(ex_function_internal_name),
)
| WasmGlobalExport(_) => ()
}
},
exports,
);
List.iter(
item => {
switch (item) {
| TSigValue(id, {val_repr: ReprFunction(args, rets, direct)}) =>
let name = Ident.name(id);
let internal_name = Hashtbl.find(export_map, name);
let get_closure = () =>
Expression.Global_get.make(wasm_mod, internal_name, Type.int32);
let arguments =
List.mapi(
(i, arg) =>
Expression.Local_get.make(wasm_mod, i, type_of_repr(arg)),
args,
);
let arguments = [get_closure(), ...arguments];
let call_result_types =
Type.create(
Array.of_list(
List.map(type_of_repr, rets == [] ? [WasmI32] : rets),
),
);
let function_call =
switch (direct) {
| Direct({name}) =>
Expression.Call.make(
wasm_mod,
internal_name,
arguments,
call_result_types,
)
| Indirect =>
let call_arg_types =
Type.create(
Array.of_list(List.map(type_of_repr, [WasmI32, ...args])),
);
let func_ptr =
Expression.Load.make(
wasm_mod,
internal_global_name,
4,
8,
2,
Type.int32,
get_closure(),
grain_memory,
);
let arguments =
List.mapi(
(i, arg) =>
Expression.Local_get.make(wasm_mod, i, type_of_repr(arg)),
args,
);
let arguments = [get_closure(), ...arguments];
let call_result_types =
Type.create(
Array.of_list(
List.map(type_of_repr, rets == [] ? [WasmI32] : rets),
Expression.Call_indirect.make(
wasm_mod,
grain_global_function_table,
func_ptr,
arguments,
call_arg_types,
call_result_types,
);
| Unknown => failwith("Impossible: Unknown function call type")
};
let function_body =
switch (rets) {
| [] => Expression.Drop.make(wasm_mod, function_call)
| _ => function_call
};
let function_body =
Expression.Block.make(
wasm_mod,
"closure_incref",
[
Expression.If.make(
wasm_mod,
Expression.Binary.make(
wasm_mod,
Op.ne_int32,
get_closure(),
Expression.Const.make(wasm_mod, Literal.int32(0l)),
),
);
let function_call =
switch (direct) {
| Direct({name}) =>
Expression.Call.make(
store(
wasm_mod,
Hashtbl.find(exported_names, name),
arguments,
call_result_types,
)
| Indirect =>
let call_arg_types =
Type.create(
Array.of_list(
List.map(type_of_repr, [WasmI32, ...args]),
),
);
let func_ptr =
Expression.Load.make(
Expression.Binary.make(
wasm_mod,
4,
8,
2,
Type.int32,
Op.sub_int32,
get_closure(),
grain_memory,
);
Expression.Call_indirect.make(
wasm_mod,
grain_global_function_table,
func_ptr,
arguments,
call_arg_types,
call_result_types,
);
| Unknown => failwith("Impossible: Unknown function call type")
};
let function_body =
switch (rets) {
| [] => Expression.Drop.make(wasm_mod, function_call)
| _ => function_call
};
let function_body =
Expression.Block.make(
wasm_mod,
"closure_incref",
[
Expression.If.make(
Expression.Const.make(wasm_mod, Literal.int32(8l)),
),
Expression.Binary.make(
wasm_mod,
Expression.Binary.make(
wasm_mod,
Op.ne_int32,
get_closure(),
Expression.Const.make(wasm_mod, Literal.int32(0l)),
),
store(
Op.add_int32,
load(
wasm_mod,
Expression.Binary.make(
wasm_mod,
Op.sub_int32,
get_closure(),
Expression.Const.make(wasm_mod, Literal.int32(8l)),
),
Expression.Binary.make(
wasm_mod,
Op.add_int32,
load(
wasm_mod,
Expression.Binary.make(
wasm_mod,
Op.sub_int32,
get_closure(),
Expression.Const.make(
wasm_mod,
Literal.int32(8l),
),
),
),
Expression.Const.make(wasm_mod, Literal.int32(1l)),
),
),
Expression.Null.make(),
Expression.Const.make(wasm_mod, Literal.int32(1l)),
),
function_body,
],
);
let arg_types =
Type.create(Array.of_list(List.map(type_of_repr, args)));
let result_types =
Type.create(Array.of_list(List.map(type_of_repr, rets)));
ignore @@
Function.add_function(
wasm_mod,
name,
arg_types,
result_types,
[||],
),
Expression.Null.make(),
),
function_body,
);
// Remove existing Grain export (if any)
Export.remove_export(wasm_mod, name);
ignore @@ Export.add_function_export(wasm_mod, name, name);
| TSigValue(_)
| TSigType(_)
| TSigTypeExt(_)
| TSigModule(_)
| TSigModType(_) => ()
}
},
cmi_sign,
)
)
],
);
let arg_types =
Type.create(Array.of_list(List.map(type_of_repr, args)));
let result_types =
Type.create(Array.of_list(List.map(type_of_repr, rets)));
ignore @@
Function.add_function(
wasm_mod,
name,
arg_types,
result_types,
[||],
function_body,
);
ignore @@ Export.add_function_export(wasm_mod, name, name);
| TSigValue(_)
| TSigType(_)
| TSigTypeExt(_)
| TSigModule(_)
| TSigModType(_) => ()
}
},
cmi_sign,
);
};
12 changes: 1 addition & 11 deletions compiler/src/codegen/comp_utils.rei
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,5 @@ let load:
) =>
Expression.t;

let is_grain_env: string => bool;

let get_exported_names:
(
~function_names: Hashtbl.t(string, string)=?,
~global_names: Hashtbl.t(string, string)=?,
Module.t
) =>
Hashtbl.t(string, string);

let write_universal_exports:
(Module.t, Cmi_format.cmi_infos, Hashtbl.t(string, string)) => unit;
(Module.t, Cmi_format.cmi_infos, list(export), string => string) => unit;
Loading
Loading