Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New generative model algorithm: Rectified Flow #184

Merged
merged 39 commits into from
Apr 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
a77a7d8
add RectifiedFlow
autumn-2-net Mar 31, 2024
1db6afe
add RectifiedFlow
autumn-2-net Mar 31, 2024
e0cc821
add Variance RectifiedFlow
autumn-2-net Mar 31, 2024
08e9c1c
add RectifiedFlow config
autumn-2-net Mar 31, 2024
ed7ca4d
fix
autumn-2-net Mar 31, 2024
ee188c8
fix
autumn-2-net Mar 31, 2024
c12e3bc
variance support new loss
autumn-2-net Mar 31, 2024
4667424
variance support new loss
autumn-2-net Mar 31, 2024
cbb3b30
fix discrete timestep
autumn-2-net Mar 31, 2024
ca30895
add more sample
autumn-2-net Mar 31, 2024
3ecb848
fix sample
autumn-2-net Mar 31, 2024
82c2383
fix sample
autumn-2-net Mar 31, 2024
af92556
fix sample
autumn-2-net Mar 31, 2024
ff20a19
fix esp
autumn-2-net Mar 31, 2024
4080d0e
fix pitch_predictor
autumn-2-net Apr 1, 2024
ec2f40d
fix
autumn-2-net Apr 1, 2024
801d267
fix bug
autumn-2-net Apr 1, 2024
e534d28
fix bug
autumn-2-net Apr 1, 2024
04abfa4
fix bug
autumn-2-net Apr 1, 2024
d786ed1
fix bug
autumn-2-net Apr 1, 2024
20cbc88
Merge branch 'main' into RectifiedFlow
yqzhishen Apr 3, 2024
0592d43
Migrate to continuous acceleration
yqzhishen Apr 3, 2024
db53b9e
Test removing sinusoidal pos embedding
yqzhishen Apr 4, 2024
b5c0ec6
Fix default value
yqzhishen Apr 4, 2024
a554cea
Support Rectified Flow in ONNX exporters
yqzhishen Apr 4, 2024
3f3f2c4
Fix step range and dsconfig key
yqzhishen Apr 5, 2024
087b929
Fix missing key
yqzhishen Apr 5, 2024
821b3dd
Adjust condition and add optimization logic
yqzhishen Apr 5, 2024
4f0457a
add limit loss weights
autumn-2-net Apr 5, 2024
b908cb8
Merge remote-tracking branch 'origin/RectifiedFlow' into RectifiedFlow
autumn-2-net Apr 5, 2024
7e5bd96
Merge remote-tracking branch 'origin/RectifiedFlow' into RectifiedFlow
yqzhishen Apr 5, 2024
9606e4e
Revert "Test removing sinusoidal pos embedding"
yqzhishen Apr 5, 2024
dc6896b
Refactor code and configs, rename modules, adjust default LR schedule
yqzhishen Apr 16, 2024
07b0ea8
Delete debugging code
yqzhishen Apr 16, 2024
cf3a0b8
Fix unexpected float64 in ONNX graphs
yqzhishen Apr 17, 2024
2cfe603
Add configuration schemas for Rectified Flow stuff
yqzhishen Apr 17, 2024
ee7f3cd
Update module name in comments, docs and images
yqzhishen Apr 17, 2024
45db81f
Update image and references
yqzhishen Apr 17, 2024
7539fe8
Fix readme
yqzhishen Apr 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,24 @@ TBD

## References

