Skip to content

Commit

Permalink
Generic kumo_loader function that can points to snowflake/s3/local …
Browse files Browse the repository at this point in the history
…dataset (rusty1s#156)

* Add a generic `kumo_loader` function that can points to snowflake/s3/local dataset

* Clean up code, switch test env from snowflake to s3 to save costs from CI

* lint

* change test data location to local
  • Loading branch information
JiaxuanYou authored Jan 11, 2022
1 parent cd917a7 commit 8dffbf1
Show file tree
Hide file tree
Showing 8 changed files with 135 additions and 93 deletions.
7 changes: 4 additions & 3 deletions benchmark/train/configs/financial.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@ snowflake:
warehouse: WH_XS
database: kumo
dataset:
format: snowflake
location: local
name: Financial
format: csv
data_dir: 's3://kumo-datasets'
metadata_dir: 'test/csv_data'
name: FINANCIAL
target_table: LOAN
target_column: STATUS
task: node
Expand Down
1 change: 0 additions & 1 deletion benchmark/train/configs/financial_regression.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ snowflake:
database: kumo
dataset:
format: snowflake
location: local
name: Financial
target_table: LOAN
target_column: AMOUNT
Expand Down
1 change: 0 additions & 1 deletion benchmark/train/configs/imdb_classification.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ snowflake:
database: kumo
dataset:
format: snowflake
location: local
name: IMDB
target_table: U2BASE
target_column: RATING
Expand Down
34 changes: 18 additions & 16 deletions kumo/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,53 +20,55 @@ def set_cfg(cfg):
:return: configuration use by the experiment.
'''

# Set defaults
# Change defaults from PyG GraphGym version
cfg.model.type = 'heterognn'
cfg.gnn.head = 'node'
cfg.dataset.name = 'Financial'
cfg.gnn.dim_emb = 16

# Overwrite GraphGym scheduler
# (might improve default optimizer after more training experiences
# on databases)
cfg.optim.scheduler = 'none'

# ----------------------------------------------------------------------- #
# Snowflake options
# New options in Kumo
# ----------------------------------------------------------------------- #
cfg.snowflake = CN()

cfg.snowflake = CN()
# Account name
cfg.snowflake.account = 'xva19026'

# User name
cfg.snowflake.user = ''

# Password
cfg.snowflake.password = ''

# Warehouse name
cfg.snowflake.warehouse = 'WH_XS'

# Database name
cfg.snowflake.database = 'kumo'

# directory for dataset
cfg.dataset.data_dir = ''
# directory for dataset metadata
cfg.dataset.metadata_dir = 'test/csv_data'
# Default to random split
cfg.dataset.split_type = 'random'
# If split_type == 'column', split by the values of the column:
# lower values of this column will be in the training split;
# the highest values of this column will be in the test split.
# Restriction: the split column has to be in the prediction target table.
cfg.dataset.split_column = None
# TODO: Duplicate `label_table` and `label_column`
cfg.dataset.target_table = cfg.dataset.label_table
cfg.dataset.target_column = cfg.dataset.label_column
# Tables where shallow embeddings are included for feature augmentation
cfg.dataset.augment_table = []

# early stopping configs
cfg.optim.early_stopping = False
cfg.optim.min_delta = 0.001
# if None, set patience = total num epochs / 10
cfg.optim.patience = None

# Overwrite GraphGym scheduler
# (might improve default optimizer after more training experiences
# on databases)
cfg.optim.scheduler = 'none'

# TODO: Duplicate `label_table` and `label_column`
cfg.dataset.target_table = cfg.dataset.label_table
cfg.dataset.target_column = cfg.dataset.label_column


set_cfg(cfg)
49 changes: 0 additions & 49 deletions kumo/custom/loader/financial.py

This file was deleted.

104 changes: 101 additions & 3 deletions kumo/train/loader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import os
import os.path as osp
import pandas as pd
import torch
Expand All @@ -13,6 +14,8 @@
from kumo.store import Store
from torch_geometric.graphgym.loader import index2mask
from torch_geometric.data import InMemoryDataset
from kumo.scan import DatabaseMetadata, DatabaseStats
from kumo.connector import SnowflakeConnector, CSVConnector


def preprocess_dataset(data: Store):
Expand All @@ -28,7 +31,99 @@ def preprocess_dataset(data: Store):
for key in data.metadata()[0]:
if "x" not in data[key]:
data[key].x = torch.ones(data[key].num_nodes, 1)
# todo: necessary data transformation
return data


def kumo_loader(cfg: CfgNode):
"""
Load dataset following Kumo support APIs
Args:
cfg (CfgNode): Global config object
Returns: Kumo Store object
"""

# Load dataset metadata
metadata_dir = cfg.dataset.metadata_dir
assert metadata_dir is not None, \
"cfg.dataset.metadata_dir is required in yaml config"
if not osp.isabs(metadata_dir):
metadata_dir = osp.join(osp.dirname(osp.realpath(__file__)), "..",
"..", metadata_dir)
if cfg.dataset.name not in metadata_dir:
metadata_dir = osp.join(metadata_dir, cfg.dataset.name)
if ".yml" not in metadata_dir and ".yaml" not in metadata_dir:
metadata_dir = osp.join(metadata_dir, "metadata.yml") # default
dbmeta = DatabaseMetadata.load(metadata_dir)
dbmeta.set_target(cfg.dataset.target_table, cfg.dataset.target_column)
if cfg.dataset.split_column is not None:
dbmeta.set_split(cfg.dataset.target_table, cfg.dataset.split_column)

# Load dataset
if cfg.dataset.format == "snowflake":
account = cfg.snowflake.account if cfg.snowflake.account is not None \
else os.getenv("SNOWFLAKE_ACCOUNT")
user = cfg.snowflake.user if cfg.snowflake.user is not None \
else os.getenv("SNOWFLAKE_USER")
password = cfg.snowflake.password \
if cfg.snowflake.password is not None \
else os.getenv("SNOWFLAKE_PASSWORD")
warehouse = cfg.snowflake.warehouse \
if cfg.snowflake.warehouse is not None \
else os.getenv("SNOWFLAKE_WAREHOUSE")
database = cfg.snowflake.database \
if cfg.snowflake.database is not None \
else os.getenv("SNOWFLAKE_DATABASE")
assert account is not None, \
"SNOWFLAKE_ACCOUNT required in environment variable or yaml config"
assert user is not None, \
"SNOWFLAKE_USER required in environment variable or yaml config"
assert password is not None, \
"SNOWFLAKE_PASSWORD required in " \
"environment variable or yaml config"
assert warehouse is not None, \
"SNOWFLAKE_WAREHOUSE required in " \
"environment variable or yaml config"
assert database is not None, \
"SNOWFLAKE_DATABASE required in " \
"environment variable or yaml config"
connector = SnowflakeConnector(account=account, user=user,
password=password, warehouse=warehouse,
database=database,
schema=cfg.dataset.name)
elif cfg.dataset.format == "csv":
data_dir = cfg.dataset.data_dir
assert data_dir is not None, "cfg.dataset.data_dir is required"
if "s3:" in data_dir:
if cfg.dataset.name not in data_dir:
data_dir = osp.join(data_dir, cfg.dataset.name.lower(), "csv")
elif not osp.isabs(data_dir):
data_dir = osp.join(osp.dirname(osp.realpath(__file__)), "..",
"..", data_dir)
if cfg.dataset.name not in data_dir:
data_dir = osp.join(data_dir, cfg.dataset.name)
else:
raise ValueError("{} not found".format(cfg.dataset.data_dir))

connector = CSVConnector(data_dir, na_values="?")
else:
raise ValueError("Unrecognized database format: {}".format(
cfg.dataset.format))

dbstats = DatabaseStats.from_connector(connector, dbmeta)
dbstats.print_summary()

data = Store.from_connector(connector, dbmeta, dbstats,
cfg.dataset.augment_table)

# todo: Temporary work around for financial prediction task
# todo: in order to compare results with baselines
if cfg.dataset.target_table == "LOAN" and \
cfg.dataset.target_column == "STATUS":
data['LOAN'].y[data["LOAN"].y == 2] = 0
data["LOAN"].y[data["LOAN"].y == 3] = 1

return data


Expand All @@ -47,13 +142,16 @@ def load_dataset(cfg: CfgNode, **kwargs):
dataset_dir = osp.join(cfg.dataset.dir, name)
target_table = cfg.dataset.target_table
target_column = cfg.dataset.target_column
# Try to load customized data format
# First try to load with any customized data loader
for func in register.loader_dict.values():
dataset = func(format, name, dataset_dir, target_table, target_column,
cfg.dataset.split_column)
if dataset is not None:
return dataset
if format == "PyG":
# Then try to load with standard Kumo data loaders
if format == "snowflake" or format == "csv":
dataset = kumo_loader(cfg)
elif format == "PyG":
try:
dataset = getattr(pyg_dataset,
name)(dataset_dir,
Expand Down
21 changes: 12 additions & 9 deletions test/train/configs/financial.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,40 +6,43 @@ snowflake:
warehouse: WH_XS
database: kumo
dataset:
format: snowflake
location: local
name: Financial
format: csv
data_dir: 'test/csv_data'
metadata_dir: 'test/csv_data'
name: FINANCIAL
target_table: LOAN
target_column: STATUS
task: node
task_type: classification
split: [0.8, 0.1, 0.1]
split_mode: random
split_column: DATE # only needed when split_mode = column
encoder: True
encoder_name: db
encoder_bn: True
train:
mode: db_fast
sampler: neighbor
sampler: full_batch
neighbor_sizes: [10,10,10,10]
batch_size: 512
eval_period: 20
ckpt_period: 100
val:
sampler: neighbor
sampler: full_batch
model:
type: heterognn
gnn:
layers_pre_mp: 1
layers_pre_mp: 2
layers_mp: 3
layers_post_mp: 1
dim_inner: 64
layer_type: SAGEConv
stage_type: stack
batchnorm: True
batchnorm: False
act: prelu
dropout: 0.0
agg: mean
optim:
optimizer: adam
base_lr: 0.01
max_epoch: 200
base_lr: 0.001
max_epoch: 50
11 changes: 0 additions & 11 deletions test/train/test_trainer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os.path as osp
from os import environ

from collections import namedtuple

Expand All @@ -21,16 +20,6 @@ def test_loader():
Args = namedtuple("Args", ["cfg_file", "repeat", "opts"])
args = Args(osp.join(root, "configs", "financial.yaml"), 1, [])

if (
environ.get("SNOWFLAKE_ACCOUNT") is None
or environ.get("SNOWFLAKE_USER") is None
or environ.get("SNOWFLAKE_PASSWORD") is None
):
raise Exception(
"Set Snowflake env (SNOWFLAKE_ACCOUNT, "
"SNOWFLAKE_USER, SNOWFLAKE_PASSWORD)"
)

load_cfg(cfg, args)
dump_cfg(cfg)
# Repeat for different random seeds
Expand Down

0 comments on commit 8dffbf1

Please sign in to comment.