Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NNVM][TESTING] Add two testing symbols: dqn and dcgan #1294

Merged
merged 6 commits into from
Jun 17, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions nnvm/python/nnvm/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,6 @@
from . import mlp
from . import resnet
from . import vgg
from . import dcgan
from . import dqn
from . import yolo2_detection
90 changes: 90 additions & 0 deletions nnvm/python/nnvm/testing/dcgan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# pylint: disable=unused-argument
"""
Symbol of the generator of DCGAN

Adopted from:
https://github.com/tqchen/mxnet-gan/blob/master/mxgan/generator.py

Reference:
Radford, Alec, Luke Metz, and Soumith Chintala.
"Unsupervised representation learning with deep convolutional generative adversarial networks."
arXiv preprint arXiv:1511.06434 (2015).
"""
from .. import symbol as sym
from . utils import create_workload

def deconv2d(data, ishape, oshape, kshape, name, stride=(2, 2)):
"""a deconv layer that enlarges the feature map"""
target_shape = (oshape[-2], oshape[-1])

pad_y = (kshape[0] - 1) // 2
pad_x = (kshape[1] - 1) // 2
adj_y = (target_shape[0] + 2 * pad_y - kshape[0]) % stride[0]
adj_x = (target_shape[1] + 2 * pad_x - kshape[1]) % stride[1]

net = sym.conv2d_transpose(data,
kernel_size=kshape,
strides=stride,
channels=oshape[0],
padding=(pad_y, pad_x),
output_padding=(adj_y, adj_x),
use_bias=False,
name=name)
return net

def deconv2d_bn_relu(data, prefix, **kwargs):
"""a block of deconv + batch norm + relu"""
eps = 1e-5 + 1e-12
net = deconv2d(data, name="%s_deconv" % prefix, **kwargs)
net = sym.batch_norm(net, epsilon=eps, name="%s_bn" % prefix)
net = sym.relu(net, name="%s_act" % prefix)
return net

def get_symbol(oshape, ngf=128, code=None):
"""get symbol of dcgan generator"""
assert oshape[-1] == 32, "Only support 32x32 image"
assert oshape[-2] == 32, "Only support 32x32 image"

code = sym.Variable("data") if code is None else code
net = sym.dense(code, name="g1", units=4*4*ngf*4, use_bias=False)
net = sym.relu(net)
# 4 x 4
net = sym.reshape(net, shape=(-1, ngf * 4, 4, 4))
# 8 x 8
net = deconv2d_bn_relu(
net, ishape=(ngf * 4, 4, 4), oshape=(ngf * 2, 8, 8), kshape=(4, 4), prefix="g2")
# 16x16
net = deconv2d_bn_relu(
net, ishape=(ngf * 2, 8, 8), oshape=(ngf, 16, 16), kshape=(4, 4), prefix="g3")
# 32x32
net = deconv2d(
net, ishape=(ngf, 16, 16), oshape=oshape[-3:], kshape=(4, 4), name="g4_deconv")
net = sym.tanh(net)
return net


def get_workload(batch_size, oshape=(3, 32, 32), ngf=128, random_len=100, dtype="float32"):
"""Get benchmark workload for a DCGAN generator

Parameters
----------
batch_size : int
The batch size used in the model
oshape : tuple, optional
The shape of output image, layout="CHW"
ngf: int, optional
The number of final feature maps in the generator
random_len : int, optional
The length of random input
dtype : str, optional
The data type

Returns
-------
net : nnvm.symbol
The computational graph
params : dict of str to NDArray
The parameters.
"""
net = get_symbol(oshape=oshape, ngf=ngf)
return create_workload(net, batch_size, (random_len, ), dtype)
71 changes: 71 additions & 0 deletions nnvm/python/nnvm/testing/dqn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""
Symbol of Nature DQN

