Skip to content

Commit

Permalink
Force vocoder to only accept mel_base = 'e'
Browse files Browse the repository at this point in the history
  • Loading branch information
yqzhishen committed Jul 12, 2024
1 parent a1c0a5f commit da80060
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 8 deletions.
8 changes: 6 additions & 2 deletions deployment/exporters/nsf_hifigan_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,11 @@ def build_model(self) -> nn.Module:
config_path = self.model_path.with_name('config.json')
with open(config_path, 'r', encoding='utf8') as f:
config = json.load(f)
model = NSFHiFiGANONNX(config, mel_base=hparams.get('mel_base', '10')).eval().to(self.device)
assert hparams.get('mel_base') == 'e', (
"Mel base must be set to \'e\' according to 2nd stage of the migration plan. "
"See https://github.com/openvpi/DiffSinger/releases/tag/v2.3.0 for more details."
)
model = NSFHiFiGANONNX(config).eval().to(self.device)
load_ckpt(model.generator, str(self.model_path),
prefix_in_ckpt=None, key_in_ckpt='generator',
strict=True, device=self.device)
Expand Down Expand Up @@ -67,7 +71,7 @@ def export_attachments(self, path: Path):
'num_mel_bins': hparams['audio_num_mel_bins'],
'mel_fmin': hparams['fmin'],
'mel_fmax': hparams['fmax'] if hparams['fmax'] is not None else hparams['audio_sample_rate'] / 2,
'mel_base': str(hparams.get('mel_base', '10')),
'mel_base': 'e',
'mel_scale': 'slaney',
}, fw, sort_keys=False)
print(f'| export configs => {config_path} **PLEASE EDIT BEFORE USE**')
Expand Down
7 changes: 1 addition & 6 deletions deployment/modules/nsf_hifigan.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,11 @@

# noinspection SpellCheckingInspection
class NSFHiFiGANONNX(torch.nn.Module):
def __init__(self, attrs: dict, mel_base='e'):
def __init__(self, attrs: dict):
super().__init__()
self.mel_base = str(mel_base)
assert self.mel_base in ['e', '10'], "mel_base must be 'e', '10' or 10."
self.generator = Generator(AttrDict(attrs))

def forward(self, mel: torch.Tensor, f0: torch.Tensor):
mel = mel.transpose(1, 2)
if self.mel_base != 'e':
# log10 to log mel
mel = mel * 2.30259
wav = self.generator(mel, f0)
return wav.squeeze(1)

0 comments on commit da80060

Please sign in to comment.