Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remaining bit operations #4

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
279 changes: 270 additions & 9 deletions asl_xdsl/dialects/asl.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,6 +827,217 @@ def __init__(
)


@irdl_op_definition
class ZeroExtendBitsOp(IRDLOperation):
"""A bit vector zero-extend operation."""

name = "asl.zero_extend_bits"

S: ClassVar = VarConstraint("S", BaseAttr(BitVectorType))
T: ClassVar = VarConstraint("T", BaseAttr(BitVectorType))

lhs = operand_def(S)
rhs = operand_def(IntegerType())
res = result_def(T)
Comment on lines +836 to +841
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can actually just replace this with

Suggested change
S: ClassVar = VarConstraint("S", BaseAttr(BitVectorType))
T: ClassVar = VarConstraint("T", BaseAttr(BitVectorType))
lhs = operand_def(S)
rhs = operand_def(IntegerType())
res = result_def(T)
lhs = operand_def(BitVectorType)
rhs = operand_def(IntegerType())
res = result_def(BitVectorType)

The VarConstraint are only really used whenever you need two types to be exactly equal


assembly_format = (
"$lhs `,` $rhs `:` `(` type($lhs) `,` type($rhs) `)` `->` type($res) attr-dict"
)

def __init__(
self,
lhs: SSAValue,
rhs: SSAValue,
res: SSAValue,
attr_dict: Mapping[str, Attribute] = {},
):
super().__init__(
operands=[lhs, rhs],
result_types=[res.type],
attributes=attr_dict,
)


@irdl_op_definition
class SignExtendBitsOp(IRDLOperation):
"""A bit vector zero-extend operation."""

name = "asl.sign_extend_bits"

S: ClassVar = VarConstraint("S", BaseAttr(BitVectorType))
T: ClassVar = VarConstraint("T", BaseAttr(BitVectorType))

lhs = operand_def(S)
rhs = operand_def(IntegerType())
res = result_def(T)

assembly_format = (
"$lhs `,` $rhs `:` `(` type($lhs) `,` type($rhs) `)` `->` type($res) attr-dict"
)

def __init__(
self,
lhs: SSAValue,
rhs: SSAValue,
res: SSAValue,
attr_dict: Mapping[str, Attribute] = {},
):
super().__init__(
operands=[lhs, rhs],
result_types=[res.type],
attributes=attr_dict,
)


@irdl_op_definition
class AppendBitsOp(IRDLOperation):
"""A bit vector append operation."""

name = "asl.append_bits"

S: ClassVar = VarConstraint("S", BaseAttr(BitVectorType))
T: ClassVar = VarConstraint("T", BaseAttr(BitVectorType))
U: ClassVar = VarConstraint("U", BaseAttr(BitVectorType))

lhs = operand_def(S)
rhs = operand_def(T)
res = result_def(U)

assembly_format = (
"$lhs `,` $rhs `:` `(` type($lhs) `,` type($rhs) `)` `->` type($res) attr-dict"
)

def __init__(
self,
lhs: SSAValue,
rhs: SSAValue,
res: SSAValue,
attr_dict: Mapping[str, Attribute] = {},
):
super().__init__(
operands=[lhs, rhs],
result_types=[res.type],
attributes=attr_dict,
)


@irdl_op_definition
class ReplicateBitsOp(IRDLOperation):
"""A bit vector replication operation."""

name = "asl.replicate_bits"

S: ClassVar = VarConstraint("S", BaseAttr(BitVectorType))
T: ClassVar = VarConstraint("T", BaseAttr(BitVectorType))

lhs = operand_def(S)
rhs = operand_def(IntegerType())
res = result_def(T)

assembly_format = (
"$lhs `,` $rhs `:` `(` type($lhs) `,` type($rhs) `)` `->` type($res) attr-dict"
)

def __init__(
self,
lhs: SSAValue,
rhs: SSAValue,
res: SSAValue,
attr_dict: Mapping[str, Attribute] = {},
):
super().__init__(
operands=[lhs, rhs],
result_types=[res.type],
attributes=attr_dict,
)


@irdl_op_definition
class ZerosBitsOp(IRDLOperation):
"""A bit vector all-zeros operation."""

name = "asl.zeros_bits"

T: ClassVar = VarConstraint("T", BaseAttr(BitVectorType))

arg = operand_def(IntegerType())
res = result_def(T)

assembly_format = "$arg `:` type($arg) `->` type($res) attr-dict"

def __init__(
self,
arg: SSAValue,
res: SSAValue,
attr_dict: Mapping[str, Attribute] = {},
):
super().__init__(
operands=[arg],
result_types=[res.type],
attributes=attr_dict,
)


