From 4b21a740cd0945200ada2a6fdae5a17bfa600d5d Mon Sep 17 00:00:00 2001
From: Anthony Mahanna <43019056+aMahanna@users.noreply.github.com>
Date: Mon, 22 Apr 2024 17:57:56 -0400
Subject: [PATCH] add edge type assertions (#37)

---
 tests/test_adapter.py | 11 ++++++++---
 1 file changed, 8 insertions(+), 3 deletions(-)

diff --git a/tests/test_adapter.py b/tests/test_adapter.py
index 7adc9d1..0828361 100644
--- a/tests/test_adapter.py
+++ b/tests/test_adapter.py
@@ -5,7 +5,7 @@
 from dgl import DGLGraph, DGLHeteroGraph
 from dgl.view import EdgeSpace, NodeSpace
 from pandas import DataFrame
-from torch import Tensor, cat, long, tensor
+from torch import Tensor, cat, int64, long, tensor
 
 from adbdgl_adapter import ADBDGL_Adapter
 from adbdgl_adapter.encoders import CategoricalEncoder, IdentityEncoder
@@ -573,6 +573,8 @@ def test_adb_partial_to_dgl() -> None:
     # Grab the same nodes from the Homogeneous graph
     from_nodes_new, to_nodes_new = dgl_g_new.edges(etype=None)
 
+    assert from_nodes.dtype == from_nodes_new.dtype
+    assert to_nodes.dtype == to_nodes_new.dtype
     assert from_nodes.tolist() == from_nodes_new.tolist()
     assert to_nodes.tolist() == to_nodes_new.tolist()
 
@@ -772,8 +774,11 @@ def assert_adb_to_dgl(
             from_nodes = et_df["from_key"].map(adb_map[from_col]).tolist()
             to_nodes = et_df["to_key"].map(adb_map[to_col]).tolist()
 
-            assert from_nodes == dgl_g.edges(etype=e_key)[0].tolist()
-            assert to_nodes == dgl_g.edges(etype=e_key)[1].tolist()
+            src, dst = dgl_g.edges(etype=e_key)
+            assert src.dtype == int64
+            assert dst.dtype == int64
+            assert from_nodes == src.tolist()
+            assert to_nodes == dst.tolist()
 
             assert_adb_to_dgl_meta(meta, et_df, dgl_g.edges[e_key].data)