Skip to content

Commit

Permalink
Fix unit test errors
Browse files Browse the repository at this point in the history
  • Loading branch information
davidmezzetti committed Apr 11, 2024
1 parent 2de0305 commit f553c27
Showing 1 changed file with 12 additions and 14 deletions.
26 changes: 12 additions & 14 deletions grand/backends/_sqlbackend.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ def get_node_neighbors(
res = self._connection.execute(
self._edge_table.select().where(
self._edge_table.c[self._edge_source_key] == str(u)
)
).order_by(self._edge_table.c[self._primary_key])
).fetchall()
else:
res = self._connection.execute(
Expand All @@ -395,7 +395,7 @@ def get_node_neighbors(
(self._edge_table.c[self._edge_source_key] == str(u)),
(self._edge_table.c[self._edge_target_key] == str(u)),
)
)
).order_by(self._edge_table.c[self._primary_key])
).fetchall()

res = [x._asdict() for x in res]
Expand All @@ -404,7 +404,7 @@ def get_node_neighbors(
return {
(
r[self._edge_source_key]
if r[self._edge_source_key] != u
if r[self._edge_source_key] != str(u)
else r[self._edge_target_key]
): r["_metadata"]
for r in res
Expand All @@ -414,7 +414,7 @@ def get_node_neighbors(
[
(
r[self._edge_source_key]
if r[self._edge_source_key] != u
if r[self._edge_source_key] != str(u)
else r[self._edge_target_key]
)
for r in res
Expand All @@ -438,7 +438,7 @@ def get_node_predecessors(
res = self._connection.execute(
self._edge_table.select().where(
self._edge_table.c[self._edge_target_key] == str(u)
)
).order_by(self._edge_table.c[self._primary_key])
).fetchall()
else:
res = self._connection.execute(
Expand All @@ -447,7 +447,7 @@ def get_node_predecessors(
(self._edge_table.c[self._edge_target_key] == str(u)),
(self._edge_table.c[self._edge_source_key] == str(u)),
)
)
).order_by(self._edge_table.c[self._primary_key])
).fetchall()

res = [x._asdict() for x in res]
Expand All @@ -456,7 +456,7 @@ def get_node_predecessors(
return {
(
r[self._edge_source_key]
if r[self._edge_source_key] != u
if r[self._edge_source_key] != str(u)
else r[self._edge_target_key]
): r["_metadata"]
for r in res
Expand All @@ -466,7 +466,7 @@ def get_node_predecessors(
[
(
r[self._edge_source_key]
if r[self._edge_source_key] != u
if r[self._edge_source_key] != str(u)
else r[self._edge_target_key]
)
for r in res
Expand Down Expand Up @@ -524,10 +524,9 @@ def out_degrees(self, nbunch=None):
if where_clause is not None:
query = query.where(where_clause)

results = [x._asdict() for x in self._connection.execute(query).fetchall()]
results = {
r[self._edge_source_key]: r[1]
for r in results
r[0]: r[1]
for r in self._connection.execute(query)
}

if nbunch and not isinstance(nbunch, (list, tuple)):
Expand Down Expand Up @@ -570,10 +569,9 @@ def in_degrees(self, nbunch=None):
if where_clause is not None:
query = query.where(where_clause)

results = [x._asdict() for x in self._connection.execute(query).fetchall()]
results = {
r[self._edge_target_key]: r[1]
for r in results
r[0]: r[1]
for r in self._connection.execute(query)
}

if nbunch and not isinstance(nbunch, (list, tuple)):
Expand Down

0 comments on commit f553c27

Please sign in to comment.