Skip to content

Commit

Permalink
Revert "Test removing sinusoidal pos embedding"
Browse files Browse the repository at this point in the history
This reverts commit db53b9e
  • Loading branch information
yqzhishen committed Apr 5, 2024
1 parent 7e5bd96 commit 9606e4e
Showing 1 changed file with 8 additions and 15 deletions.
23 changes: 8 additions & 15 deletions modules/diffusion/wavenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,17 +67,13 @@ def __init__(self, in_dims, n_feats, *, n_layers=20, n_chans=256, n_dilates=4):
super().__init__()
self.in_dims = in_dims
self.n_feats = n_feats
self.use_sinusoidal_pos_embed = hparams.get('use_sinusoidal_pos_embed', True)
self.input_projection = Conv1d(in_dims * n_feats, n_chans, 1)
if self.use_sinusoidal_pos_embed:
self.diffusion_embedding = SinusoidalPosEmb(n_chans)
self.mlp = nn.Sequential(
nn.Linear(n_chans, n_chans * 4),
nn.Mish(),
nn.Linear(n_chans * 4, n_chans)
)
else:
self.diffusion_embedding = nn.Linear(1, n_chans)
self.diffusion_embedding = SinusoidalPosEmb(n_chans)
self.mlp = nn.Sequential(
nn.Linear(n_chans, n_chans * 4),
nn.Mish(),
nn.Linear(n_chans * 4, n_chans)
)
self.residual_layers = nn.ModuleList([
ResidualBlock(
encoder_hidden=hparams['hidden_size'],
Expand All @@ -104,11 +100,8 @@ def forward(self, spec, diffusion_step, cond):
x = self.input_projection(x) # [B, C, T]

x = F.relu(x)
if self.use_sinusoidal_pos_embed:
diffusion_step = self.diffusion_embedding(diffusion_step)
diffusion_step = self.mlp(diffusion_step)
else:
diffusion_step = self.diffusion_embedding(diffusion_step[..., None])
diffusion_step = self.diffusion_embedding(diffusion_step)
diffusion_step = self.mlp(diffusion_step)
skip = []
for layer in self.residual_layers:
x, skip_connection = layer(x, cond, diffusion_step)
Expand Down

0 comments on commit 9606e4e

Please sign in to comment.