From 9606e4e9574b7e95f1ac28d2f72866ea4fded51c Mon Sep 17 00:00:00 2001 From: yqzhishen Date: Fri, 5 Apr 2024 20:39:12 +0800 Subject: [PATCH] Revert "Test removing sinusoidal pos embedding" This reverts commit db53b9e2 --- modules/diffusion/wavenet.py | 23 ++++++++--------------- 1 file changed, 8 insertions(+), 15 deletions(-) diff --git a/modules/diffusion/wavenet.py b/modules/diffusion/wavenet.py index b5e79ebc..0a1400d3 100644 --- a/modules/diffusion/wavenet.py +++ b/modules/diffusion/wavenet.py @@ -67,17 +67,13 @@ def __init__(self, in_dims, n_feats, *, n_layers=20, n_chans=256, n_dilates=4): super().__init__() self.in_dims = in_dims self.n_feats = n_feats - self.use_sinusoidal_pos_embed = hparams.get('use_sinusoidal_pos_embed', True) self.input_projection = Conv1d(in_dims * n_feats, n_chans, 1) - if self.use_sinusoidal_pos_embed: - self.diffusion_embedding = SinusoidalPosEmb(n_chans) - self.mlp = nn.Sequential( - nn.Linear(n_chans, n_chans * 4), - nn.Mish(), - nn.Linear(n_chans * 4, n_chans) - ) - else: - self.diffusion_embedding = nn.Linear(1, n_chans) + self.diffusion_embedding = SinusoidalPosEmb(n_chans) + self.mlp = nn.Sequential( + nn.Linear(n_chans, n_chans * 4), + nn.Mish(), + nn.Linear(n_chans * 4, n_chans) + ) self.residual_layers = nn.ModuleList([ ResidualBlock( encoder_hidden=hparams['hidden_size'], @@ -104,11 +100,8 @@ def forward(self, spec, diffusion_step, cond): x = self.input_projection(x) # [B, C, T] x = F.relu(x) - if self.use_sinusoidal_pos_embed: - diffusion_step = self.diffusion_embedding(diffusion_step) - diffusion_step = self.mlp(diffusion_step) - else: - diffusion_step = self.diffusion_embedding(diffusion_step[..., None]) + diffusion_step = self.diffusion_embedding(diffusion_step) + diffusion_step = self.mlp(diffusion_step) skip = [] for layer in self.residual_layers: x, skip_connection = layer(x, cond, diffusion_step)