Skip to content

Commit

Permalink
Merge pull request #2 from dagardner-nv/david-fsi_dgl-patch
Browse files Browse the repository at this point in the history
Update tests
  • Loading branch information
tzemicheal authored Aug 4, 2023
2 parents 7d6a1b4 + d3796f8 commit 9d61da2
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 176 deletions.
2 changes: 0 additions & 2 deletions examples/gnn_fraud_detection_pipeline/requirements.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,4 @@ channels:
- dglteam/label/cu118
dependencies:
- cuml=23.06
- dask>=2023.1.1
- dgl=1.0.2
- distributed>=2023.1.1
70 changes: 31 additions & 39 deletions tests/examples/gnn_fraud_detection_pipeline/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,32 +26,24 @@
"installing these additional dependencies")


@pytest.fixture(autouse=True, scope='session')
def stellargraph(fail_missing: bool):
@pytest.fixture(name="dgl", autouse=True, scope='session')
def dgl_fixture(fail_missing: bool):
"""
All of the tests in this subdir require stellargraph
All of the tests in this subdir require dgl
"""
yield import_or_skip("stellargraph", reason=SKIP_REASON, fail_missing=fail_missing)
yield import_or_skip("dgl", reason=SKIP_REASON, fail_missing=fail_missing)


@pytest.fixture(autouse=True, scope='session')
def cuml(fail_missing: bool):
@pytest.fixture(name="cuml", autouse=True, scope='session')
def cuml_fixture(fail_missing: bool):
"""
All of the tests in this subdir require cuml
"""
yield import_or_skip("cuml", reason=SKIP_REASON, fail_missing=fail_missing)


@pytest.fixture(autouse=True, scope='session')
def tensorflow(fail_missing: bool):
"""
All of the tests in this subdir require tensorflow
"""
yield import_or_skip("tensorflow", reason=SKIP_REASON, fail_missing=fail_missing)


@pytest.fixture
def config(config):
@pytest.fixture(name="config")
def config_fixture(config):
"""
The GNN fraud detection pipeline utilizes the "other" pipeline mode.
"""
Expand All @@ -60,38 +52,36 @@ def config(config):
yield config


@pytest.fixture
def example_dir():
@pytest.fixture(name="example_dir")
def example_dir_fixture():
yield os.path.join(TEST_DIRS.examples_dir, 'gnn_fraud_detection_pipeline')


@pytest.fixture
def training_file(example_dir: str):
@pytest.fixture(name="training_file")
def training_file_fixture(example_dir: str):
yield os.path.join(example_dir, 'training.csv')


@pytest.fixture
def hinsage_model(example_dir: str):
yield os.path.join(example_dir, 'model/hinsage-model.pt')
@pytest.fixture(name="model_dir")
def model_dir_fixture(example_dir: str):
yield os.path.join(example_dir, 'model')


@pytest.fixture
def xgb_model(example_dir: str):
yield os.path.join(example_dir, 'model/xgb-model.pt')
@pytest.fixture(name="xgb_model")
def xgb_model_fixture(model_dir: str):
yield os.path.join(model_dir, 'xgb.pt')


# Some of the code inside gnn_fraud_detection_pipeline performs some relative imports in the form of:
# from .mod import Class
# For this reason we need to ensure that the examples dir is in the sys.path first
@pytest.fixture
def gnn_fraud_detection_pipeline(request: pytest.FixtureRequest, restore_sys_path, reset_plugins):
sys.path.append(TEST_DIRS.examples_dir)
import gnn_fraud_detection_pipeline
yield gnn_fraud_detection_pipeline
@pytest.fixture(name="ex_in_sys_path", autouse=True)
def ex_in_sys_path_fixture(example_dir: str, restore_sys_path, reset_plugins): # pylint: disable=unused-argument
sys.path.append(example_dir)


