-
-
Notifications
You must be signed in to change notification settings - Fork 151
/
Copy pathmodel.py
108 lines (99 loc) · 5.15 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from .istftnet import Decoder
from .modules import CustomAlbert, ProsodyPredictor, TextEncoder
from dataclasses import dataclass
from huggingface_hub import hf_hub_download
from loguru import logger
from numbers import Number
from transformers import AlbertConfig
from typing import Dict, Optional, Union
import json
import torch
class KModel(torch.nn.Module):
'''
KModel is a torch.nn.Module with 2 main responsibilities:
1. Init weights, downloading config.json + model.pth from HF if needed
2. forward(phonemes: str, ref_s: FloatTensor) -> (audio: FloatTensor)
You likely only need one KModel instance, and it can be reused across
multiple KPipelines to avoid redundant memory allocation.
Unlike KPipeline, KModel is language-blind.
KModel stores self.vocab and thus knows how to map phonemes -> input_ids,
so there is no need to repeatedly download config.json outside of KModel.
'''
REPO_ID = 'hexgrad/Kokoro-82M'
def __init__(self, config: Union[Dict, str, None] = None, model: Optional[str] = None):
super().__init__()
if not isinstance(config, dict):
if not config:
logger.debug("No config provided, downloading from HF")
config = hf_hub_download(repo_id=KModel.REPO_ID, filename='config.json')
with open(config, 'r', encoding='utf-8') as r:
config = json.load(r)
logger.debug(f"Loaded config: {config}")
self.vocab = config['vocab']
self.bert = CustomAlbert(AlbertConfig(vocab_size=config['n_token'], **config['plbert']))
self.bert_encoder = torch.nn.Linear(self.bert.config.hidden_size, config['hidden_dim'])
self.context_length = self.bert.config.max_position_embeddings
self.predictor = ProsodyPredictor(
style_dim=config['style_dim'], d_hid=config['hidden_dim'],
nlayers=config['n_layer'], max_dur=config['max_dur'], dropout=config['dropout']
)
self.text_encoder = TextEncoder(
channels=config['hidden_dim'], kernel_size=config['text_encoder_kernel_size'],
depth=config['n_layer'], n_symbols=config['n_token']
)
self.decoder = Decoder(
dim_in=config['hidden_dim'], style_dim=config['style_dim'],
dim_out=config['n_mels'], **config['istftnet']
)
if not model:
model = hf_hub_download(repo_id=KModel.REPO_ID, filename='kokoro-v1_0.pth')
for key, state_dict in torch.load(model, map_location='cpu', weights_only=True).items():
assert hasattr(self, key), key
try:
getattr(self, key).load_state_dict(state_dict)
except:
logger.debug(f"Did not load {key} from state_dict")
state_dict = {k[7:]: v for k, v in state_dict.items()}
getattr(self, key).load_state_dict(state_dict, strict=False)
@property
def device(self):
return self.bert.device
@dataclass
class Output:
audio: torch.FloatTensor
pred_dur: Optional[torch.LongTensor] = None
@torch.no_grad()
def forward(
self,
phonemes: str,
ref_s: torch.FloatTensor,
speed: Number = 1,
return_output: bool = False # MARK: BACKWARD COMPAT
) -> Union['KModel.Output', torch.FloatTensor]:
input_ids = list(filter(lambda i: i is not None, map(lambda p: self.vocab.get(p), phonemes)))
logger.debug(f"phonemes: {phonemes} -> input_ids: {input_ids}")
assert len(input_ids)+2 <= self.context_length, (len(input_ids)+2, self.context_length)
input_ids = torch.LongTensor([[0, *input_ids, 0]]).to(self.device)
input_lengths = torch.LongTensor([input_ids.shape[-1]]).to(self.device)
text_mask = torch.arange(input_lengths.max()).unsqueeze(0).expand(input_lengths.shape[0], -1).type_as(input_lengths)
text_mask = torch.gt(text_mask+1, input_lengths.unsqueeze(1)).to(self.device)
bert_dur = self.bert(input_ids, attention_mask=(~text_mask).int())
d_en = self.bert_encoder(bert_dur).transpose(-1, -2)
ref_s = ref_s.to(self.device)
s = ref_s[:, 128:]
d = self.predictor.text_encoder(d_en, s, input_lengths, text_mask)
x, _ = self.predictor.lstm(d)
duration = self.predictor.duration_proj(x)
duration = torch.sigmoid(duration).sum(axis=-1) / speed
pred_dur = torch.round(duration).clamp(min=1).long().squeeze()
logger.debug(f"pred_dur: {pred_dur}")
indices = torch.repeat_interleave(torch.arange(input_ids.shape[1], device=self.device), pred_dur)
pred_aln_trg = torch.zeros((input_ids.shape[1], indices.shape[0]), device=self.device)
pred_aln_trg[indices, torch.arange(indices.shape[0])] = 1
pred_aln_trg = pred_aln_trg.unsqueeze(0).to(self.device)
en = d.transpose(-1, -2) @ pred_aln_trg
F0_pred, N_pred = self.predictor.F0Ntrain(en, s)
t_en = self.text_encoder(input_ids, input_lengths, text_mask)
asr = t_en @ pred_aln_trg
audio = self.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze().cpu()
return self.Output(audio=audio, pred_dur=pred_dur.cpu()) if return_output else audio