Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Inputsize invariant dev #308

Open
wants to merge 16 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion chainer_chemistry/dataset/preprocessors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from chainer_chemistry.dataset.preprocessors.common import MolFeatureExtractionError # NOQA
from chainer_chemistry.dataset.preprocessors.common import type_check_num_atoms # NOQA
from chainer_chemistry.dataset.preprocessors.ecfp_preprocessor import ECFPPreprocessor # NOQA
from chainer_chemistry.dataset.preprocessors.relgat_preprocessor import RelGATPreprocessor # NOQA
from chainer_chemistry.dataset.preprocessors.ggnn_preprocessor import GGNNPreprocessor # NOQA
from chainer_chemistry.dataset.preprocessors.mol_preprocessor import MolPreprocessor # NOQA
from chainer_chemistry.dataset.preprocessors.nfp_preprocessor import NFPPreprocessor # NOQA
from chainer_chemistry.dataset.preprocessors.relgat_preprocessor import RelGATPreprocessor # NOQA
from chainer_chemistry.dataset.preprocessors.relgcn_preprocessor import RelGCNPreprocessor # NOQA
from chainer_chemistry.dataset.preprocessors.rsgcn_preprocessor import RSGCNPreprocessor # NOQA
from chainer_chemistry.dataset.preprocessors.schnet_preprocessor import SchNetPreprocessor # NOQA
Expand Down
10 changes: 10 additions & 0 deletions chainer_chemistry/dataset/preprocessors/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,16 @@ def construct_atomic_number_array(mol, out_size=-1):
'.'.format(out_size, n_atom))


def construct_is_real_node(mol, out_size=-1):
num_atoms = mol.GetNumAtoms()
if out_size < 0:
is_real_node = numpy.ones(num_atoms, dtype=numpy.float32)
else:
is_real_node = numpy.zeros(out_size, dtype=numpy.float32)
is_real_node[:num_atoms] = 1.
return is_real_node


# --- Adjacency matrix preprocessing ---
def construct_adj_matrix(mol, out_size=-1, self_connection=True):
"""Returns the adjacent matrix of the given molecule.
Expand Down
17 changes: 13 additions & 4 deletions chainer_chemistry/dataset/preprocessors/ggnn_preprocessor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from chainer_chemistry.dataset.preprocessors.common \
import construct_atomic_number_array, construct_discrete_edge_matrix
from chainer_chemistry.dataset.preprocessors.common import \
construct_atomic_number_array, construct_discrete_edge_matrix, \
construct_is_real_node
from chainer_chemistry.dataset.preprocessors.common import type_check_num_atoms
from chainer_chemistry.dataset.preprocessors.mol_preprocessor \
import MolPreprocessor
Expand All @@ -20,18 +21,20 @@ class GGNNPreprocessor(MolPreprocessor):
Setting negative value indicates do not pad returned array.
add_Hs (bool): If True, implicit Hs are added.
kekulize (bool): If True, Kekulizes the molecule.
return_is_real_node (bool): If True, also returns `is_real_node`.

"""

def __init__(self, max_atoms=-1, out_size=-1, add_Hs=False,
kekulize=False):
kekulize=False, return_is_real_node=True):
super(GGNNPreprocessor, self).__init__(
add_Hs=add_Hs, kekulize=kekulize)
if max_atoms >= 0 and out_size >= 0 and max_atoms > out_size:
raise ValueError('max_atoms {} must be less or equal to '
'out_size {}'.format(max_atoms, out_size))
self.max_atoms = max_atoms
self.out_size = out_size
self.return_is_real_node = return_is_real_node

def get_input_features(self, mol):
"""get input features
Expand All @@ -45,4 +48,10 @@ def get_input_features(self, mol):
type_check_num_atoms(mol, self.max_atoms)
atom_array = construct_atomic_number_array(mol, out_size=self.out_size)
adj_array = construct_discrete_edge_matrix(mol, out_size=self.out_size)
return atom_array, adj_array
if not self.return_is_real_node:
return atom_array, adj_array
else:
is_real_node = construct_is_real_node(
mol, self.out_size)
return atom_array, adj_array, is_real_node

17 changes: 14 additions & 3 deletions chainer_chemistry/dataset/preprocessors/nfp_preprocessor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from chainer_chemistry.dataset.preprocessors.common import construct_adj_matrix
from chainer_chemistry.dataset.preprocessors.common import \
construct_adj_matrix, construct_is_real_node
from chainer_chemistry.dataset.preprocessors.common \
import construct_atomic_number_array
from chainer_chemistry.dataset.preprocessors.common import type_check_num_atoms
Expand All @@ -21,18 +22,20 @@ class NFPPreprocessor(MolPreprocessor):
Setting negative value indicates do not pad returned array.
add_Hs (bool): If True, implicit Hs are added.
kekulize (bool): If True, Kekulizes the molecule.
return_is_real_node (bool): If True, also returns `is_real_node`.

"""

