Skip to content

Commit

Permalink
dialects: (builtin) constrain DensIntOrFPElementsattr to the correct …
Browse files Browse the repository at this point in the history
…nb of elements (#3637)
  • Loading branch information
jorendumoulin authored Dec 13, 2024
1 parent 1f67b7d commit 46c7930
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 8 deletions.
35 changes: 32 additions & 3 deletions tests/dialects/test_builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from xdsl.dialects.arith import ConstantOp
from xdsl.dialects.builtin import (
AnyTensorType,
AnyVectorType,
ArrayAttr,
BFloat16Type,
BytesAttr,
Expand Down Expand Up @@ -237,7 +238,7 @@ def test_IntegerType_packing():


def test_DenseIntOrFPElementsAttr_fp_type_conversion():
check1 = DenseIntOrFPElementsAttr.tensor_from_list([4, 5], f32, [])
check1 = DenseIntOrFPElementsAttr.tensor_from_list([4, 5], f32, [2])

value1 = check1.get_attrs()[0].value.data
value2 = check1.get_attrs()[1].value.data
Expand All @@ -251,7 +252,7 @@ def test_DenseIntOrFPElementsAttr_fp_type_conversion():
t1 = FloatAttr(4.0, f32)
t2 = FloatAttr(5.0, f32)

check2 = DenseIntOrFPElementsAttr.tensor_from_list([t1, t2], f32, [])
check2 = DenseIntOrFPElementsAttr.tensor_from_list([t1, t2], f32, [2])

value3 = check2.get_attrs()[0].value.data
value4 = check2.get_attrs()[1].value.data
Expand All @@ -264,9 +265,37 @@ def test_DenseIntOrFPElementsAttr_fp_type_conversion():


def test_DenseIntOrFPElementsAttr_from_list():
# legal zero-rank tensor
attr = DenseIntOrFPElementsAttr.tensor_from_list([5.5], f32, [])

assert attr.type == AnyTensorType(f32, [])
assert len(attr) == 1

# illegal zero-rank tensor
with pytest.raises(
ValueError, match="A zero-rank tensor can only hold 1 value but 2 were given."
):
DenseIntOrFPElementsAttr.tensor_from_list([5.5, 5.6], f32, [])

# legal 1 element tensor
attr = DenseIntOrFPElementsAttr.tensor_from_list([5.5], f32, [1])
assert attr.type == AnyTensorType(f32, [1])
assert len(attr) == 1

# legal normal tensor
attr = DenseIntOrFPElementsAttr.tensor_from_list([5.5, 5.6], f32, [2])
assert attr.type == AnyTensorType(f32, [2])
assert len(attr) == 2

# splat initialization
attr = DenseIntOrFPElementsAttr.tensor_from_list([4], f32, [4])
assert attr.type == AnyTensorType(f32, [4])
assert tuple(attr.get_values()) == (4, 4, 4, 4)
assert len(attr) == 4

# vector with inferred shape
attr = DenseIntOrFPElementsAttr.vector_from_list([1, 2, 3, 4], f32)
assert attr.type == AnyVectorType(f32, [4])
assert len(attr) == 4


@pytest.mark.parametrize(
Expand Down
2 changes: 1 addition & 1 deletion tests/interpreters/test_ml_program_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def test_ml_program_global_load_constant():
interpreter.register_implementations(MLProgramFunctions())

(result,) = interpreter.run_op(fetch, ())
assert result == ShapedArray(TypedPtr.new_int32([4]), [4])
assert result == ShapedArray(TypedPtr.new_int32([4] * 4), [4])


def test_ml_program_global_load_constant_ex2():
Expand Down
30 changes: 26 additions & 4 deletions xdsl/dialects/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1944,6 +1944,9 @@ def get_shape(self) -> tuple[int, ...]:
def get_element_type(self) -> IntegerType | IndexType | AnyFloat:
return self.type.get_element_type()

def __len__(self) -> int:
return len(self.data)

@property
def shape_is_complete(self) -> bool:
shape = self.get_shape()
Expand Down Expand Up @@ -2040,25 +2043,44 @@ def from_list(
),
data: Sequence[int | float] | Sequence[AnyIntegerAttr] | Sequence[AnyFloatAttr],
) -> DenseIntOrFPElementsAttr:
# zero rank type should only hold 1 value
if not type.get_shape() and len(data) != 1:
raise ValueError(
f"A zero-rank {type.name} can only hold 1 value but {len(data)} were given."
)

# splat value given
if len(data) == 1 and prod(type.get_shape()) != 1:
new_data = (data[0],) * prod(type.get_shape())
else:
new_data = data

if isinstance(type.element_type, AnyFloat):
new_type = cast(RankedStructure[AnyFloat], type)
new_data = cast(Sequence[int | float] | Sequence[FloatAttr[AnyFloat]], data)
new_data = cast(
Sequence[int | float] | Sequence[FloatAttr[AnyFloat]], new_data
)
return DenseIntOrFPElementsAttr.create_dense_float(new_type, new_data)
elif isinstance(type.element_type, IntegerType):
new_type = cast(RankedStructure[IntegerType], type)
new_data = cast(Sequence[int] | Sequence[IntegerAttr[IntegerType]], data)
new_data = cast(
Sequence[int] | Sequence[IntegerAttr[IntegerType]], new_data
)
return DenseIntOrFPElementsAttr.create_dense_int(new_type, new_data)
else:
new_type = cast(RankedStructure[IndexType], type)
new_data = cast(Sequence[int] | Sequence[IntegerAttr[IndexType]], data)
new_data = cast(Sequence[int] | Sequence[IntegerAttr[IndexType]], new_data)
return DenseIntOrFPElementsAttr.create_dense_index(new_type, new_data)

@staticmethod
def vector_from_list(
data: Sequence[int] | Sequence[float],
data_type: IntegerType | IndexType | AnyFloat,
shape: Sequence[int] | None = None,
) -> DenseIntOrFPElementsAttr:
t = VectorType(data_type, [len(data)])
if not shape:
shape = [len(data)]
t = VectorType(data_type, shape)
return DenseIntOrFPElementsAttr.from_list(t, data)

@staticmethod
Expand Down

0 comments on commit 46c7930

Please sign in to comment.