Skip to content

Commit

Permalink
Update llvm-project to 830c0b9
Browse files Browse the repository at this point in the history
- builder.getSymbolRefAttr is gone.
- OpAsmOpInterface's getAsmResultNames method needs explicit override
- a bunch of churn for builtin.func needing to be made explicit (and
  sometimes implicit?)
- operation printers no longer need to print the operation name
  themselves.
- snuck in beneficial trivial addition to TmpDeleteDeadIREEListsPass to
  test a particular upstream change e2e with my local patchset.
  • Loading branch information
silvasean committed Sep 3, 2021
1 parent 9cc4fdc commit 1dec561
Show file tree
Hide file tree
Showing 21 changed files with 184 additions and 175 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class IREE_AliasedSymbolRefAttr : Attr<CPred<"$_self.isa<FlatSymbolRefAttr>()">,
let storageType = [{ FlatSymbolRefAttr }];
let returnType = [{ StringRef }];
let valueType = NoneType;
let constBuilderCall = "$_builder.getSymbolRefAttr($0)";
let constBuilderCall = "::mlir::SymbolRefAttr::get($_builder.getContext(), $0)";
}

class IREE_AnyPtrOf<list<Type> types> :
Expand Down
2 changes: 1 addition & 1 deletion external/llvm-project
Submodule llvm-project updated 4541 files
4 changes: 2 additions & 2 deletions frontends/pytorch/test/node_import/dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
mb = torch_mlir.ModuleBuilder()


# CHECK-LABEL: builtin.func @__torch__.dict_literal_empty() -> !torch.dict<!torch.str, !torch.tensor> {
# CHECK-LABEL: func @__torch__.dict_literal_empty() -> !torch.dict<!torch.str, !torch.tensor> {
# CHECK: %[[DICT:.*]] = torch.prim.DictConstruct keys() values() -> !torch.dict<!torch.str, !torch.tensor>
# CHECK: return %[[DICT]] : !torch.dict<!torch.str, !torch.tensor>
@mb.import_function
Expand All @@ -21,7 +21,7 @@ def dict_literal_empty() -> Dict[str, torch.Tensor]:
return {}


# CHECK-LABEL: builtin.func @__torch__.dict_literal(
# CHECK-LABEL: func @__torch__.dict_literal(
# CHECK-SAME: %[[K0:.*]]: !torch.str, %[[V0:.*]]: !torch.tensor,
# CHECK-SAME: %[[K1:.*]]: !torch.str, %[[V1:.*]]: !torch.tensor)
# CHECK-SAME: -> !torch.dict<!torch.str, !torch.optional<!torch.tensor>> {
Expand Down
6 changes: 3 additions & 3 deletions frontends/pytorch/test/node_import/tuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
NT = NamedTuple('NT', [('f1', Optional[torch.Tensor]),
('f2', Optional[torch.Tensor])])

# CHECK-LABEL: builtin.func @__torch__.tuple(
# CHECK-LABEL: func @__torch__.tuple(
# CHECK-SAME: %[[T0:.*]]: !torch.tensor,
# CHECK-SAME: %[[T1:.*]]: !torch.tensor) ->
# CHECK-SAME: !torch.tuple<!torch.tensor, !torch.tensor> {
Expand All @@ -28,7 +28,7 @@ def tuple(t0, t1):
return t0, t1


