From 2062f1a969f5b4ca96ab878b819aa231412a7890 Mon Sep 17 00:00:00 2001 From: Jordan Matelsky Date: Wed, 11 Dec 2024 13:39:42 -0500 Subject: [PATCH] test: Add dataframe-from-existing test --- grand/backends/_dataframe.py | 8 +++++--- grand/backends/test_backends.py | 34 +++++++++++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 3 deletions(-) diff --git a/grand/backends/_dataframe.py b/grand/backends/_dataframe.py index 4beed47..2dd2833 100644 --- a/grand/backends/_dataframe.py +++ b/grand/backends/_dataframe.py @@ -33,10 +33,12 @@ def __init__( use as the node ID """ self._directed = directed - self._edge_df = edge_df or pd.DataFrame( - columns=[edge_df_source_column, edge_df_target_column] + self._edge_df = ( + edge_df + if edge_df is not None + else pd.DataFrame(columns=[edge_df_source_column, edge_df_target_column]) ) - self._node_df = node_df or None + self._node_df = node_df if node_df is not None else None self._edge_df_source_column = edge_df_source_column self._edge_df_target_column = edge_df_target_column self._node_df_id_column = node_df_id_column diff --git a/grand/backends/test_backends.py b/grand/backends/test_backends.py index 4e396e0..2fdc268 100644 --- a/grand/backends/test_backends.py +++ b/grand/backends/test_backends.py @@ -1,5 +1,6 @@ import pytest import os +import pandas as pd import networkx as nx @@ -448,3 +449,36 @@ def test_get_density_performance(backend): for i in range(1000 - 1): G.nx.add_edge(i, i + 1) assert nx.density(G.nx) <= 0.005 + + +class TestDataFrameBackend: + + def test_can_create_empty(self): + b = DataFrameBackend() + assert b.get_edge_count() == 0 + assert b.get_node_count() == 0 + + b.add_edge("A", "B", {}) + assert b.get_edge_count() == 1 + assert b.get_node_count() == 2 + + def test_can_create_from_int_dataframes(self): + # Create an edges DataFrame + edges = pd.DataFrame( + { + "source": [0, 1, 2, 3, 4], + "target": [1, 2, 3, 4, 0], + "weight": [1, 2, 3, 4, 5], + } + ) + + nodes = pd.DataFrame( + { + "name": [0, 1, 2, 3, 4], + "value": [1, 2, 3, 4, 5], + } + ) + + b = DataFrameBackend(edge_df=edges, node_df=nodes) + assert b.get_edge_count() == 5 + assert b.get_node_count() == 5