Reference:
Mnih, Volodymyr, et al. "Human-level control through deep reinforcement learning."
Nature 518.7540 (2015): 529.
"""

from .. import symbol as sym
from . utils import create_workload

def get_symbol(num_actions=18):
"""get symbol of nature dqn"""
data = sym.Variable(name='data')
net = sym.conv2d(data, kernel_size=(8, 8), strides=(4, 4), padding=(0, 0),
channels=32, name='conv1')
net = sym.relu(net, name='relu1')
net = sym.conv2d(net, kernel_size=(4, 4), strides=(2, 2), padding=(0, 0),
channels=64, name='conv2')
net = sym.relu(net, name='relu2')
net = sym.conv2d(net, kernel_size=(3, 3), strides=(1, 1), padding=(0, 0),
channels=64, name='conv3')
net = sym.relu(net, name='relu3')
net = sym.flatten(net, name='flatten')
net = sym.dense(net, units=512, name='fc4')
net = sym.relu(net, name='relu4')
net = sym.dense(net, units=num_actions, name='fc5')

return net


def get_workload(batch_size, num_actions=18, image_shape=(4, 84, 84), dtype="float32"):
"""Get benchmark workload for a Deep Q Network

Parameters
----------
batch_size : int
The batch size used in the model
num_actions : int, optional
Number of actions
image_shape : tuple, optional
The input image shape
dtype : str, optional
The data type

Returns
-------
net : nnvm.symbol
The computational graph
params : dict of str to NDArray
The parameters.
"""
net = get_symbol(num_actions=num_actions)
return create_workload(net, batch_size, image_shape, dtype)
10 changes: 9 additions & 1 deletion nnvm/tests/python/frontend/mxnet/model_zoo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""MXNet and NNVM model zoo."""
from __future__ import absolute_import
from . import mlp, resnet, vgg
from . import mlp, resnet, vgg, dqn, dcgan
import nnvm.testing

__all__ = ['mx_mlp', 'nnvm_mlp', 'mx_resnet', 'nnvm_resnet', 'mx_vgg', 'nnvm_vgg']
Expand All @@ -26,3 +26,11 @@
mx_vgg[num_layer] = vgg.get_symbol(_num_class, num_layer)
nnvm_vgg[num_layer] = nnvm.testing.vgg.get_workload(
1, _num_class, num_layers=num_layer)[0]

# dqn
mx_dqn = dqn.get_symbol()
nnvm_dqn = nnvm.testing.dqn.get_workload(1)[0]

# dcgan generator
mx_dcgan = dcgan.get_symbol()
nnvm_dcgan = nnvm.testing.dcgan.get_workload(1)[0]
63 changes: 63 additions & 0 deletions nnvm/tests/python/frontend/mxnet/model_zoo/dcgan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# pylint: disable=unused-argument
"""
The MXNet symbol of DCGAN generator

Adopted from:
https://github.com/tqchen/mxnet-gan/blob/master/mxgan/generator.py

Reference:
Radford, Alec, Luke Metz, and Soumith Chintala.
"Unsupervised representation learning with deep convolutional generative adversarial networks."
arXiv preprint arXiv:1511.06434 (2015).
"""

import mxnet as mx

def deconv2d(data, ishape, oshape, kshape, name, stride=(2, 2)):
"""a deconv layer that enlarges the feature map"""
target_shape = (oshape[-2], oshape[-1])
pad_y = (kshape[0] - 1) // 2
pad_x = (kshape[1] - 1) // 2
adj_y = (target_shape[0] + 2 * pad_y - kshape[0]) % stride[0]
adj_x = (target_shape[1] + 2 * pad_x - kshape[1]) % stride[1]

net = mx.sym.Deconvolution(data,
kernel=kshape,
stride=stride,
pad=(pad_y, pad_x),
adj=(adj_y, adj_x),
num_filter=oshape[0],
no_bias=True,
name=name)
return net

