diff --git a/GNNGraphs/src/gnngraph.jl b/GNNGraphs/src/gnngraph.jl index 1c826d748..0a818b4d2 100644 --- a/GNNGraphs/src/gnngraph.jl +++ b/GNNGraphs/src/gnngraph.jl @@ -146,11 +146,9 @@ function GNNGraph(data::D; edata = normalize_graphdata(edata, default_name = :e, n = num_edges, duplicate_if_needed = true) - # don't force the shape of the data when there is only one graph - gdata = normalize_graphdata(gdata, default_name = :u, - n = num_graphs > 1 ? num_graphs : -1) + gdata = normalize_graphdata(gdata, default_name = :u, n = num_graphs, glob=true) - GNNGraph(graph, + return GNNGraph(graph, num_nodes, num_edges, num_graphs, graph_indicator, ndata, edata, gdata) @@ -203,7 +201,7 @@ function GNNGraph(g::GNNGraph; ndata = g.ndata, edata = g.edata, gdata = g.gdata ndata = normalize_graphdata(ndata, default_name = :x, n = g.num_nodes) edata = normalize_graphdata(edata, default_name = :e, n = g.num_edges, duplicate_if_needed = true) - gdata = normalize_graphdata(gdata, default_name = :u, n = g.num_graphs) + gdata = normalize_graphdata(gdata, default_name = :u, n = g.num_graphs, glob=true) if !isnothing(graph_type) if graph_type == :coo diff --git a/GNNGraphs/src/gnnheterograph/gnnheterograph.jl b/GNNGraphs/src/gnnheterograph/gnnheterograph.jl index 4a3f8b924..7c2b9ede5 100644 --- a/GNNGraphs/src/gnnheterograph/gnnheterograph.jl +++ b/GNNGraphs/src/gnnheterograph/gnnheterograph.jl @@ -144,8 +144,7 @@ function GNNHeteroGraph(data::EDict; ndata = normalize_heterographdata(ndata, default_name = :x, ns = num_nodes) edata = normalize_heterographdata(edata, default_name = :e, ns = num_edges, duplicate_if_needed = true) - gdata = normalize_graphdata(gdata, default_name = :u, - n = num_graphs > 1 ? num_graphs : -1) + gdata = normalize_graphdata(gdata, default_name = :u, n = num_graphs, glob = true) end return GNNHeteroGraph(graph, diff --git a/GNNGraphs/src/utils.jl b/GNNGraphs/src/utils.jl index a6e96a3ab..9c2b94057 100644 --- a/GNNGraphs/src/utils.jl +++ b/GNNGraphs/src/utils.jl @@ -129,19 +129,24 @@ function normalize_graphdata(data; default_name::Symbol, kws...) normalize_graphdata(NamedTuple{(default_name,)}((data,)); default_name, kws...) end -function normalize_graphdata(data::NamedTuple; default_name, n, duplicate_if_needed = false) +function normalize_graphdata(data::NamedTuple; default_name::Symbol, n::Int, + duplicate_if_needed::Bool = false, glob::Bool = false) # This had to workaround two Zygote bugs with NamedTuples # https://github.com/FluxML/Zygote.jl/issues/1071 - # https://github.com/FluxML/Zygote.jl/issues/1072 + # https://github.com/FluxML/Zygote.jl/issues/1072 # TODO fixed. Can we simplify something? + if n > 1 @assert all(x -> x isa AbstractArray, data) "Non-array features provided." end - if n <= 1 - # If last array dimension is not 1, add a new dimension. - # This is mostly useful to reshape global feature vectors - # of size D to Dx1 matrices. + if n <= 1 && glob == true + @assert n == 1 + n = -1 # relax the case of a single graph, allowing to store arbitrary types + # # # If last array dimension is not 1, add a new dimension. + # # # This is mostly useful to reshape global feature vectors + # # # of size D to Dx1 matrices. + # TODO remove this and handle better the batching of global features unsqz_last(v::AbstractArray) = size(v)[end] != 1 ? reshape(v, size(v)..., 1) : v unsqz_last(v) = v @@ -161,7 +166,7 @@ function normalize_graphdata(data::NamedTuple; default_name, n, duplicate_if_nee for x in data if x isa AbstractArray - @assert size(x)[end]==n "Wrong size in last dimension for feature array, expected $n but got $(size(x)[end])." + @assert size(x)[end] == n "Wrong size in last dimension for feature array, expected $n but got $(size(x)[end])." end end end diff --git a/GNNGraphs/test/gnngraph.jl b/GNNGraphs/test/gnngraph.jl index 252d2f5c0..2b18fe7b7 100644 --- a/GNNGraphs/test/gnngraph.jl +++ b/GNNGraphs/test/gnngraph.jl @@ -228,6 +228,17 @@ end @test g.num_nodes == 1 @test g.num_edges == 0 @test g.ndata.a == [1] + + g = GNNGraph((Int[], Int[]); ndata=(;a=[1]), edata=(;b=Int[]), num_nodes=1) + @test g.num_nodes == 1 + @test g.num_edges == 0 + @test g.ndata.a == [1] + @test g.edata.b == Int[] + + g = GNNGraph(; edata=(;b=Int[])) + @test g.num_nodes == 0 + @test g.num_edges == 0 + @test g.edata.b == Int[] end