- Original DiffSinger: [paper](https://arxiv.org/abs/2105.02446), [implementation](https://github.com/MoonInTheRiver/DiffSinger)
### Original Paper & Implementation

- Paper: [DiffSinger: Singing Voice Synthesis via Shallow Diffusion Mechanism](https://arxiv.org/abs/2105.02446)
- Implementation: [MoonInTheRiver/DiffSinger](https://github.com/MoonInTheRiver/DiffSinger)

### Generative Models & Algorithms

- Denoising Diffusion Probabilistic Models (DDPM): [paper](https://arxiv.org/abs/2006.11239), [implementation](https://github.com/hojonathanho/diffusion)
- [DDIM](https://arxiv.org/abs/2010.02502) for diffusion sampling acceleration
- [PNDM](https://arxiv.org/abs/2202.09778) for diffusion sampling acceleration
- [DPM-Solver++](https://github.com/LuChengTHU/dpm-solver) for diffusion sampling acceleration
- [UniPC](https://github.com/wl-zhao/UniPC) for diffusion sampling acceleration
- Rectified Flow (RF): [paper](https://arxiv.org/abs/2209.03003), [implementation](https://github.com/gnobitab/RectifiedFlow)

### Dependencies & Submodules

- [HiFi-GAN](https://github.com/jik876/hifi-gan) and [NSF](https://github.com/nii-yamagishilab/project-NN-Pytorch-scripts/tree/master/project/01-nsf) for waveform reconstruction
- [pc-ddsp](https://github.com/yxlllc/pc-ddsp) for waveform reconstruction
- [DDIM](https://arxiv.org/abs/2010.02502) for diffusion sampling acceleration
- [PNDM](https://arxiv.org/abs/2202.09778) for diffusion sampling acceleration
- [DPM-Solver++](https://github.com/LuChengTHU/dpm-solver) for diffusion sampling acceleration
- [UniPC](https://github.com/wl-zhao/UniPC) for diffusion sampling acceleration
- [RMVPE](https://github.com/Dream-High/RMVPE) and yxlllc's [fork](https://github.com/yxlllc/RMVPE) for pitch extraction

## Disclaimer
Expand Down
20 changes: 13 additions & 7 deletions configs/acoustic.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -63,21 +63,28 @@ use_tension_embed: false
use_key_shift_embed: false
use_speed_embed: false

diffusion_type: reflow
time_scale_factor: 1000
timesteps: 1000
max_beta: 0.02
rel_pos: true
sampling_algorithm: euler
sampling_steps: 20
diff_accelerator: ddim
diff_speedup: 10
hidden_size: 256
residual_layers: 20
residual_channels: 512
dilation_cycle_length: 4 # *
diff_decoder_type: 'wavenet'
diff_loss_type: l2
backbone_type: 'wavenet'
main_loss_type: l2
main_loss_log_norm: false
schedule_type: 'linear'

# shallow diffusion
use_shallow_diffusion: true
T_start: 0.4
T_start_infer: 0.4
K_step: 400
K_step_infer: 400

Expand All @@ -100,20 +107,19 @@ num_sanity_val_steps: 1
optimizer_args:
lr: 0.0006
lr_scheduler_args:
step_size: 30000
gamma: 0.5
step_size: 10000
gamma: 0.75
max_batch_frames: 50000
max_batch_size: 64
dataset_size_key: 'lengths'
val_with_vocoder: true
val_check_interval: 2000
num_valid_plots: 10
max_updates: 200000
max_updates: 160000
num_ckpt_keep: 5
permanent_ckpt_start: 120000
permanent_ckpt_start: 80000
permanent_ckpt_interval: 20000


finetune_enabled: false
finetune_ckpt_path: null

Expand Down
9 changes: 6 additions & 3 deletions configs/templates/config_acoustic.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,10 @@ residual_channels: 512
residual_layers: 20

# shallow diffusion
diffusion_type: reflow
use_shallow_diffusion: true
T_start: 0.4
T_start_infer: 0.4
K_step: 300
K_step_infer: 300
shallow_diffusion_args:
Expand All @@ -73,11 +76,11 @@ optimizer_args:
lr: 0.0006
lr_scheduler_args:
scheduler_cls: torch.optim.lr_scheduler.StepLR
step_size: 30000
gamma: 0.5
step_size: 10000
gamma: 0.75
max_batch_frames: 50000
max_batch_size: 64
max_updates: 200000
max_updates: 160000

num_valid_plots: 10
val_with_vocoder: true
Expand Down
8 changes: 5 additions & 3 deletions configs/templates/config_variance.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ use_glide_embed: false
glide_types: [up, down]
glide_embed_scale: 11.313708498984760 # sqrt(128)

diffusion_type: reflow

pitch_prediction_args:
pitd_norm_min: -8.0
pitd_norm_max: 8.0
Expand All @@ -89,16 +91,16 @@ optimizer_args:
lr: 0.0006
lr_scheduler_args:
scheduler_cls: torch.optim.lr_scheduler.StepLR
step_size: 12000
step_size: 10000
gamma: 0.75
max_batch_frames: 80000
max_batch_size: 48
max_updates: 288000
max_updates: 160000

num_valid_plots: 10
val_check_interval: 2000
num_ckpt_keep: 5
permanent_ckpt_start: 180000
permanent_ckpt_start: 80000
permanent_ckpt_interval: 10000
pl_trainer_devices: 'auto'
pl_trainer_precision: '16-mixed'
15 changes: 10 additions & 5 deletions configs/variance.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,17 @@ lambda_dur_loss: 1.0
lambda_pitch_loss: 1.0
lambda_var_loss: 1.0

diffusion_type: reflow # ddpm
time_scale_factor: 1000
schedule_type: 'linear'
K_step: 1000
timesteps: 1000
max_beta: 0.02
diff_decoder_type: 'wavenet'
diff_loss_type: l2
backbone_type: 'wavenet'
main_loss_type: l2
main_loss_log_norm: true
sampling_algorithm: euler
sampling_steps: 20
diff_accelerator: ddim
diff_speedup: 10

Expand All @@ -111,16 +116,16 @@ num_sanity_val_steps: 1
optimizer_args:
lr: 0.0006
lr_scheduler_args:
step_size: 12000
step_size: 10000
gamma: 0.75
max_batch_frames: 80000
max_batch_size: 48
dataset_size_key: 'lengths'
val_check_interval: 2000
num_valid_plots: 10
max_updates: 288000
max_updates: 160000
num_ckpt_keep: 5
permanent_ckpt_start: 180000
permanent_ckpt_start: 80000
permanent_ckpt_interval: 10000

finetune_enabled: false
Expand Down
80 changes: 53 additions & 27 deletions deployment/exporters/acoustic_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(
self.aux_decoder_class_name = remove_suffix(
self.model.aux_decoder.decoder.__class__.__name__, 'ONNX'
) if self.model.use_shallow_diffusion else None
self.denoiser_class_name = remove_suffix(self.model.diffusion.denoise_fn.__class__.__name__, 'ONNX')
self.backbone_class_name = remove_suffix(self.model.diffusion.backbone.__class__.__name__, 'ONNX')
self.diffusion_class_name = remove_suffix(self.model.diffusion.__class__.__name__, 'ONNX')

# Attributes for exporting
Expand Down Expand Up @@ -78,6 +78,14 @@ def __init__(
if self.freeze_spk is not None:
self.model.fs2.register_buffer('frozen_spk_embed', self._perform_spk_mix(self.freeze_spk[1]))

# Acceleration
if self.model.diffusion_type == 'ddpm':
self.acceleration_type = 'discrete'
elif self.model.diffusion_type == 'reflow':
self.acceleration_type = 'continuous'
else:
raise ValueError(f'Invalid diffusion type: {self.model.diffusion_type}')

def build_model(self) -> DiffSingerAcousticONNX:
model = DiffSingerAcousticONNX(
vocab_size=len(self.vocab),
Expand Down Expand Up @@ -138,9 +146,13 @@ def export_attachments(self, path: Path):
dsconfig['use_speed_embed'] = self.expose_velocity
for variance in VARIANCE_CHECKLIST:
dsconfig[f'use_{variance}_embed'] = (variance in self.model.fs2.variance_embed_list)
# shallow diffusion
dsconfig['use_shallow_diffusion'] = self.model.use_shallow_diffusion
dsconfig['max_depth'] = self.model.diffusion.k_step
# sampling acceleration and shallow diffusion
dsconfig['use_continuous_acceleration'] = self.acceleration_type == 'continuous'
dsconfig['use_variable_depth'] = self.model.use_shallow_diffusion
if self.acceleration_type == 'continuous':
dsconfig['max_depth'] = 1 - self.model.diffusion.t_start
else:
dsconfig['max_depth'] = self.model.diffusion.k_step
# mel specification
dsconfig['sample_rate'] = hparams['audio_sample_rate']
dsconfig['hop_size'] = hparams['hop_size']
Expand Down Expand Up @@ -234,55 +246,69 @@ def _torch_export_model(self):

condition = torch.rand((1, n_frames, hparams['hidden_size']), device=self.device)

# Prepare inputs for denoiser tracing and GaussianDiffusion scripting
# Prepare inputs for backbone tracing and GaussianDiffusion scripting
shape = (1, 1, hparams['audio_num_mel_bins'], n_frames)
noise = torch.randn(shape, device=self.device)
x_start = torch.randn((1, n_frames, hparams['audio_num_mel_bins']),device=self.device)
step = (torch.rand((1,), device=self.device) * hparams['K_step']).long()
x_aux = torch.randn((1, n_frames, hparams['audio_num_mel_bins']), device=self.device)
if self.acceleration_type == 'continuous':
time_or_step = (torch.rand((1,), device=self.device) * self.model.diffusion.time_scale_factor).float()
dummy_depth = torch.tensor(0.1, device=self.device)
dummy_steps_or_speedup = 5
else:
time_or_step = (torch.rand((1,), device=self.device) * self.model.diffusion.k_step).long()
dummy_depth = torch.tensor(0.1, device=self.device)
dummy_steps_or_speedup = 200

print(f'Tracing {self.denoiser_class_name} denoiser...')
diffusion = self.model.view_as_diffusion()
diffusion.diffusion.denoise_fn = torch.jit.trace(
diffusion.diffusion.denoise_fn,
(
noise,
step,
condition.transpose(1, 2)
print(f'Tracing {self.backbone_class_name} backbone...')
if self.model.diffusion_type == 'ddpm':
major_mel_decoder = self.model.view_as_diffusion()
elif self.model.diffusion_type == 'reflow':
major_mel_decoder = self.model.view_as_reflow()
else:
raise ValueError(f'Invalid diffusion type: {self.model.diffusion_type}')
major_mel_decoder.diffusion.set_backbone(
torch.jit.trace(
major_mel_decoder.diffusion.backbone,
(
noise,
time_or_step,
condition.transpose(1, 2)
)
)
)

print(f'Scripting {self.diffusion_class_name}...')
diffusion_inputs = [
condition,
*([x_start, 100] if self.model.use_shallow_diffusion else [])
*([x_aux, dummy_depth] if self.model.use_shallow_diffusion else [])
]
diffusion = torch.jit.script(
diffusion,
major_mel_decoder = torch.jit.script(
major_mel_decoder,
example_inputs=[
(
*diffusion_inputs,
1 # p_sample branch
),
(
*diffusion_inputs,
200 # p_sample_plms branch
dummy_steps_or_speedup # p_sample_plms branch
)
]
)

# PyTorch ONNX export for GaussianDiffusion
print(f'Exporting {self.diffusion_class_name}...')
torch.onnx.export(
diffusion,
major_mel_decoder,
(
*diffusion_inputs,
200
dummy_steps_or_speedup
),
self.diffusion_cache_path,
input_names=[
'condition',
*(['x_start', 'depth'] if self.model.use_shallow_diffusion else []),
'speedup'
*(['x_aux', 'depth'] if self.model.use_shallow_diffusion else []),
('steps' if self.acceleration_type == 'continuous' else 'speedup')
],
output_names=[
'mel'
Expand All @@ -291,7 +317,7 @@ def _torch_export_model(self):
'condition': {
1: 'n_frames'
},
**({'x_start': {1: 'n_frames'}} if self.model.use_shallow_diffusion else {}),
**({'x_aux': {1: 'n_frames'}} if self.model.use_shallow_diffusion else {}),
'mel': {
1: 'n_frames'
}
Expand Down Expand Up @@ -337,8 +363,8 @@ def _optimize_diffusion_graph(self, diffusion: onnx.ModelProto) -> onnx.ModelPro
onnx_helper.graph_fold_back_to_squeeze(diffusion.graph)
onnx_helper.graph_extract_conditioner_projections(
graph=diffusion.graph, op_type='Conv',
weight_pattern=r'diffusion\.denoise_fn\.residual_layers\.\d+\.conditioner_projection\.weight',
alias_prefix='/diffusion/denoise_fn/cache'
weight_pattern=r'diffusion\..*\.conditioner_projection\.weight',
alias_prefix='/diffusion/backbone/cache'
)
onnx_helper.graph_remove_unused_values(diffusion.graph)
print(f'Running ONNX Simplifier #2 on {self.diffusion_class_name}...')
Expand All @@ -361,7 +387,7 @@ def _merge_fs2_aux_diffusion_graphs(self, fs2: onnx.ModelProto, diffusion: onnx.
merged = onnx.compose.merge_models(
fs2, diffusion, io_map=[
('condition', 'condition'),
*([('aux_mel', 'x_start')] if self.model.use_shallow_diffusion else []),
*([('aux_mel', 'x_aux')] if self.model.use_shallow_diffusion else []),
],
prefix1='', prefix2='', doc_string='',
producer_name=fs2.producer_name, producer_version=fs2.producer_version,
Expand Down
Loading