def deconv2d_bn_relu(data, prefix, **kwargs):
"""a block of deconv + batch norm + relu"""
eps = 1e-5 + 1e-12

net = deconv2d(data, name="%s_deconv" % prefix, **kwargs)
net = mx.sym.BatchNorm(net, eps=eps, name="%s_bn" % prefix)
net = mx.sym.Activation(net, name="%s_act" % prefix, act_type='relu')
return net

def get_symbol(oshape=(3, 32, 32), ngf=128, code=None):
"""get symbol of dcgan generator"""
assert oshape[-1] == 32, "Only support 32x32 image"
assert oshape[-2] == 32, "Only support 32x32 image"

code = mx.sym.Variable("data") if code is None else code
net = mx.sym.FullyConnected(code, name="g1", num_hidden=4*4*ngf*4, no_bias=True, flatten=False)
net = mx.sym.Activation(net, act_type='relu')
# 4 x 4
net = mx.sym.reshape(net, shape=(-1, ngf * 4, 4, 4))
# 8 x 8
net = deconv2d_bn_relu(
net, ishape=(ngf * 4, 4, 4), oshape=(ngf * 2, 8, 8), kshape=(4, 4), prefix="g2")
# 16x16
net = deconv2d_bn_relu(
net, ishape=(ngf * 2, 8, 8), oshape=(ngf, 16, 16), kshape=(4, 4), prefix="g3")
# 32x32
net = deconv2d(
net, ishape=(ngf, 16, 16), oshape=oshape[-3:], kshape=(4, 4), name="g4_deconv")
net = mx.sym.Activation(net, act_type='tanh')
return net
27 changes: 27 additions & 0 deletions nnvm/tests/python/frontend/mxnet/model_zoo/dqn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""
The mxnet symbol of Nature DQN

Reference:
Mnih, Volodymyr, et al.
"Human-level control through deep reinforcement learning."
Nature 518.7540 (2015): 529.
"""

import mxnet as mx

def get_symbol(num_action=18):
data = mx.sym.Variable(name='data')
net = mx.sym.Convolution(data, kernel=(8, 8), stride=(4, 4),
num_filter=32, name='conv1')
net = mx.sym.Activation(net, act_type='relu', name='relu1')
net = mx.sym.Convolution(net, kernel=(4, 4), stride=(2, 2),
num_filter=64, name='conv2')
net = mx.sym.Activation(net, act_type='relu', name='relu2')
net = mx.sym.Convolution(net, kernel=(3, 3), stride=(1, 1),
num_filter=64, name='conv3')
net = mx.sym.Activation(net, act_type='relu', name='relu3')
net = mx.sym.FullyConnected(net, num_hidden=512, name='fc4')
net = mx.sym.Activation(net, act_type='relu', name='relu4')
net = mx.sym.FullyConnected(net, num_hidden=num_action, name='fc5', flatten=False)

return net
14 changes: 14 additions & 0 deletions nnvm/tests/python/frontend/mxnet/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,18 @@ def test_resnet():
nnvm_sym = model_zoo.nnvm_resnet[n]
compare_graph(from_mx_sym, nnvm_sym)

def test_dqn():
mx_sym = model_zoo.mx_dqn
from_mx_sym, _ = nnvm.frontend.from_mxnet(mx_sym)
nnvm_sym = model_zoo.nnvm_dqn
compare_graph(from_mx_sym, nnvm_sym)

def test_dcgan():
mx_sym = model_zoo.mx_dcgan
from_mx_sym, _ = nnvm.frontend.from_mxnet(mx_sym)
nnvm_sym = model_zoo.nnvm_dcgan
compare_graph(from_mx_sym, nnvm_sym)

def test_multi_outputs():
def compose(F, **kwargs):
x = F.sym.Variable('x')
Expand All @@ -48,3 +60,5 @@ def compose(F, **kwargs):
test_vgg()
test_resnet()
test_multi_outputs()
test_dqn()
test_dcgan()