Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion GNNLux/src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ function Base.show(io::IO, l::AGNNConv)
print(io, ")")
end

function (l::AGNNConv)(g, x::AbstractMatrix, ps, st)
function (l::AGNNConv)(g, x, ps, st)
β = l.trainable ? ps.β : l.init_beta
m = (; β, l.add_self_loops)
return GNNlib.agnn_conv(m, g, x), st
Expand Down
21 changes: 16 additions & 5 deletions GNNlib/src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -334,23 +334,34 @@ end

####################### AGNNConv ######################################

function agnn_conv(l, g::GNNGraph, x::AbstractMatrix)
function agnn_conv(l, g::AbstractGNNGraph, x)
check_num_nodes(g, x)
if l.add_self_loops
g = add_self_loops(g)
end

xn = x ./ sqrt.(sum(x .^ 2, dims = 1))
cos_dist = apply_edges(xi_dot_xj, g, xi = xn, xj = xn)
xj, xi = expand_srcdst(g, x)

xi_n = xi ./ sqrt.(sum(xi .^ 2, dims = 1))
xj_n = xj ./ sqrt.(sum(xj .^ 2, dims = 1))
Comment thread
i-Amogh marked this conversation as resolved.
Outdated
cos_dist = apply_edges(xi_dot_xj, g, xi = xi_n, xj = xj_n)
α = softmax_edge_neighbors(g, l.β .* cos_dist)

x = propagate(g, +; xj = x, e = α) do xi, xj, α
α .* xj
x = propagate(g, +; xj, e = α) do xi, xj, α
α .* xj
end

return x
end

"""
_has_same_node_types(g::GNNHeteroGraph)

Return true if all edge types in the heterogeneous graph have the same source and
target node types (i.e., no bipartite relations).
"""
_has_same_node_types(g::GNNHeteroGraph) = all(et -> et[1] == et[3], g.etypes)
Comment thread
i-Amogh marked this conversation as resolved.
Outdated

####################### MegNetConv ######################################

function megnet_conv(l, g::GNNGraph, x::AbstractMatrix, e::AbstractMatrix)
Expand Down
9 changes: 8 additions & 1 deletion GraphNeuralNetworks/test/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,14 @@ end
g.graph isa AbstractSparseMatrix && continue
@test size(l(g, g.x)) == (D_IN, g.num_nodes)
test_gradients(l, g, g.x, rtol = RTOL_HIGH, test_gpu = true, compare_finite_diff = false)
end
end
l_bip = AGNNConv(add_self_loops=false)
s = [1, 1, 2, 3]
t = [1, 2, 1, 2]
g = GNNGraph((s, t)) |> gpu
x = (randn(Float32, D_IN, 3) |> gpu, randn(Float32, D_IN, 2) |> gpu)
y = l_bip(g, x)
@test size(y) == (D_IN, 2)
end

@testitem "MEGNetConv" setup=[TolSnippet, TestModule] begin
Expand Down
10 changes: 9 additions & 1 deletion GraphNeuralNetworks/test/layers/heteroconv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,15 @@
x = (A = rand(Float32, 4,2), B = rand(Float32, 4, 3))
layers = HeteroGraphConv( (:A, :to, :B) => GCNConv(4 => 2, tanh),
(:B, :to, :A) => GCNConv(4 => 2, tanh));
y = layers(g, x);
y = layers(g, x);
@test size(y.A) == (2,2) && size(y.B) == (2,3)
end

@testset "AGNNConv" begin
x = (A = rand(Float32, 4, 2), B = rand(Float32, 4, 3))
layers = HeteroGraphConv((:A, :to, :B) => AGNNConv(),
(:B, :to, :A) => AGNNConv())
y = layers(hg, x)
@test size(y.A) == (4, 2) && size(y.B) == (4, 3)
end
end
Loading