Skip to content

Commit

Permalink
Self contained NAT Transformer from https://arxiv.org/abs/1805.11063
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 199180774
  • Loading branch information
royaurko authored and whr94621 committed Jun 12, 2018
1 parent 788e229 commit f08cd6f
Show file tree
Hide file tree
Showing 5 changed files with 446 additions and 14 deletions.
14 changes: 14 additions & 0 deletions tensor2tensor/data_generators/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,20 @@ class TranslateDistillProblem(TranslateProblem):
def is_generate_per_split(self):
return True

def example_reading_spec(self):
data_fields = {"dist_targets": tf.VarLenFeature(tf.int64)}

if self.has_inputs:
data_fields["inputs"] = tf.VarLenFeature(tf.int64)

# hack: ignoring true targets and putting dist_targets in targets
data_items_to_decoders = {
"inputs": tf.contrib.slim.tfexample_decoder.Tensor("inputs"),
"targets": tf.contrib.slim.tfexample_decoder.Tensor("dist_targets"),
}

return (data_fields, data_items_to_decoders)

def get_or_create_vocab(self, data_dir, tmp_dir, force_get=False):
"""Get vocab for distill problems."""
# We assume that vocab file is present in data_dir directory where the
Expand Down
41 changes: 29 additions & 12 deletions tensor2tensor/layers/discretization.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,32 +695,42 @@ def get_vq_bottleneck(bottleneck_size, hidden_size):
return means, ema_means, ema_count


def vq_nearest_neighbor(x, means):
def vq_nearest_neighbor(x, means, soft_em=False, num_samples=10):
"""Find the nearest element in means to elements in x."""
bottleneck_size = common_layers.shape_list(means)[0]
x_norm_sq = tf.reduce_sum(tf.square(x), axis=-1, keepdims=True)
means_norm_sq = tf.reduce_sum(tf.square(means), axis=-1, keepdims=True)
scalar_prod = tf.matmul(x, means, transpose_b=True)
dist = x_norm_sq + tf.transpose(means_norm_sq) - 2 * scalar_prod
x_means_idx = tf.argmax(-dist, axis=-1)
x_means_hot = tf.one_hot(x_means_idx, bottleneck_size)
if soft_em:
x_means_idx = tf.multinomial(-dist, num_samples=num_samples)
x_means_hot = tf.one_hot(
x_means_idx, depth=common_layers.shape_list(means)[0])
x_means_hot = tf.reduce_sum(x_means_hot, axis=1)
else:
x_means_idx = tf.argmax(-dist, axis=-1)
x_means_hot = tf.one_hot(x_means_idx, bottleneck_size)
x_means_hot_flat = tf.reshape(x_means_hot, [-1, bottleneck_size])
x_means = tf.matmul(x_means_hot_flat, means)
e_loss = tf.reduce_mean(tf.square(x - tf.stop_gradient(x_means)))
return x_means_hot, e_loss


def vq_discrete_bottleneck(x,
bottleneck_size,
bottleneck_bits,
beta=0.25,
decay=0.999,
epsilon=1e-5):
epsilon=1e-5,
soft_em=False,
num_samples=10):
"""Simple vector quantized discrete bottleneck."""
bottleneck_size = 2**bottleneck_bits
x_shape = common_layers.shape_list(x)
hidden_size = x_shape[-1]
means, ema_means, ema_count = get_vq_bottleneck(bottleneck_size, hidden_size)
x = tf.reshape(x, [-1, hidden_size])
x_means_hot, e_loss = vq_nearest_neighbor(x, means)
x_means_hot, e_loss = vq_nearest_neighbor(
x, means, soft_em=soft_em, num_samples=num_samples)

