diff --git a/src/BlockPartitionedArrays.jl b/src/BlockPartitionedArrays.jl index 25e4346..aa83bb2 100644 --- a/src/BlockPartitionedArrays.jl +++ b/src/BlockPartitionedArrays.jl @@ -25,6 +25,21 @@ function Base.getindex(a::BlockPRange,inds::Block{1}) a.ranges[inds.n...] end +function PartitionedArrays.matching_local_indices(a::BlockPRange,b::BlockPRange) + c = map(PartitionedArrays.matching_local_indices,blocks(a),blocks(b)) + reduce(&,c,init=true) +end + +function PartitionedArrays.matching_own_indices(a::BlockPRange,b::BlockPRange) + c = map(PartitionedArrays.matching_own_indices,blocks(a),blocks(b)) + reduce(&,c,init=true) +end + +function PartitionedArrays.matching_ghost_indices(a::BlockPRange,b::BlockPRange) + c = map(PartitionedArrays.matching_ghost_indices,blocks(a),blocks(b)) + reduce(&,c,init=true) +end + """ struct BlockPArray{V,T,N,A,B} <: BlockArrays.AbstractBlockArray{T,N} """ diff --git a/src/MultiField.jl b/src/MultiField.jl index 71f2f4d..9987558 100644 --- a/src/MultiField.jl +++ b/src/MultiField.jl @@ -186,7 +186,7 @@ end function FESpaces.interpolate!(objects,free_values::AbstractVector,space::DistributedMultiFieldFESpace) msg = "free_values and FESpace have incompatible index partitions." - @check partition(axes(free_values,1)) === partition(space.gids) msg + @check PartitionedArrays.matching_local_indices(axes(free_values,1),get_free_dof_ids(space)) msg # Interpolate each field field_fe_fun = DistributedSingleFieldFEFunction[] @@ -219,7 +219,7 @@ function FESpaces.interpolate_everywhere!( space::DistributedMultiFieldFESpace ) msg = "free_values and FESpace have incompatible index partitions." - @check partition(axes(free_values,1)) === partition(space.gids) msg + @check PartitionedArrays.matching_local_indices(axes(free_values,1),get_free_dof_ids(space)) msg # Interpolate each field field_fe_fun = DistributedSingleFieldFEFunction[] diff --git a/test/MultiFieldTests.jl b/test/MultiFieldTests.jl index 3e70c2e..d38ff70 100644 --- a/test/MultiFieldTests.jl +++ b/test/MultiFieldTests.jl @@ -17,7 +17,6 @@ function main(distribute, parts, mfs) domain = (0,4,0,4) cells = (4,4) - model = CartesianDiscreteModel(domain,cells) model = CartesianDiscreteModel(ranks,parts,domain,cells) Ω = Triangulation(model)