@pytest.fixture
def test_data():
@pytest.fixture(name="test_data")
def test_data_fixture():
"""
Construct test data, a small DF of 10 rows which we will build a graph from
The nodes in our graph will be the unique values from each of our three columns, and the index is also
Expand Down Expand Up @@ -132,9 +122,11 @@ def test_data():

assert len(expected_edges) == 20 # ensuring test data & assumptions are correct

yield dict(index=index,
client_data=client_data,
merchant_data=merchant_data,
df=df,
expected_nodes=expected_nodes,
expected_edges=expected_edges)
yield {
"index": index,
"client_data": client_data,
"merchant_data": merchant_data,
"df": df,
"expected_nodes": expected_nodes,
"expected_edges": expected_edges
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,25 +25,15 @@
@pytest.mark.use_python
class TestClassificationStage:

def test_constructor(
self,
config: Config,
xgb_model: str,
gnn_fraud_detection_pipeline: types.ModuleType, # pylint: disable=unused-argument
cuml: types.ModuleType):
from gnn_fraud_detection_pipeline.stages.classification_stage import ClassificationStage
def test_constructor(self, config: Config, xgb_model: str, cuml: types.ModuleType):
from stages.classification_stage import ClassificationStage

stage = ClassificationStage(config, xgb_model)
assert isinstance(stage._xgb_model, cuml.ForestInference)

def test_process_message(
self,
config: Config,
xgb_model: str,
gnn_fraud_detection_pipeline: types.ModuleType, # pylint: disable=unused-argument
dataset_cudf: DatasetManager):
from gnn_fraud_detection_pipeline.stages.classification_stage import ClassificationStage
from gnn_fraud_detection_pipeline.stages.graph_sage_stage import GraphSAGEMultiMessage
def test_process_message(self, config: Config, xgb_model: str, dataset_cudf: DatasetManager):
from stages.classification_stage import ClassificationStage
from stages.graph_sage_stage import GraphSAGEMultiMessage

df = dataset_cudf['examples/gnn_fraud_detection_pipeline/inductive_emb.csv']
df.rename(lambda x: f"ind_emb_{x}", axis=1, inplace=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,93 +13,32 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import types
import typing
from io import StringIO

import pandas as pd
import pytest

import cudf

from morpheus.config import Config
from morpheus.messages import MessageMeta
from morpheus.messages import MultiMessage
from utils import TEST_DIRS


@pytest.mark.use_python
@pytest.mark.import_mod(
[os.path.join(TEST_DIRS.examples_dir, 'gnn_fraud_detection_pipeline/stages/graph_construction_stage.py')])
class TestGraphConstructionStage:

def test_constructor(self, config: Config, training_file: str, import_mod: typing.List[types.ModuleType]):
graph_construction_stage = import_mod[0]
stage = graph_construction_stage.FraudGraphConstructionStage(config, training_file)
def test_constructor(self, config: Config, training_file: str):
from stages.graph_construction_stage import FraudGraphConstructionStage
stage = FraudGraphConstructionStage(config, training_file)
assert isinstance(stage._training_data, cudf.DataFrame)

# The training datafile contains many more columns than this, but these are the four columns
# that are depended upon in the code
assert {'client_node', 'index', 'fraud_label', 'merchant_node'}.issubset(stage._column_names)

def _check_graph(
self,
stellargraph: types.ModuleType,
sg: "stellargraph.StellarGraph", # noqa: F821
expected_nodes,
expected_edges):
assert isinstance(sg, stellargraph.StellarGraph)
sg.check_graph_for_ml(features=True, expensive_check=True) # this will raise if it doesn't pass
assert not sg.is_directed()

nodes = sg.nodes()
assert set(nodes) == expected_nodes

edges = sg.edges()
assert set(edges) == expected_edges

def test_graph_construction(self,
import_mod: typing.List[types.ModuleType],
stellargraph: types.ModuleType,
test_data: dict):
graph_construction_stage = import_mod[0]
df = test_data['df']

client_features = pd.DataFrame({0: 1}, index=list(set(test_data['client_data'])))
merchant_features = pd.DataFrame({0: 1}, index=test_data['merchant_data'])

# Call _graph_construction
sg = graph_construction_stage.FraudGraphConstructionStage._graph_construction(
nodes={
'client': df.client_node, 'merchant': df.merchant_node, 'transaction': df.index
},
edges=[
zip(df.client_node, df.index),
zip(df.merchant_node, df.index),
],
node_features={
"transaction": df[['client_node', 'merchant_node']],
"client": client_features,
"merchant": merchant_features
})

self._check_graph(stellargraph, sg, test_data['expected_nodes'], test_data['expected_edges'])

def test_build_graph_features(self,
import_mod: typing.List[types.ModuleType],
stellargraph: types.ModuleType,
test_data: dict):
graph_construction_stage = import_mod[0]
sg = graph_construction_stage.FraudGraphConstructionStage._build_graph_features(test_data['df'])
self._check_graph(stellargraph, sg, test_data['expected_nodes'], test_data['expected_edges'])

def test_process_message(self,
config: Config,
import_mod: typing.List[types.ModuleType],
stellargraph: types.ModuleType,
test_data: dict):
graph_construction_stage = import_mod[0]
def test_process_message(self, dgl: types.ModuleType, config: Config, test_data: dict):
from stages import graph_construction_stage
df = test_data['df']

# The stage wants a csv file from the first 5 rows
Expand All @@ -108,12 +47,20 @@ def test_process_message(self,

# Since we used the first 5 rows as the training data, send the second 5 as inference data
meta = MessageMeta(cudf.DataFrame(df))
mm = MultiMessage(meta=meta, mess_offset=5, mess_count=5)
fgmm = stage._process_message(mm)
multi_msg = MultiMessage(meta=meta, mess_offset=5, mess_count=5)
fgmm = stage._process_message(multi_msg)

assert isinstance(fgmm, graph_construction_stage.FraudGraphMultiMessage)
assert fgmm.meta is meta
assert fgmm.mess_offset == 5
assert fgmm.mess_count == 5

self._check_graph(stellargraph, fgmm.graph, test_data['expected_nodes'], test_data['expected_edges'])
assert isinstance(fgmm.graph, dgl.DGLGraph)
fgmm.graph.check_graph_for_ml(features=True, expensive_check=True) # this will raise if it doesn't pass
assert not fgmm.graph.is_directed()

nodes = fgmm.graph.nodes()
assert set(nodes) == test_data['expected_nodes']

edges = fgmm.graph.edges()
assert set(edges) == test_data['expected_edges']
Original file line number Diff line number Diff line change
Expand Up @@ -13,84 +13,50 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import types

import pytest

import cudf

from morpheus.config import Config
from morpheus.messages import MessageMeta
from morpheus.messages import MultiMessage
from utils.dataset_manager import DatasetManager


@pytest.mark.use_python
class TestGraphSageStage:

def test_constructor(self,
config: Config,
hinsage_model: str,
gnn_fraud_detection_pipeline: types.ModuleType,
tensorflow):
from gnn_fraud_detection_pipeline.stages.graph_sage_stage import GraphSAGEStage
stage = GraphSAGEStage(config,
model_hinsage_file=hinsage_model,
batch_size=10,
sample_size=[4, 64],
record_id="test_id",
target_node="test_node")
def test_constructor(self, config: Config, model_dir: str):
from stages.graph_sage_stage import GraphSAGEStage
from stages.model import HinSAGE
stage = GraphSAGEStage(config, model_dir=model_dir, batch_size=10, record_id="test_id", target_node="test_node")

assert isinstance(stage._keras_model, tensorflow.keras.models.Model)
assert isinstance(stage._dgl_model, HinSAGE)
assert stage._batch_size == 10
assert stage._sample_size == [4, 64]
assert stage._record_id == "test_id"
assert stage._target_node == "test_node"

def test_inductive_step_hinsage(self,
config: Config,
hinsage_model: str,
gnn_fraud_detection_pipeline: types.ModuleType,
test_data: dict,
dataset_pandas: DatasetManager):
from gnn_fraud_detection_pipeline.stages.graph_construction_stage import FraudGraphConstructionStage
from gnn_fraud_detection_pipeline.stages.graph_sage_stage import GraphSAGEStage

# The column names in the saved test data will be strings, in the results they will be ints
expected_df = dataset_pandas['examples/gnn_fraud_detection_pipeline/inductive_emb.csv']
expected_df.rename(lambda x: int(x), axis=1, inplace=True)

df = test_data['df']

graph = FraudGraphConstructionStage._build_graph_features(df)

stage = GraphSAGEStage(config, model_hinsage_file=hinsage_model)
results = stage._inductive_step_hinsage(graph, stage._keras_model, test_data['index'])

assert isinstance(results, cudf.DataFrame)
assert results.index.to_arrow().to_pylist() == test_data['index']
dataset_pandas.assert_compare_df(results, expected_df)

def test_process_message(self,
config: Config,
hinsage_model: str,
gnn_fraud_detection_pipeline: types.ModuleType,
training_file: str,
model_dir: str,
test_data: dict,
dataset_pandas: DatasetManager):
from gnn_fraud_detection_pipeline.stages.graph_construction_stage import FraudGraphConstructionStage
from gnn_fraud_detection_pipeline.stages.graph_construction_stage import FraudGraphMultiMessage
from gnn_fraud_detection_pipeline.stages.graph_sage_stage import GraphSAGEMultiMessage
from gnn_fraud_detection_pipeline.stages.graph_sage_stage import GraphSAGEStage
from stages.graph_construction_stage import FraudGraphConstructionStage
from stages.graph_sage_stage import GraphSAGEMultiMessage
from stages.graph_sage_stage import GraphSAGEStage

expected_df = dataset_pandas['examples/gnn_fraud_detection_pipeline/inductive_emb.csv']
expected_df.rename(lambda x: "ind_emb_{}".format(x), axis=1, inplace=True)
expected_df.rename(lambda x: f"ind_emb_{x}", axis=1, inplace=True)

df = test_data['df']
meta = MessageMeta(cudf.DataFrame(df))
graph = FraudGraphConstructionStage._build_graph_features(df)
msg = FraudGraphMultiMessage(meta=meta, graph=graph)
multi_msg = MultiMessage(meta=meta)
construction_stage = FraudGraphConstructionStage(config, training_file)
fgmm_msg = construction_stage._process_message(multi_msg)

stage = GraphSAGEStage(config, model_hinsage_file=hinsage_model)
results = stage._process_message(msg)
stage = GraphSAGEStage(config, model_dir=model_dir)
results = stage._process_message(fgmm_msg)

assert isinstance(results, GraphSAGEMultiMessage)
assert results.meta is meta
Expand Down

0 comments on commit 9d61da2

Please sign in to comment.