From a3430324cda3981dd6d14253ba9e6979677be9a6 Mon Sep 17 00:00:00 2001 From: JordiManyer Date: Fri, 12 Apr 2024 17:45:10 +1000 Subject: [PATCH] Fixed BlockPArray bugs --- src/BlockPartitionedArrays.jl | 54 +++++++++++++++++++++++++++++------ test/MultiFieldTests.jl | 8 +++--- 2 files changed, 49 insertions(+), 13 deletions(-) diff --git a/src/BlockPartitionedArrays.jl b/src/BlockPartitionedArrays.jl index cd9e1c6..f173e5c 100644 --- a/src/BlockPartitionedArrays.jl +++ b/src/BlockPartitionedArrays.jl @@ -145,17 +145,19 @@ end function Base.copyto!(y::BlockPVector,x::BlockPVector) @check blocklength(x) == blocklength(y) - for i in blockaxes(x,1) - copyto!(y[i],x[i]) + yb, xb = blocks(y), blocks(x) + for i in 1:blocksize(x,1) + copyto!(yb[i],xb[i]) end return y end function Base.copyto!(y::BlockPMatrix,x::BlockPMatrix) @check blocksize(x) == blocksize(y) - for i in blockaxes(x,1) - for j in blockaxes(x,2) - copyto!(y[i,j],x[i,j]) + yb, xb = blocks(y), blocks(x) + for i in 1:blocksize(x,1) + for j in 1:blocksize(x,2) + copyto!(yb[i,j],xb[i,j]) end end return y @@ -169,6 +171,8 @@ function Base.fill!(a::BlockPVector,v) end function Base.sum(a::BlockPArray) + # TODO: This could use a single communication, instead of one for each block + # TODO: We could implement a generic reduce, that we apply to sum, all, any, etc.. return sum(map(sum,blocks(a))) end @@ -284,15 +288,47 @@ end # LinearAlgebra API +function Base.:*(a::Number,b::BlockArray) + mortar(map(bi -> a*bi,blocks(b))) +end +Base.:*(b::BlockPMatrix,a::Number) = a*b +Base.:/(b::BlockPVector,a::Number) = (1/a)*b + +function Base.:*(a::BlockPMatrix,b::BlockPVector) + c = similar(b) + mul!(c,a,b) + return c +end + +for op in (:+,:-) + @eval begin + function Base.$op(a::BlockPArray) + mortar(map($op,blocks(a))) + end + function Base.$op(a::BlockPArray,b::BlockPArray) + @assert blocksize(a) == blocksize(b) + mortar(map($op,blocks(a),blocks(b))) + end + end +end + function LinearAlgebra.mul!(y::BlockPVector,A::BlockPMatrix,x::BlockPVector) + o = one(eltype(A)) + mul!(y,A,x,o,o) +end + +function LinearAlgebra.mul!(y::BlockPVector,A::BlockPMatrix,x::BlockPVector,α::Number,β::Number) + yb, Ab, xb = blocks(y), blocks(A), blocks(x) z = zero(eltype(y)) o = one(eltype(A)) - for i in blockaxes(A,2) - fill!(y[i],z) - for j in blockaxes(A,2) - mul!(y[i],A[i,j],x[j],o,o) + for i in 1:blocksize(A,1) + fill!(yb[i],z) + for j in 1:blocksize(A,2) + mul!(yb[i],Ab[i,j],xb[j],α,o) end + rmul!(yb[i],β) end + return y end function LinearAlgebra.dot(x::BlockPVector,y::BlockPVector) diff --git a/test/MultiFieldTests.jl b/test/MultiFieldTests.jl index 00f349e..9a0475a 100644 --- a/test/MultiFieldTests.jl +++ b/test/MultiFieldTests.jl @@ -1,8 +1,7 @@ module MultiFieldTests using Gridap -using Gridap.FESpaces -using Gridap.MultiField +using Gridap.FESpaces, Gridap.MultiField, Gridap.Algebra using GridapDistributed using PartitionedArrays using Test @@ -74,8 +73,9 @@ function main(distribute, parts, mfs) A1 = assemble_matrix(a1,UxP,UxP) A2 = assemble_matrix(a2,UxP,UxP) - x = prandn(partition(axes(A1,2))) - @test norm(A1*x-A2*x) < 1.0e-9 + x1 = allocate_in_domain(A1); fill!(x1,1.0) + x2 = allocate_in_domain(A2); fill!(x2,1.0) + @test norm(A1*x1-A2*x2) < 1.0e-9 end function main(distribute, parts)