Skip to content

Commit

Permalink
use secret.cast instead of custom type conversions
Browse files Browse the repository at this point in the history
  • Loading branch information
j2kun committed Nov 28, 2023
1 parent 11efb23 commit 2658e7f
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 126 deletions.
11 changes: 4 additions & 7 deletions lib/Dialect/Secret/IR/SecretOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,15 +235,13 @@ void GenericOp::build(OpBuilder &builder, OperationState &result,
bodyBuilder(builder, result.location, bodyBlock.getArguments());
}

LogicalResult CastOp::fold(CastOp::FoldAdaptor adaptor,
SmallVectorImpl<OpFoldResult> &foldResults) {
OpFoldResult CastOp::fold(CastOp::FoldAdaptor adaptor) {
Value input = getInput();
Value output = getOutput();

// self cast is a no-op
if (input.getType() == output.getType()) {
foldResults.push_back(input);
return success();
return input;
}

// Fold a cast-and-cast-back to a no-op.
Expand All @@ -254,10 +252,9 @@ LogicalResult CastOp::fold(CastOp::FoldAdaptor adaptor,
// folds to use %0 directly in place of %2.
auto inputOp = input.getDefiningOp<CastOp>();
if (!inputOp || output.getType() != inputOp.getInput().getType())
return failure();
return OpFoldResult();

foldResults.append(inputOp.getInput());
return success();
return inputOp.getInput();
}

} // namespace secret
Expand Down
54 changes: 11 additions & 43 deletions lib/Transforms/YosysOptimizer/YosysOptimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,17 +133,8 @@ LogicalResult convertOpOperands(secret::GenericOp op, func::FuncOp func,

// Insert a conversion from the original type to the converted type
OpBuilder builder(op);
typeConvertedArgs.push_back(
builder
.create<secret::GenericOp>(
op.getLoc(), opOperand.get(),
secret::SecretType::get(convertedType),
[&](OpBuilder &b, Location loc, ValueRange blockArguments) {
auto fromElementsOp = convertIntegerValue(
blockArguments[0], convertedType, b, loc);
b.create<secret::YieldOp>(loc, fromElementsOp.getResult());
})
.getResult(0));
typeConvertedArgs.push_back(builder.create<secret::CastOp>(
op.getLoc(), secret::SecretType::get(convertedType), opOperand.get()));
}

return success();
Expand All @@ -152,7 +143,7 @@ LogicalResult convertOpOperands(secret::GenericOp op, func::FuncOp func,
/// Convert a secret.generic's results from secret.secret<tensor<3xi1>>
/// to secret.secret<i3>.
LogicalResult convertOpResults(secret::GenericOp op,
DenseSet<Operation *> &generics,
DenseSet<Operation *> &castOps,
SmallVector<Value> &typeConvertedResults) {
for (Value opResult : op.getResults()) {
if (!opResult.getType().isa<secret::SecretType>()) {
Expand Down Expand Up @@ -189,33 +180,10 @@ LogicalResult convertOpResults(secret::GenericOp op,
// tensor version.
OpBuilder builder(op);
builder.setInsertionPointAfter(op);
auto genericOp = builder.create<secret::GenericOp>(
op.getLoc(), opResult, secret::SecretType::get(reassembledType),
[&](OpBuilder &builder, Location loc, ValueRange blockArguments) {
ImplicitLocOpBuilder b(op.getLoc(), builder);
// At this point, there is a single block argument that must
// be a tensor<Nxi1>
Value arg = blockArguments[0];
Value accum = b.create<arith::ConstantOp>(
reassembledType, b.getIntegerAttr(reassembledType, 0));

for (int i = 0; i < reassembledType.getWidth(); i++) {
// x |= arg[i] << i
auto cI = b.create<arith::ConstantOp>(
reassembledType, b.getIntegerAttr(reassembledType, i));
auto cIndex = b.create<arith::ConstantOp>(b.getIndexType(),
b.getIndexAttr(i));
auto extractedBit =
b.create<tensor::ExtractOp>(arg, cIndex.getResult());
auto shifted = b.create<arith::ShLIOp>(
b.create<arith::ExtSIOp>(reassembledType, extractedBit), cI);
accum = b.create<arith::OrIOp>(loc, accum, shifted);
}

b.create<secret::YieldOp>(loc, accum);
});
generics.insert(genericOp);
typeConvertedResults.push_back(genericOp.getResult(0));
auto castOp = builder.create<secret::CastOp>(
op.getLoc(), secret::SecretType::get(reassembledType), opResult);
castOps.insert(castOp);
typeConvertedResults.push_back(castOp.getOutput());
}

return success();
Expand Down Expand Up @@ -297,11 +265,11 @@ LogicalResult runOnGenericOp(MLIRContext *context, secret::GenericOp op,
returnOp.erase();
func.erase();

DenseSet<Operation *> generics;
DenseSet<Operation *> castOps;
SmallVector<Value> typeConvertedResults;
generics.reserve(op->getNumResults());
castOps.reserve(op->getNumResults());
typeConvertedResults.reserve(op->getNumResults());
if (failed(convertOpResults(op, generics, typeConvertedResults))) {
if (failed(convertOpResults(op, castOps, typeConvertedResults))) {
return failure();
}

Expand All @@ -312,7 +280,7 @@ LogicalResult runOnGenericOp(MLIRContext *context, secret::GenericOp op,

op.getResults().replaceUsesWithIf(
typeConvertedResults, [&](OpOperand &operand) {
return !generics.contains(operand.getOwner());
return !castOps.contains(operand.getOwner());
});
return success();
}
Expand Down
6 changes: 5 additions & 1 deletion tests/yosys_optimizer/add_one.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ module {
func.func @add_one(%in: !secret.secret<i8>) -> (!secret.secret<i8>) {
%one = arith.constant 1 : i8
// Generic to convert the i8 to a tensor
// CHECK: secret.generic
// CHECK: secret.cast
// CHECK-SAME: !secret.secret<i8> to !secret.secret<tensor<8xi1>>

// CHECK: secret.generic
%1 = secret.generic
Expand All @@ -15,6 +16,9 @@ module {
%2 = arith.addi %IN, %ONE : i8
secret.yield %2 : i8
} -> (!secret.secret<i8>)

// CHECK: secret.cast
// CHECK-SAME: !secret.secret<tensor<8xi1>> to !secret.secret<i8>
return %1 : !secret.secret<i8>
}
}
79 changes: 10 additions & 69 deletions tests/yosys_optimizer/arith_ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -15,81 +15,22 @@ func.func @ops(
secret.yield %2 : i3
} -> (!secret.secret<i3>)
return %1 : !secret.secret<i3>
// Converting arg0
// CHECK: secret.generic
// CHECK: arith.andi
// CHECK: arith.shrsi
// CHECK: arith.trunci
// CHECK: arith.andi
// CHECK: arith.shrsi
// CHECK: arith.trunci
// CHECK: arith.andi
// CHECK: arith.shrsi
// CHECK: arith.trunci
// CHECK: tensor.from_elements
// CHECK: secret.yield

// Converting arg1
// CHECK: secret.generic
// CHECK: arith.andi
// CHECK: arith.shrsi
// CHECK: arith.trunci
// CHECK: arith.andi
// CHECK: arith.shrsi
// CHECK: arith.trunci
// CHECK: arith.andi
// CHECK: arith.shrsi
// CHECK: arith.trunci
// CHECK: tensor.from_elements
// CHECK: secret.yield

// Converting arg2
// CHECK: secret.generic
// CHECK: arith.andi
// CHECK: arith.shrsi
// CHECK: arith.trunci
// CHECK: arith.andi
// CHECK: arith.shrsi
// CHECK: arith.trunci
// CHECK: arith.andi
// CHECK: arith.shrsi
// CHECK: arith.trunci
// CHECK: tensor.from_elements
// CHECK: secret.yield

// Converting arg3
// CHECK: secret.generic
// CHECK: arith.andi
// CHECK: arith.shrsi
// CHECK: arith.trunci
// CHECK: arith.andi
// CHECK: arith.shrsi
// CHECK: arith.trunci
// CHECK: arith.andi
// CHECK: arith.shrsi
// CHECK: arith.trunci
// CHECK: tensor.from_elements
// CHECK: secret.yield
// CHECK: secret.cast
// CHECK-SAME: !secret.secret<i3> to !secret.secret<tensor<3xi1>>
// CHECK: secret.cast
// CHECK-SAME: !secret.secret<i3> to !secret.secret<tensor<3xi1>>
// CHECK: secret.cast
// CHECK-SAME: !secret.secret<i3> to !secret.secret<tensor<3xi1>>
// CHECK: secret.cast
// CHECK-SAME: !secret.secret<i3> to !secret.secret<tensor<3xi1>>

// Main computation
// CHECK: secret.generic
// CHECK-COUNT-7: comb.truth_table
// CHECK: secret.yield
// CHECK-SAME: tensor<3xi1>

// Reverse-converting main computation result
// CHECK: secret.generic
// CHECK: arith.extsi
// CHECK: arith.shli
// CHECK: arith.ori
// CHECK: arith.extsi
// CHECK: arith.shli
// CHECK: arith.ori
// CHECK: arith.extsi
// CHECK: arith.shli
// CHECK: arith.ori
// CHECK: secret.yield
// CHECK-SAME: i3

// CHECK: secret.cast
// CHECK-SAME: !secret.secret<tensor<3xi1>> to !secret.secret<i3>
// CHECK: return
}
6 changes: 3 additions & 3 deletions tests/yosys_optimizer/chunk_connections.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
// CHECK-LABEL: @for_loop
func.func @for_loop(%ARG0: !secret.secret<i8>, %ARG1: !secret.secret<i8>) -> !secret.secret<i32> {
// convert two ARGs
// CHECK: secret.generic
// CHECK: secret.generic
// CHECK: secret.cast
// CHECK: secret.cast

// CHECK: secret.generic
// CHECK-NOT: arith.extsi
Expand All @@ -25,7 +25,7 @@ func.func @for_loop(%ARG0: !secret.secret<i8>, %ARG1: !secret.secret<i8>) -> !se
secret.yield %5 : i32
} -> (!secret.secret<i32>)

// CHECK: secret.generic
// CHECK: secret.cast
// CHECK: return
return %1 : !secret.secret<i32>
}
8 changes: 5 additions & 3 deletions tests/yosys_optimizer/micro_speech_for.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
func.func @for_25_20_8(
%c98: !secret.secret<i32>, %c99: !secret.secret<i32>, %c100: !secret.secret<i8>) -> (!secret.secret<i8>) {
// convert three args
// CHECK: secret.generic
// CHECK: secret.generic
// CHECK: secret.generic
// CHECK: secret.cast
// CHECK: secret.cast
// CHECK: secret.cast

// The only arith op we expect is arith.constant
// CHECK-NOT: arith.{{^constant}}
Expand Down Expand Up @@ -44,6 +44,8 @@ func.func @for_25_20_8(
%121 = arith.trunci %120 : i32 to i8
secret.yield %121 : i8
} -> (!secret.secret<i8>)

// CHECK: secret.cast
// CHECK: return
func.return %1 : !secret.secret<i8>
}

0 comments on commit 2658e7f

Please sign in to comment.