@irdl_op_definition
class OnesBitsOp(IRDLOperation):
"""A bit vector all-ones operation."""

name = "asl.ones_bits"

T: ClassVar = VarConstraint("T", BaseAttr(BitVectorType))

arg = operand_def(IntegerType())
res = result_def(T)

assembly_format = "$arg `:` type($arg) `->` type($res) attr-dict"

def __init__(
self,
arg: SSAValue,
res: SSAValue,
attr_dict: Mapping[str, Attribute] = {},
):
super().__init__(
operands=[arg],
result_types=[res.type],
attributes=attr_dict,
)


@irdl_op_definition
class MkMaskBitsOp(IRDLOperation):
"""
A bit vector mask generation operation.
`mk_mask(x, N) : bits(N)` consists of `x` ones.
For example, `mk_mask(3, 8) == '0000 0111'`.
"""

name = "asl.mk_mask"

T: ClassVar = VarConstraint("T", BaseAttr(BitVectorType))

lhs = operand_def(IntegerType())
rhs = operand_def(IntegerType())
res = result_def(T)

assembly_format = (
"$lhs `,` $rhs `:` `(` type($lhs) `,` type($rhs) `)` `->` type($res) attr-dict"
)

def __init__(
self,
lhs: SSAValue,
rhs: SSAValue,
res: SSAValue,
attr_dict: Mapping[str, Attribute] = {},
):
super().__init__(
operands=[lhs, rhs],
result_types=[res.type],
attributes=attr_dict,
)


@irdl_op_definition
class NotBitsOp(IRDLOperation):
"""A bitwise NOT operation."""
Expand Down Expand Up @@ -1166,28 +1377,70 @@ def __init__(


@irdl_op_definition
class SliceSingleOp(IRDLOperation):
"""Slice a single element from a bit vector."""
class GetSliceOp(IRDLOperation):
"""Extract a slice from a bit vector."""

name = "asl.get_slice"

name = "asl.slice_single"
S: ClassVar = VarConstraint("S", BaseAttr(BitVectorType))
T: ClassVar = VarConstraint("T", BaseAttr(BitVectorType))

bits = operand_def(BitVectorType)
bits = operand_def(S)
index = operand_def(IntegerType)
width = operand_def(IntegerType)

res = result_def(BitVectorType(1))
res = result_def(T)

assembly_format = (
"$bits `[` $index `]` `:` type($bits) `[` type($index) `]` attr-dict"
"$bits `,` $index `,` $width "
"`:` `(` type($bits) `,` type($index) `,` type($width) `)` "
"`->` type($res) attr-dict"
)

def __init__(
self,
bits: SSAValue,
index: SSAValue,
width: SSAValue,
res: SSAValue,
):
super().__init__(
operands=[bits, index, width],
result_types=[res.type],
)


@irdl_op_definition
class SetSliceOp(IRDLOperation):
"""Insert a slice into a bit vector."""

name = "asl.set_slice"

S: ClassVar = VarConstraint("S", BaseAttr(BitVectorType))
T: ClassVar = VarConstraint("T", BaseAttr(BitVectorType))

bits = operand_def(S)
index = operand_def(IntegerType)
width = operand_def(IntegerType)
rhs = operand_def(T)
res = result_def(S)

assembly_format = (
"$bits `,` $index `,` $width `,` $rhs "
"`:` `(` type($bits) `,` type($index) `,` type($width) `,` type($rhs) `)` "
"`->` type($res) attr-dict"
)

def __init__(
self,
bits: SSAValue,
index: SSAValue,
width: SSAValue,
rhs: SSAValue,
):
super().__init__(
operands=[bits, index],
result_types=[BitVectorType(1)],
operands=[bits, index, width, rhs],
result_types=[bits.type],
)


Expand Down Expand Up @@ -1231,6 +1484,13 @@ def __init__(
LslBitsOp,
LsrBitsOp,
AsrBitsOp,
ZeroExtendBitsOp,
SignExtendBitsOp,
AppendBitsOp,
ReplicateBitsOp,
ZerosBitsOp,
OnesBitsOp,
MkMaskBitsOp,
AddBitsIntOp,
SubBitsIntOp,
NotBitsOp,
Expand All @@ -1244,7 +1504,8 @@ def __init__(
FuncOp,
CallOp,
# Slices
SliceSingleOp,
GetSliceOp,
SetSliceOp,
],
[
IntegerType,
Expand Down
Loading
Loading