def __init__(self, max_atoms=-1, out_size=-1, add_Hs=False,
kekulize=False):
kekulize=False, return_is_real_node=True):
super(NFPPreprocessor, self).__init__(
add_Hs=add_Hs, kekulize=kekulize)
if max_atoms >= 0 and out_size >= 0 and max_atoms > out_size:
raise ValueError('max_atoms {} must be less or equal to '
'out_size {}'.format(max_atoms, out_size))
self.max_atoms = max_atoms
self.out_size = out_size
self.return_is_real_node = return_is_real_node

def get_input_features(self, mol):
"""get input features
Expand All @@ -41,9 +44,17 @@ def get_input_features(self, mol):
mol (Mol):

Returns:
atom_array (numpy.ndarray): (node,)
adj_array (numpy.ndarray): (node, node)
is_real_node (numpy.ndarray): (node,)

"""
type_check_num_atoms(mol, self.max_atoms)
atom_array = construct_atomic_number_array(mol, out_size=self.out_size)
adj_array = construct_adj_matrix(mol, out_size=self.out_size)
return atom_array, adj_array
if not self.return_is_real_node:
return atom_array, adj_array
else:
is_real_node = construct_is_real_node(
mol, out_size=self.out_size)
return atom_array, adj_array, is_real_node
40 changes: 11 additions & 29 deletions chainer_chemistry/dataset/preprocessors/relgat_preprocessor.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
from chainer_chemistry.dataset.preprocessors.common import construct_atomic_number_array # NOQA
from chainer_chemistry.dataset.preprocessors.common import construct_discrete_edge_matrix # NOQA
from chainer_chemistry.dataset.preprocessors.common import MolFeatureExtractionError # NOQA
from chainer_chemistry.dataset.preprocessors.common import type_check_num_atoms
from chainer_chemistry.dataset.preprocessors.mol_preprocessor import MolPreprocessor # NOQA
from chainer_chemistry.dataset.preprocessors.ggnn_preprocessor import GGNNPreprocessor # NOQA


class RelGATPreprocessor(MolPreprocessor):
"""RelGAT Preprocessor
class RelGATPreprocessor(GGNNPreprocessor):
"""RelGCN Preprocessor

Args:
max_atoms (int): Max number of atoms for each molecule, if the
Expand All @@ -18,27 +14,13 @@ class RelGATPreprocessor(MolPreprocessor):
If the number of atoms in the molecule is less than this value,
the returned arrays is padded to have fixed size.
Setting negative value indicates do not pad returned array.
add_Hs (bool): If True, implicit Hs are added.
kekulize (bool): If True, Kekulizes the molecule.
return_is_real_node (bool): If True, also returns `is_real_node`.

"""

def __init__(self, max_atoms=-1, out_size=-1, add_Hs=False):
super(RelGATPreprocessor, self).__init__(add_Hs=add_Hs)
if max_atoms >= 0 and out_size >= 0 and max_atoms > out_size:
raise ValueError('max_atoms {} must be less or equal to '
'out_size {}'.format(max_atoms, out_size))
self.max_atoms = max_atoms
self.out_size = out_size

def get_input_features(self, mol):
"""get input features

Args:
mol (Mol):

Returns:

"""
type_check_num_atoms(mol, self.max_atoms)
atom_array = construct_atomic_number_array(mol, out_size=self.out_size)
adj_array = construct_discrete_edge_matrix(mol, out_size=self.out_size)
return atom_array, adj_array
def __init__(self, max_atoms=-1, out_size=-1, add_Hs=False,
kekulize=False, return_is_real_node=True):
super(RelGATPreprocessor, self).__init__(
max_atoms=max_atoms, out_size=out_size, add_Hs=add_Hs,
kekulize=kekulize, return_is_real_node=return_is_real_node)
19 changes: 4 additions & 15 deletions chainer_chemistry/dataset/preprocessors/relgcn_preprocessor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from chainer_chemistry.dataset.preprocessors.ggnn_preprocessor \
import GGNNPreprocessor
from chainer_chemistry.dataset.preprocessors.ggnn_preprocessor import GGNNPreprocessor # NOQA


class RelGCNPreprocessor(GGNNPreprocessor):
Expand All @@ -17,22 +16,12 @@ class RelGCNPreprocessor(GGNNPreprocessor):
Setting negative value indicates do not pad returned array.
add_Hs (bool): If True, implicit Hs are added.
kekulize (bool): If True, Kekulizes the molecule.
return_is_real_node (bool): If True, also returns `is_real_node`.

"""