# CHECK-LABEL: builtin.func @__torch__.tuple_optional(
# CHECK-LABEL: func @__torch__.tuple_optional(
# CHECK-SAME: %[[T0:.*]]: !torch.tensor,
# CHECK-SAME: %[[T1:.*]]: !torch.tensor) ->
# CHECK-SAME: !torch.tuple<!torch.optional<!torch.tensor>, !torch.optional<!torch.tensor>> {
Expand All @@ -47,7 +47,7 @@ def tuple_optional(
return t0, t1


# CHECK-LABEL: builtin.func @__torch__.namedtuple_optional(
# CHECK-LABEL: func @__torch__.namedtuple_optional(
# CHECK-SAME: %[[T0:.*]]: !torch.tensor,
# CHECK-SAME: %[[T1:.*]]: !torch.tensor) ->
# CHECK-SAME: !torch.tuple<!torch.optional<!torch.tensor>, !torch.optional<!torch.tensor>> {
Expand Down
12 changes: 8 additions & 4 deletions include/npcomp/Dialect/Basicpy/IR/BasicpyOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ def CompareOperationAttr : StrEnumAttr<
//===----------------------------------------------------------------------===//

def Basicpy_NumericConstantOp : Basicpy_Op<"numeric_constant", [
ConstantLike, NoSideEffect, DeclareOpInterfaceMethods<OpAsmOpInterface>]> {
ConstantLike, NoSideEffect,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
let summary = "A constant from the Python3 numeric type hierarchy";
let description = [{
Basicpy re-uses core MLIR types to represent the Python3 numeric type
Expand Down Expand Up @@ -141,7 +142,8 @@ def Basicpy_NumericConstantOp : Basicpy_Op<"numeric_constant", [
}

def Basicpy_BoolConstantOp : Basicpy_Op<"bool_constant", [
ConstantLike, NoSideEffect, DeclareOpInterfaceMethods<OpAsmOpInterface>]> {
ConstantLike, NoSideEffect,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
let summary = "A boolean constant";
let description = [{
A constant of type !basicpy.BoolType that can take either an i1 value
Expand Down Expand Up @@ -220,7 +222,8 @@ def Basicpy_BuildTupleOp : Basicpy_Op<"build_tuple", [NoSideEffect]> {
}

def Basicpy_BytesConstantOp : Basicpy_Op<"bytes_constant", [
ConstantLike, NoSideEffect, DeclareOpInterfaceMethods<OpAsmOpInterface>]> {
ConstantLike, NoSideEffect,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
let summary = "Constant bytes value";
let description = [{
A bytes value of BytesType. The value is represented by a StringAttr.
Expand Down Expand Up @@ -251,7 +254,8 @@ def Basicpy_SingletonOp : Basicpy_Op<"singleton", [
}

def Basicpy_StrConstantOp : Basicpy_Op<"str_constant", [
ConstantLike, NoSideEffect, DeclareOpInterfaceMethods<OpAsmOpInterface>]> {
ConstantLike, NoSideEffect,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
let summary = "Constant string value";
let description = [{
A string value of StrType. The value is represented by a StringAttr
Expand Down
15 changes: 10 additions & 5 deletions include/npcomp/Dialect/Torch/IR/TorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,8 @@ def Torch_PrimIfYieldOp : Torch_Op<"prim.If.yield", [
//===----------------------------------------------------------------------===//

def Torch_ConstantNoneOp : Torch_Op<"constant.none",
[ConstantLike, NoSideEffect, DeclareOpInterfaceMethods<OpAsmOpInterface>]> {
[ConstantLike, NoSideEffect,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
let summary = "Get the singleton None value.";
let description = [{
Not to be confused with the `mlir::NoneType`. Be careful to use
Expand All @@ -547,7 +548,8 @@ def Torch_ConstantNoneOp : Torch_Op<"constant.none",
}

def Torch_ConstantStrOp : Torch_Op<"constant.str",
[ConstantLike, NoSideEffect, DeclareOpInterfaceMethods<OpAsmOpInterface>]> {
[ConstantLike, NoSideEffect,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
let summary = "Materialize a constant str value.";
let description = [{
Note: Strings in Python (and TorchScript) are immutable.
Expand All @@ -563,7 +565,8 @@ def Torch_ConstantStrOp : Torch_Op<"constant.str",
}

def Torch_ConstantIntOp : Torch_Op<"constant.int",
[ConstantLike, NoSideEffect, DeclareOpInterfaceMethods<OpAsmOpInterface>]> {
[ConstantLike, NoSideEffect,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
let summary = "Materialize a constant `int` value.";
let description = [{
Note: TorchScript represents integers as 64-bit signed values, unlike
Expand All @@ -581,7 +584,8 @@ def Torch_ConstantIntOp : Torch_Op<"constant.int",
}

def Torch_ConstantFloatOp : Torch_Op<"constant.float",
[ConstantLike, NoSideEffect, DeclareOpInterfaceMethods<OpAsmOpInterface>]> {
[ConstantLike, NoSideEffect,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
let summary = "Materialize a constant `float` value.";
let description = [{
Note: TorchScript represents `float` as 64-bit floating point values.
Expand All @@ -599,7 +603,8 @@ def Torch_ConstantFloatOp : Torch_Op<"constant.float",
}

def Torch_ConstantBoolOp : Torch_Op<"constant.bool",
[ConstantLike, NoSideEffect, DeclareOpInterfaceMethods<OpAsmOpInterface>]> {
[ConstantLike, NoSideEffect,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
let summary = "Materialize a constant `bool` value.";
let description = [{
}];
Expand Down
2 changes: 1 addition & 1 deletion include/npcomp/Dialect/Torch/IR/TorchTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def Torch_NnModuleType : Torch_Type<"NnModule", "nn.Module"> {
let parser = [{
if (parser.parseLess())
return Type();
StringRef className;
std::string className;
if ($_parser.parseOptionalString(&className))
return Type();
if ($_parser.parseGreater())
Expand Down
9 changes: 4 additions & 5 deletions lib/Dialect/Basicpy/IR/BasicpyOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ static ParseResult parseNumericConstantOp(OpAsmParser &parser,
}

static void print(OpAsmPrinter &p, NumericConstantOp op) {
p << "basicpy.numeric_constant ";
p << " ";
p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"value"});

if (op->getAttrs().size() > 1)
Expand Down Expand Up @@ -176,7 +176,6 @@ static ParseResult parseExecOp(OpAsmParser &parser, OperationState *result) {
}

static void print(OpAsmPrinter &p, ExecOp op) {
p << op.getOperationName();
p.printOptionalAttrDictWithKeyword(op->getAttrs());
p.printRegion(op.body());
}
Expand Down Expand Up @@ -230,7 +229,7 @@ static ParseResult parseFuncTemplateOp(OpAsmParser &parser,
}

static void print(OpAsmPrinter &p, FuncTemplateOp op) {
p << op.getOperationName() << " ";
p << " ";
p.printSymbolName(op.getName());
p.printOptionalAttrDictWithKeyword(op->getAttrs(),
{SymbolTable::getSymbolAttrName()});
Expand Down Expand Up @@ -294,7 +293,7 @@ static void print(OpAsmPrinter &p, SlotObjectMakeOp op) {
return;
}

p << op.getOperationName() << "(";
p << "(";
p.printOperands(op.slots());
p << ")";
p.printOptionalAttrDict(op->getAttrs(), {"className"});
Expand Down Expand Up @@ -358,7 +357,7 @@ static void print(OpAsmPrinter &p, SlotObjectGetOp op) {
return;
}

p << op.getOperationName() << " ";
p << " ";
p.printOperand(op.object());
p << "[" << op.index() << "]";
p.printOptionalAttrDict(op->getAttrs(), {"index"});
Expand Down
1 change: 0 additions & 1 deletion lib/Dialect/Refbackrt/IR/RefbackrtOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ using namespace mlir::NPCOMP::refbackrt;
//===----------------------------------------------------------------------===//

static void printModuleMetadataOp(OpAsmPrinter &p, ModuleMetadataOp &op) {
p << "refbackrt.module_metadata";
p.printOptionalAttrDictWithKeyword(op->getAttrs());
p.printRegion(op.metadatas(), /*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/false);
Expand Down
11 changes: 6 additions & 5 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ static IntegerAttr getI64IntegerAttr(MLIRContext *context, int64_t value) {
//===----------------------------------------------------------------------===//

LogicalResult MethodOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
auto func = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, function());
auto func =
symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, functionAttr());
if (!func)
return emitError() << "'@" << function()
<< "' does not reference a valid function";
Expand Down Expand Up @@ -132,8 +133,8 @@ bool isValidSubtype(Type subtype, Type type) {
}

LogicalResult NnModuleOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
auto classType =
symbolTable.lookupNearestSymbolFrom<ClassTypeOp>(*this, getClassName());
auto classType = symbolTable.lookupNearestSymbolFrom<ClassTypeOp>(
*this, SymbolRefAttr::get(getContext(), getClassName()));
if (!classType)
return emitError() << "'" << getClassName()
<< "' does not reference a valid class type";
Expand Down Expand Up @@ -297,7 +298,7 @@ static ParseResult parsePrimIfOp(OpAsmParser &parser, OperationState &result) {
}

static void print(OpAsmPrinter &p, PrimIfOp op) {
p << PrimIfOp::getOperationName() << " " << op.condition();
p << " " << op.condition();
p << " -> (" << op.getResultTypes() << ")";
p.printRegion(op.thenRegion(), /*printEntryBlockArgs=*/false);
p << " else";
Expand Down Expand Up @@ -748,7 +749,7 @@ static ParseResult parseConstantIntOp(OpAsmParser &parser,
}

static void print(OpAsmPrinter &p, Torch::ConstantIntOp op) {
p << Torch::ConstantIntOp::getOperationName() << " ";
p << " ";
p << op.value().getSExtValue();
p.printOptionalAttrDict(op->getAttrs(), {"value"});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class TmpDeleteDeadIREEListsPass
SmallVector<Operation *> deadOps;
deadOps.push_back(op);
for (auto &use : op.getResult().getUses()) {
if (isa<iree::ListSetOp>(use.getOwner())) {
if (isa<iree::ListSetOp, iree::ListResizeOp>(use.getOwner())) {
deadOps.push_back(use.getOwner());
} else {
// We can't analyze the list op if it is used by something else.
Expand Down
3 changes: 2 additions & 1 deletion lib/RefBackend/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -735,7 +735,8 @@ class LowerToLLVM : public LowerToLLVMBase<LowerToLLVM> {
auto wrapper = createWrapperFunc(originalFunc);
op.getResult().setType(LLVMPointerType::get(wrapper.getType()));
Builder builder(op.getContext());
op->setAttr("global_name", builder.getSymbolRefAttr(wrapper.getName()));
op->setAttr("global_name",
SymbolRefAttr::get(builder.getContext(), wrapper.getName()));
});
}
};
Expand Down
6 changes: 3 additions & 3 deletions lib/RefBackend/LowerToRefbackrtABI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -231,9 +231,9 @@ static LogicalResult createModuleMetadata(ModuleOp module) {

// Add attributes that are valid for every func (funcName, numInputs,
// numOutputs)
namedAttrs.push_back(
std::make_pair(Identifier::get("funcName", module.getContext()),
builder.getSymbolRefAttr(func.getName())));
namedAttrs.push_back(std::make_pair(
Identifier::get("funcName", module.getContext()),
SymbolRefAttr::get(builder.getContext(), func.getName())));
namedAttrs.push_back(
std::make_pair(Identifier::get("numInputs", module.getContext()),
builder.getI32IntegerAttr(func.getNumArguments())));
Expand Down
2 changes: 1 addition & 1 deletion test/Conversion/TorchToIREE/basic.mlir
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@

// RUN: npcomp-opt <%s -convert-torch-to-iree -split-input-file -verify-diagnostics | FileCheck %s

// CHECK-LABEL: builtin.func @forward(
// CHECK-LABEL: func @forward(
// CHECK-SAME: %[[ARG_TORCH:.*]]: !torch.float) -> !torch.list<!torch.float> {
// CHECK: %[[ARG:.*]] = torch_c.to_f64 %[[ARG_TORCH]]
// CHECK: %[[ALSO_ARG:.*]] = torch_c.to_f64 %[[ARG_TORCH]]
Expand Down
4 changes: 2 additions & 2 deletions test/Conversion/TorchToLinalg/flatten.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func @torch.aten.flatten.using_ints$basic_negative(%arg0: !torch.vtensor<[3,3,2,

// -----

// CHECK-LABEL: builtin.func @torch.aten.flatten.using_ints$flatten_front(
// CHECK-LABEL: func @torch.aten.flatten.using_ints$flatten_front(
// CHECK-SAME: %[[TENSOR:.*]]: !torch.vtensor<[3,3,2,2],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[3,3,2,2],f32> -> tensor<3x3x2x2xf32>
// CHECK: %[[COLLAPSED:.*]] = linalg.tensor_collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0, 1, 2], [3]] : tensor<3x3x2x2xf32> into tensor<18x2xf32>
Expand All @@ -53,7 +53,7 @@ func @torch.aten.flatten.using_ints$flatten_front(%arg0: !torch.vtensor<[3,3,2,2

// -----

// CHECK-LABEL: builtin.func @torch.aten.flatten.using_ints$flatten_back(
// CHECK-LABEL: func @torch.aten.flatten.using_ints$flatten_back(
// CHECK-SAME: %[[TENSOR:.*]]: !torch.vtensor<[3,3,2,2],f32>) -> !torch.vtensor<[?,12],f32> {
// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[TENSOR]] : !torch.vtensor<[3,3,2,2],f32> -> tensor<3x3x2x2xf32>
// CHECK: %[[COLLAPSED:.*]] = linalg.tensor_collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0], [1, 2, 3]] : tensor<3x3x2x2xf32> into tensor<3x12xf32>
Expand Down
2 changes: 1 addition & 1 deletion test/Conversion/TorchToLinalg/pooling.mlir
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// RUN: npcomp-opt <%s -convert-torch-to-linalg -split-input-file -verify-diagnostics | FileCheck %s

// CHECK-LABEL: builtin.func @forward
// CHECK-LABEL: func @forward
builtin.func @forward(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
%int1 = torch.constant.int 1
%int2 = torch.constant.int 2
Expand Down
22 changes: 11 additions & 11 deletions test/Dialect/Basicpy/functions.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -5,52 +5,52 @@
//===----------------------------------------------------------------------===//

// CHECK-LABEL: func @positional
func @positional(%arg0 : !basicpy.UnknownType, %arg1 : !basicpy.UnknownType) -> !basicpy.UnknownType {
builtin.func @positional(%arg0 : !basicpy.UnknownType, %arg1 : !basicpy.UnknownType) -> !basicpy.UnknownType {
// CHECK: basicpy.func_template_call @foobar(%arg0, %arg1) kw []
%0 = basicpy.func_template_call @foobar(%arg0, %arg1) kw [] : (!basicpy.UnknownType, !basicpy.UnknownType) -> !basicpy.UnknownType
return %0 : !basicpy.UnknownType
}

// -----
// CHECK-LABEL: func @kwValid
func @kwValid(%arg0 : !basicpy.UnknownType, %arg1 : !basicpy.UnknownType) -> !basicpy.UnknownType {
builtin.func @kwValid(%arg0 : !basicpy.UnknownType, %arg1 : !basicpy.UnknownType) -> !basicpy.UnknownType {
// CHECK: basicpy.func_template_call @foobar(%arg0, %arg1) kw ["second"]
%0 = basicpy.func_template_call @foobar(%arg0, %arg1) kw ["second"] : (!basicpy.UnknownType, !basicpy.UnknownType) -> !basicpy.UnknownType
return %0 : !basicpy.UnknownType
}

// -----
// CHECK-LABEL: func @posArgPack
func @posArgPack(%arg0 : !basicpy.UnknownType, %arg1 : !basicpy.UnknownType) -> !basicpy.UnknownType {
builtin.func @posArgPack(%arg0 : !basicpy.UnknownType, %arg1 : !basicpy.UnknownType) -> !basicpy.UnknownType {
// CHECK: basicpy.func_template_call @foobar(%arg0, %arg1) kw ["*"]
%0 = basicpy.func_template_call @foobar(%arg0, %arg1) kw ["*"] : (!basicpy.UnknownType, !basicpy.UnknownType) -> !basicpy.UnknownType
return %0 : !basicpy.UnknownType
}

// -----
// CHECK-LABEL: func @kwArgPack
func @kwArgPack(%arg0 : !basicpy.UnknownType, %arg1 : !basicpy.UnknownType) -> !basicpy.UnknownType {
builtin.func @kwArgPack(%arg0 : !basicpy.UnknownType, %arg1 : !basicpy.UnknownType) -> !basicpy.UnknownType {
// CHECK: basicpy.func_template_call @foobar(%arg0, %arg1) kw ["**"]
%0 = basicpy.func_template_call @foobar(%arg0, %arg1) kw ["**"] : (!basicpy.UnknownType, !basicpy.UnknownType) -> !basicpy.UnknownType
return %0 : !basicpy.UnknownType
}

// -----
func @kwOverflow(%arg0 : !basicpy.UnknownType, %arg1 : !basicpy.UnknownType) -> !basicpy.UnknownType {
builtin.func @kwOverflow(%arg0 : !basicpy.UnknownType, %arg1 : !basicpy.UnknownType) -> !basicpy.UnknownType {
// expected-error @+1 {{expected <= kw arg names vs args}}
%0 = basicpy.func_template_call @foobar(%arg0, %arg1) kw ["second", "third", "fourth"] : (!basicpy.UnknownType, !basicpy.UnknownType) -> !basicpy.UnknownType
return %0 : !basicpy.UnknownType
}

// -----
func @badPosArgPack(%arg0 : !basicpy.UnknownType, %arg1 : !basicpy.UnknownType) -> !basicpy.UnknownType {
builtin.func @badPosArgPack(%arg0 : !basicpy.UnknownType, %arg1 : !basicpy.UnknownType) -> !basicpy.UnknownType {
// expected-error @+1 {{positional arg pack must be the first kw arg}}
%0 = basicpy.func_template_call @foobar(%arg0, %arg1) kw ["*", "*"] : (!basicpy.UnknownType, !basicpy.UnknownType) -> !basicpy.UnknownType
return %0 : !basicpy.UnknownType
}

// -----
func @badKwArgPack(%arg0 : !basicpy.UnknownType, %arg1 : !basicpy.UnknownType) -> !basicpy.UnknownType {
builtin.func @badKwArgPack(%arg0 : !basicpy.UnknownType, %arg1 : !basicpy.UnknownType) -> !basicpy.UnknownType {
// expected-error @+1 {{kw arg pack must be the last kw arg}}
%0 = basicpy.func_template_call @foobar(%arg0, %arg1) kw ["**", "next"] : (!basicpy.UnknownType, !basicpy.UnknownType) -> !basicpy.UnknownType
return %0 : !basicpy.UnknownType
Expand All @@ -62,20 +62,20 @@ func @badKwArgPack(%arg0 : !basicpy.UnknownType, %arg1 : !basicpy.UnknownType) -

// -----
// CHECK-LABEL: module @valid_template
module @valid_template {
builtin.module @valid_template {
// CHECK: basicpy.func_template @__global$pkg.foobar attributes {py_bind = ["#abs"]} {
basicpy.func_template @__global$pkg.foobar attributes {py_bind = ["#abs"]} {
// CHECK: func @forInts(%arg0: i32) -> i32
func @forInts(%arg0 : i32) -> i32 {
builtin.func @forInts(%arg0 : i32) -> i32 {
return %arg0 : i32
}
}
}

// -----
module @invalid_template {
builtin.module @invalid_template {
basicpy.func_template @__global$pkg.foobar {
// expected-error @+1 {{illegal operation in func_template}}
module {}
builtin.module {}
}
}
Loading

0 comments on commit 1dec561

Please sign in to comment.