diff --git a/examples/gnn_fraud_detection_pipeline/requirements.yml b/examples/gnn_fraud_detection_pipeline/requirements.yml index cc02b5c502..01f641c047 100644 --- a/examples/gnn_fraud_detection_pipeline/requirements.yml +++ b/examples/gnn_fraud_detection_pipeline/requirements.yml @@ -20,6 +20,4 @@ channels: - dglteam/label/cu118 dependencies: - cuml=23.06 - - dask>=2023.1.1 - dgl=1.0.2 - - distributed>=2023.1.1 diff --git a/tests/examples/gnn_fraud_detection_pipeline/conftest.py b/tests/examples/gnn_fraud_detection_pipeline/conftest.py index ed8e690878..043b3a588c 100644 --- a/tests/examples/gnn_fraud_detection_pipeline/conftest.py +++ b/tests/examples/gnn_fraud_detection_pipeline/conftest.py @@ -26,32 +26,24 @@ "installing these additional dependencies") -@pytest.fixture(autouse=True, scope='session') -def stellargraph(fail_missing: bool): +@pytest.fixture(name="dgl", autouse=True, scope='session') +def dgl_fixture(fail_missing: bool): """ - All of the tests in this subdir require stellargraph + All of the tests in this subdir require dgl """ - yield import_or_skip("stellargraph", reason=SKIP_REASON, fail_missing=fail_missing) + yield import_or_skip("dgl", reason=SKIP_REASON, fail_missing=fail_missing) -@pytest.fixture(autouse=True, scope='session') -def cuml(fail_missing: bool): +@pytest.fixture(name="cuml", autouse=True, scope='session') +def cuml_fixture(fail_missing: bool): """ All of the tests in this subdir require cuml """ yield import_or_skip("cuml", reason=SKIP_REASON, fail_missing=fail_missing) -@pytest.fixture(autouse=True, scope='session') -def tensorflow(fail_missing: bool): - """ - All of the tests in this subdir require tensorflow - """ - yield import_or_skip("tensorflow", reason=SKIP_REASON, fail_missing=fail_missing) - - -@pytest.fixture -def config(config): +@pytest.fixture(name="config") +def config_fixture(config): """ The GNN fraud detection pipeline utilizes the "other" pipeline mode. """ @@ -60,38 +52,36 @@ def config(config): yield config -@pytest.fixture -def example_dir(): +@pytest.fixture(name="example_dir") +def example_dir_fixture(): yield os.path.join(TEST_DIRS.examples_dir, 'gnn_fraud_detection_pipeline') -@pytest.fixture -def training_file(example_dir: str): +@pytest.fixture(name="training_file") +def training_file_fixture(example_dir: str): yield os.path.join(example_dir, 'training.csv') -@pytest.fixture -def hinsage_model(example_dir: str): - yield os.path.join(example_dir, 'model/hinsage-model.pt') +@pytest.fixture(name="model_dir") +def model_dir_fixture(example_dir: str): + yield os.path.join(example_dir, 'model') -@pytest.fixture -def xgb_model(example_dir: str): - yield os.path.join(example_dir, 'model/xgb-model.pt') +@pytest.fixture(name="xgb_model") +def xgb_model_fixture(model_dir: str): + yield os.path.join(model_dir, 'xgb.pt') # Some of the code inside gnn_fraud_detection_pipeline performs some relative imports in the form of: # from .mod import Class # For this reason we need to ensure that the examples dir is in the sys.path first -@pytest.fixture -def gnn_fraud_detection_pipeline(request: pytest.FixtureRequest, restore_sys_path, reset_plugins): - sys.path.append(TEST_DIRS.examples_dir) - import gnn_fraud_detection_pipeline - yield gnn_fraud_detection_pipeline +@pytest.fixture(name="ex_in_sys_path", autouse=True) +def ex_in_sys_path_fixture(example_dir: str, restore_sys_path, reset_plugins): # pylint: disable=unused-argument + sys.path.append(example_dir) -@pytest.fixture -def test_data(): +@pytest.fixture(name="test_data") +def test_data_fixture(): """ Construct test data, a small DF of 10 rows which we will build a graph from The nodes in our graph will be the unique values from each of our three columns, and the index is also @@ -132,9 +122,11 @@ def test_data(): assert len(expected_edges) == 20 # ensuring test data & assumptions are correct - yield dict(index=index, - client_data=client_data, - merchant_data=merchant_data, - df=df, - expected_nodes=expected_nodes, - expected_edges=expected_edges) + yield { + "index": index, + "client_data": client_data, + "merchant_data": merchant_data, + "df": df, + "expected_nodes": expected_nodes, + "expected_edges": expected_edges + } diff --git a/tests/examples/gnn_fraud_detection_pipeline/test_classification_stage.py b/tests/examples/gnn_fraud_detection_pipeline/test_classification_stage.py index 32eb195df2..09122864a8 100644 --- a/tests/examples/gnn_fraud_detection_pipeline/test_classification_stage.py +++ b/tests/examples/gnn_fraud_detection_pipeline/test_classification_stage.py @@ -25,25 +25,15 @@ @pytest.mark.use_python class TestClassificationStage: - def test_constructor( - self, - config: Config, - xgb_model: str, - gnn_fraud_detection_pipeline: types.ModuleType, # pylint: disable=unused-argument - cuml: types.ModuleType): - from gnn_fraud_detection_pipeline.stages.classification_stage import ClassificationStage + def test_constructor(self, config: Config, xgb_model: str, cuml: types.ModuleType): + from stages.classification_stage import ClassificationStage stage = ClassificationStage(config, xgb_model) assert isinstance(stage._xgb_model, cuml.ForestInference) - def test_process_message( - self, - config: Config, - xgb_model: str, - gnn_fraud_detection_pipeline: types.ModuleType, # pylint: disable=unused-argument - dataset_cudf: DatasetManager): - from gnn_fraud_detection_pipeline.stages.classification_stage import ClassificationStage - from gnn_fraud_detection_pipeline.stages.graph_sage_stage import GraphSAGEMultiMessage + def test_process_message(self, config: Config, xgb_model: str, dataset_cudf: DatasetManager): + from stages.classification_stage import ClassificationStage + from stages.graph_sage_stage import GraphSAGEMultiMessage df = dataset_cudf['examples/gnn_fraud_detection_pipeline/inductive_emb.csv'] df.rename(lambda x: f"ind_emb_{x}", axis=1, inplace=True) diff --git a/tests/examples/gnn_fraud_detection_pipeline/test_graph_construction_stage.py b/tests/examples/gnn_fraud_detection_pipeline/test_graph_construction_stage.py index d39c00d994..73ac9c4555 100644 --- a/tests/examples/gnn_fraud_detection_pipeline/test_graph_construction_stage.py +++ b/tests/examples/gnn_fraud_detection_pipeline/test_graph_construction_stage.py @@ -13,12 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os import types -import typing from io import StringIO -import pandas as pd import pytest import cudf @@ -26,80 +23,22 @@ from morpheus.config import Config from morpheus.messages import MessageMeta from morpheus.messages import MultiMessage -from utils import TEST_DIRS @pytest.mark.use_python -@pytest.mark.import_mod( - [os.path.join(TEST_DIRS.examples_dir, 'gnn_fraud_detection_pipeline/stages/graph_construction_stage.py')]) class TestGraphConstructionStage: - def test_constructor(self, config: Config, training_file: str, import_mod: typing.List[types.ModuleType]): - graph_construction_stage = import_mod[0] - stage = graph_construction_stage.FraudGraphConstructionStage(config, training_file) + def test_constructor(self, config: Config, training_file: str): + from stages.graph_construction_stage import FraudGraphConstructionStage + stage = FraudGraphConstructionStage(config, training_file) assert isinstance(stage._training_data, cudf.DataFrame) # The training datafile contains many more columns than this, but these are the four columns # that are depended upon in the code assert {'client_node', 'index', 'fraud_label', 'merchant_node'}.issubset(stage._column_names) - def _check_graph( - self, - stellargraph: types.ModuleType, - sg: "stellargraph.StellarGraph", # noqa: F821 - expected_nodes, - expected_edges): - assert isinstance(sg, stellargraph.StellarGraph) - sg.check_graph_for_ml(features=True, expensive_check=True) # this will raise if it doesn't pass - assert not sg.is_directed() - - nodes = sg.nodes() - assert set(nodes) == expected_nodes - - edges = sg.edges() - assert set(edges) == expected_edges - - def test_graph_construction(self, - import_mod: typing.List[types.ModuleType], - stellargraph: types.ModuleType, - test_data: dict): - graph_construction_stage = import_mod[0] - df = test_data['df'] - - client_features = pd.DataFrame({0: 1}, index=list(set(test_data['client_data']))) - merchant_features = pd.DataFrame({0: 1}, index=test_data['merchant_data']) - - # Call _graph_construction - sg = graph_construction_stage.FraudGraphConstructionStage._graph_construction( - nodes={ - 'client': df.client_node, 'merchant': df.merchant_node, 'transaction': df.index - }, - edges=[ - zip(df.client_node, df.index), - zip(df.merchant_node, df.index), - ], - node_features={ - "transaction": df[['client_node', 'merchant_node']], - "client": client_features, - "merchant": merchant_features - }) - - self._check_graph(stellargraph, sg, test_data['expected_nodes'], test_data['expected_edges']) - - def test_build_graph_features(self, - import_mod: typing.List[types.ModuleType], - stellargraph: types.ModuleType, - test_data: dict): - graph_construction_stage = import_mod[0] - sg = graph_construction_stage.FraudGraphConstructionStage._build_graph_features(test_data['df']) - self._check_graph(stellargraph, sg, test_data['expected_nodes'], test_data['expected_edges']) - - def test_process_message(self, - config: Config, - import_mod: typing.List[types.ModuleType], - stellargraph: types.ModuleType, - test_data: dict): - graph_construction_stage = import_mod[0] + def test_process_message(self, dgl: types.ModuleType, config: Config, test_data: dict): + from stages import graph_construction_stage df = test_data['df'] # The stage wants a csv file from the first 5 rows @@ -108,12 +47,20 @@ def test_process_message(self, # Since we used the first 5 rows as the training data, send the second 5 as inference data meta = MessageMeta(cudf.DataFrame(df)) - mm = MultiMessage(meta=meta, mess_offset=5, mess_count=5) - fgmm = stage._process_message(mm) + multi_msg = MultiMessage(meta=meta, mess_offset=5, mess_count=5) + fgmm = stage._process_message(multi_msg) assert isinstance(fgmm, graph_construction_stage.FraudGraphMultiMessage) assert fgmm.meta is meta assert fgmm.mess_offset == 5 assert fgmm.mess_count == 5 - self._check_graph(stellargraph, fgmm.graph, test_data['expected_nodes'], test_data['expected_edges']) + assert isinstance(fgmm.graph, dgl.DGLGraph) + fgmm.graph.check_graph_for_ml(features=True, expensive_check=True) # this will raise if it doesn't pass + assert not fgmm.graph.is_directed() + + nodes = fgmm.graph.nodes() + assert set(nodes) == test_data['expected_nodes'] + + edges = fgmm.graph.edges() + assert set(edges) == test_data['expected_edges'] diff --git a/tests/examples/gnn_fraud_detection_pipeline/test_graph_sage_stage.py b/tests/examples/gnn_fraud_detection_pipeline/test_graph_sage_stage.py index b8a449de48..9a686157f1 100644 --- a/tests/examples/gnn_fraud_detection_pipeline/test_graph_sage_stage.py +++ b/tests/examples/gnn_fraud_detection_pipeline/test_graph_sage_stage.py @@ -13,84 +13,50 @@ # See the License for the specific language governing permissions and # limitations under the License. -import types - import pytest import cudf from morpheus.config import Config from morpheus.messages import MessageMeta +from morpheus.messages import MultiMessage from utils.dataset_manager import DatasetManager @pytest.mark.use_python class TestGraphSageStage: - def test_constructor(self, - config: Config, - hinsage_model: str, - gnn_fraud_detection_pipeline: types.ModuleType, - tensorflow): - from gnn_fraud_detection_pipeline.stages.graph_sage_stage import GraphSAGEStage - stage = GraphSAGEStage(config, - model_hinsage_file=hinsage_model, - batch_size=10, - sample_size=[4, 64], - record_id="test_id", - target_node="test_node") + def test_constructor(self, config: Config, model_dir: str): + from stages.graph_sage_stage import GraphSAGEStage + from stages.model import HinSAGE + stage = GraphSAGEStage(config, model_dir=model_dir, batch_size=10, record_id="test_id", target_node="test_node") - assert isinstance(stage._keras_model, tensorflow.keras.models.Model) + assert isinstance(stage._dgl_model, HinSAGE) assert stage._batch_size == 10 - assert stage._sample_size == [4, 64] assert stage._record_id == "test_id" assert stage._target_node == "test_node" - def test_inductive_step_hinsage(self, - config: Config, - hinsage_model: str, - gnn_fraud_detection_pipeline: types.ModuleType, - test_data: dict, - dataset_pandas: DatasetManager): - from gnn_fraud_detection_pipeline.stages.graph_construction_stage import FraudGraphConstructionStage - from gnn_fraud_detection_pipeline.stages.graph_sage_stage import GraphSAGEStage - - # The column names in the saved test data will be strings, in the results they will be ints - expected_df = dataset_pandas['examples/gnn_fraud_detection_pipeline/inductive_emb.csv'] - expected_df.rename(lambda x: int(x), axis=1, inplace=True) - - df = test_data['df'] - - graph = FraudGraphConstructionStage._build_graph_features(df) - - stage = GraphSAGEStage(config, model_hinsage_file=hinsage_model) - results = stage._inductive_step_hinsage(graph, stage._keras_model, test_data['index']) - - assert isinstance(results, cudf.DataFrame) - assert results.index.to_arrow().to_pylist() == test_data['index'] - dataset_pandas.assert_compare_df(results, expected_df) - def test_process_message(self, config: Config, - hinsage_model: str, - gnn_fraud_detection_pipeline: types.ModuleType, + training_file: str, + model_dir: str, test_data: dict, dataset_pandas: DatasetManager): - from gnn_fraud_detection_pipeline.stages.graph_construction_stage import FraudGraphConstructionStage - from gnn_fraud_detection_pipeline.stages.graph_construction_stage import FraudGraphMultiMessage - from gnn_fraud_detection_pipeline.stages.graph_sage_stage import GraphSAGEMultiMessage - from gnn_fraud_detection_pipeline.stages.graph_sage_stage import GraphSAGEStage + from stages.graph_construction_stage import FraudGraphConstructionStage + from stages.graph_sage_stage import GraphSAGEMultiMessage + from stages.graph_sage_stage import GraphSAGEStage expected_df = dataset_pandas['examples/gnn_fraud_detection_pipeline/inductive_emb.csv'] - expected_df.rename(lambda x: "ind_emb_{}".format(x), axis=1, inplace=True) + expected_df.rename(lambda x: f"ind_emb_{x}", axis=1, inplace=True) df = test_data['df'] meta = MessageMeta(cudf.DataFrame(df)) - graph = FraudGraphConstructionStage._build_graph_features(df) - msg = FraudGraphMultiMessage(meta=meta, graph=graph) + multi_msg = MultiMessage(meta=meta) + construction_stage = FraudGraphConstructionStage(config, training_file) + fgmm_msg = construction_stage._process_message(multi_msg) - stage = GraphSAGEStage(config, model_hinsage_file=hinsage_model) - results = stage._process_message(msg) + stage = GraphSAGEStage(config, model_dir=model_dir) + results = stage._process_message(fgmm_msg) assert isinstance(results, GraphSAGEMultiMessage) assert results.meta is meta