def __init__(self, max_atoms=-1, out_size=-1, add_Hs=False,
kekulize=False):
kekulize=False, return_is_real_node=True):
super(RelGCNPreprocessor, self).__init__(
max_atoms=max_atoms, out_size=out_size, add_Hs=add_Hs,
kekulize=kekulize)

def get_input_features(self, mol):
"""get input features

Args:
mol (Mol):

Returns:

"""
return super(RelGCNPreprocessor, self).get_input_features(mol)
kekulize=kekulize, return_is_real_node=return_is_real_node)
15 changes: 11 additions & 4 deletions chainer_chemistry/links/readout/ggnn_readout.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,20 @@ def __init__(self, out_dim, hidden_dim=16, nobias=False,
self.activation = activation
self.activation_agg = activation_agg

def __call__(self, h, h0=None):
def __call__(self, h, h0=None, is_real_node=None):
# --- Readout part ---
# h, h0: (minibatch, atom, ch)
# h, h0: (minibatch, node, ch)
# is_real_node: (minibatch, node)
h1 = functions.concat((h, h0), axis=2) if h0 is not None else h

g1 = functions.sigmoid(self.i_layer(h1))
g2 = self.activation(self.j_layer(h1))
# sum along atom's axis
g = self.activation_agg(functions.sum(g1 * g2, axis=1))
g = g1 * g2
if is_real_node is not None:
# mask virtual node feature to be 0
mask = self.xp.broadcast_to(
is_real_node[:, :, None], g.shape)
g = g * mask
# sum along node axis
g = self.activation_agg(functions.sum(g, axis=1))
return g
9 changes: 6 additions & 3 deletions chainer_chemistry/models/ggnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(self, out_dim, hidden_dim=16, n_layers=4,
self.concat_hidden = concat_hidden
self.weight_tying = weight_tying

def __call__(self, atom_array, adj):
def __call__(self, atom_array, adj, is_real_node=None):
"""Forward propagation

Args:
Expand All @@ -66,6 +66,9 @@ def __call__(self, atom_array, adj):
molecule's `atom_index`-th atomic number
adj (numpy.ndarray): minibatch of adjancency matrix with edge-type
information
is_real_node (numpy.ndarray): 2-dim array (minibatch, num_nodes).
1 for real node, 0 for virtual node.
If `None`, all node is considered as real node.

Returns:
~chainer.Variable: minibatch of fingerprint
Expand All @@ -82,13 +85,13 @@ def __call__(self, atom_array, adj):
message_layer_index = 0 if self.weight_tying else step
h = self.update_layers[message_layer_index](h, adj)
if self.concat_hidden:
g = self.readout_layers[step](h, h0)
g = self.readout_layers[step](h, h0, is_real_node)
g_list.append(g)

if self.concat_hidden:
return functions.concat(g_list, axis=1)
else:
g = self.readout_layers[0](h, h0)
g = self.readout_layers[0](h, h0, is_real_node)
return g

def reset_state(self):
Expand Down
9 changes: 6 additions & 3 deletions chainer_chemistry/models/relgat.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def __init__(self, out_dim, hidden_dim=16, n_heads=3, negative_slope=0.2,
self.n_edge_types = n_edge_types
self.dropout_ratio = dropout_ratio

def __call__(self, atom_array, adj):
def __call__(self, atom_array, adj, is_real_node=None):
"""Forward propagation

Args:
Expand All @@ -91,6 +91,9 @@ def __call__(self, atom_array, adj):
molecule's `atom_index`-th atomic number
adj (numpy.ndarray): minibatch of adjancency matrix with edge-type
information
is_real_node (numpy.ndarray): 2-dim array (minibatch, num_nodes).
1 for real node, 0 for virtual node.
If `None`, all node is considered as real node.

Returns:
~chainer.Variable: minibatch of fingerprint
Expand All @@ -106,11 +109,11 @@ def __call__(self, atom_array, adj):
message_layer_index = 0 if self.weight_tying else step
h = self.update_layers[message_layer_index](h, adj)
if self.concat_hidden:
g = self.readout_layers[step](h, h0)
g = self.readout_layers[step](h, h0, is_real_node)
g_list.append(g)

if self.concat_hidden:
return functions.concat(g_list, axis=1)
else:
g = self.readout_layers[0](h, h0)
g = self.readout_layers[0](h, h0, is_real_node)
return g
7 changes: 5 additions & 2 deletions chainer_chemistry/models/relgcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,15 @@ def __init__(self, out_channels=64, num_edge_type=4, ch_list=None,
self.input_type = input_type
self.scale_adj = scale_adj

def __call__(self, x, adj):
def __call__(self, x, adj, is_real_node=None):
"""

Args:
x: (batchsize, num_nodes, in_channels)
adj: (batchsize, num_edge_type, num_nodes, num_nodes)
is_real_node (numpy.ndarray): 2-dim array (minibatch, num_nodes).
1 for real node, 0 for virtual node.
If `None`, all node is considered as real node.

Returns: (batchsize, out_channels)

Expand All @@ -96,5 +99,5 @@ def __call__(self, x, adj):
adj = rescale_adj(adj)
for rgcn_conv in self.rgcn_convs:
h = functions.tanh(rgcn_conv(h, adj))
h = self.rgcn_readout(h)
h = self.rgcn_readout(h, is_real_node=is_real_node)
return h
7 changes: 5 additions & 2 deletions examples/molnet/train_molnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,11 @@ def __init__(self, graph_conv, mlp=None):
if not isinstance(mlp, chainer.Link):
self.mlp = mlp

def __call__(self, atoms, adjs):
x = self.graph_conv(atoms, adjs)
def __call__(self, atoms, adjs, is_real_node=None):
if is_real_node is None:
x = self.graph_conv(atoms, adjs)
else:
x = self.graph_conv(atoms, adjs, is_real_node)
if self.mlp:
x = self.mlp(x)
return x
Expand Down
5 changes: 3 additions & 2 deletions examples/own_dataset/predict_own_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,9 @@ def __init__(self, *args, **kwargs):
"""
super(ScaledGraphConvPredictor, self).__init__(*args, **kwargs)

def __call__(self, atoms, adjs):
h = super(ScaledGraphConvPredictor, self).__call__(atoms, adjs)
def __call__(self, atoms, adjs, is_real_node=None):
h = super(ScaledGraphConvPredictor, self).__call__(
atoms, adjs, is_real_node)
scaler_available = hasattr(self, 'scaler')
numpy_data = isinstance(h.data, numpy.ndarray)

Expand Down
7 changes: 5 additions & 2 deletions examples/own_dataset/train_own_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,11 @@ def __init__(self, graph_conv, mlp=None):
if not isinstance(mlp, chainer.Link):
self.mlp = mlp

def __call__(self, atoms, adjs):
h = self.graph_conv(atoms, adjs)
def __call__(self, atoms, adjs, is_real_node=None):
if is_real_node is None:
h = self.graph_conv(atoms, adjs)
else:
h = self.graph_conv(atoms, adjs, is_real_node)
if self.mlp:
h = self.mlp(h)
return h
Expand Down
5 changes: 3 additions & 2 deletions examples/qm9/predict_qm9.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@ def __init__(self, *args, **kwargs):
"""
super(ScaledGraphConvPredictor, self).__init__(*args, **kwargs)

def __call__(self, atoms, adjs):
h = super(ScaledGraphConvPredictor, self).__call__(atoms, adjs)
def __call__(self, atoms, adjs, is_real_node=None):
h = super(ScaledGraphConvPredictor, self).__call__(
atoms, adjs, is_real_node)
scaler_available = hasattr(self, 'scaler')
numpy_data = isinstance(h.data, numpy.ndarray)

Expand Down
9 changes: 6 additions & 3 deletions examples/qm9/train_qm9.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,11 @@ def __init__(self, graph_conv, mlp=None):
if not isinstance(mlp, chainer.Link):
self.mlp = mlp

def __call__(self, atoms, adjs):
x = self.graph_conv(atoms, adjs)
def __call__(self, atoms, adjs, is_real_node=None):
if is_real_node is None:
x = self.graph_conv(atoms, adjs)
else:
x = self.graph_conv(atoms, adjs, is_real_node)
if self.mlp:
x = self.mlp(x)
return x
Expand Down Expand Up @@ -245,7 +248,7 @@ def main():
dataset = D.get_qm9(preprocessor, labels=labels)

# Cache the laded dataset.
os.makedirs(cache_dir)
os.makedirs(cache_dir, exist_ok=True)
NumpyTupleDataset.save(dataset_cache_path, dataset)

# Scale the label values, if necessary.
Expand Down
Loading