# Update the ema variables
updated_ema_count = moving_averages.assign_moving_average(
Expand All @@ -731,7 +741,6 @@ def vq_discrete_bottleneck(x,
zero_debias=False)

dw = tf.matmul(x_means_hot, x, transpose_a=True)

updated_ema_means = tf.identity(moving_averages.assign_moving_average(
ema_means, dw, decay, zero_debias=False))
n = tf.reduce_sum(updated_ema_count, axis=-1, keepdims=True)
Expand All @@ -754,8 +763,7 @@ def vq_discrete_unbottleneck(x, hidden_size):
bottleneck_size = common_layers.shape_list(x)[-1]
means, _, _ = get_vq_bottleneck(bottleneck_size, hidden_size)
result = tf.matmul(tf.reshape(x, [-1, x_shape[-1]]), means)
return tf.reshape(result,
x_shape[:-1] + [common_layers.shape_list(means)[-1]])
return tf.reshape(result, x_shape[:-1] + [hidden_size])


def tanh_discrete_bottleneck(x, bottleneck_bits, bottleneck_noise,
Expand Down Expand Up @@ -825,9 +833,18 @@ def parametrized_bottleneck(x, hparams):
hparams.discretize_warmup_steps, hparams.mode,
hparams.isemhash_noise_dev, hparams.isemhash_mix_prob)
if hparams.bottleneck_kind == "vq":
bottleneck_size = 2**hparams.bottleneck_bits
return vq_discrete_bottleneck(x, bottleneck_size, hparams.vq_beta,
return vq_discrete_bottleneck(x, hparams.bottleneck_bits, hparams.vq_beta,
hparams.vq_decay, hparams.vq_epsilon)
if hparams.bottleneck_kind == "em":
return vq_discrete_bottleneck(
x,
hparams.bottleneck_bits,
hparams.vq_beta,
hparams.vq_decay,
hparams.vq_epsilon,
soft_em=True,
num_samples=hparams.vq_num_samples)

raise ValueError("Unsupported hparams.bottleneck_kind %s"
% hparams.bottleneck_kind)

Expand All @@ -839,7 +856,7 @@ def parametrized_unbottleneck(x, hidden_size, hparams):
if hparams.bottleneck_kind == "isemhash":
return isemhash_unbottleneck(
x, hidden_size, hparams.isemhash_filter_size_multiplier)
if hparams.bottleneck_kind == "vq":
if hparams.bottleneck_kind in ["vq", "em"]:
return vq_discrete_unbottleneck(x, hidden_size)
raise ValueError("Unsupported hparams.bottleneck_kind %s"
% hparams.bottleneck_kind)
5 changes: 3 additions & 2 deletions tensor2tensor/layers/discretization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,8 @@ def testNearestNeighbors(self):
self.assertTrue(np.all(x_means_hot_eval == x_means_hot_test))

def testGetVQBottleneck(self):
bottleneck_size = 4
bottleneck_bits = 2
bottleneck_size = 2**bottleneck_bits
hidden_size = 3
means, _, ema_count = discretization.get_vq_bottleneck(bottleneck_size,
hidden_size)
Expand All @@ -148,7 +149,7 @@ def testVQNearestNeighbors(self):

def testVQDiscreteBottleneck(self):
x = tf.constant([[0, 0.9, 0], [0.8, 0., 0.]], dtype=tf.float32)
x_means_hot, _ = discretization.vq_discrete_bottleneck(x, bottleneck_size=4)
x_means_hot, _ = discretization.vq_discrete_bottleneck(x, bottleneck_bits=2)
with self.test_session() as sess:
tf.global_variables_initializer().run()
x_means_hot_eval = sess.run(x_means_hot)
Expand Down
1 change: 1 addition & 0 deletions tensor2tensor/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from tensor2tensor.models.research import rl
from tensor2tensor.models.research import super_lm
from tensor2tensor.models.research import transformer_moe
from tensor2tensor.models.research import transformer_nat
from tensor2tensor.models.research import transformer_revnet
from tensor2tensor.models.research import transformer_sketch
from tensor2tensor.models.research import transformer_symshard
Expand Down
Loading

0 comments on commit f08cd6f

Please sign in to comment.