Skip to content

Commit

Permalink
Merge branch 'felixbur:main' into add-db
Browse files Browse the repository at this point in the history
  • Loading branch information
bagustris authored May 17, 2024
2 parents 8cdb6dd + 47783b7 commit 2432ddb
Show file tree
Hide file tree
Showing 10 changed files with 622 additions and 46 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
Changelog
=========

Version 0.85.0
--------------
* first version with finetuning wav2vec2 layers

Version 0.84.1
--------------
* made resample independent of config file
Expand Down
2 changes: 1 addition & 1 deletion nkululeko/constants.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
VERSION="0.84.1"
VERSION="0.85.0"
SAMPLING_RATE = 16000
7 changes: 6 additions & 1 deletion nkululeko/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,12 @@ def extract_feats(self):
feats_name = "_".join(ast.literal_eval(
glob_conf.config["DATA"]["databases"]))
self.feats_test, self.feats_train = pd.DataFrame(), pd.DataFrame()
feats_types = self.util.config_val_list("FEATS", "type", ["os"])
feats_types = self.util.config_val_list("FEATS", "type", [])
# for some models no features are needed
if len(feats_types) == 0:
self.util.debug("no feature extractor specified.")
self.feats_train, self.feats_test = pd.DataFrame(), pd.DataFrame()
return
self.feature_extractor = FeatureExtractor(
df_train, feats_types, feats_name, "train"
)
Expand Down
9 changes: 3 additions & 6 deletions nkululeko/feat_extract/feats_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,22 +32,19 @@ def init_model(self):
model_name = f"openai/{self.feat_type}"
self.model = WhisperModel.from_pretrained(model_name).to(self.device)
print(f"intialized Whisper model on {self.device}")
self.feature_extractor = AutoFeatureExtractor.from_pretrained(
model_name)
self.feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
self.model_initialized = True

def extract(self):
"""Extract the features or load them from disk if present."""
store = self.util.get_path("store")
storage = f"{store}{self.name}.pkl"
extract = self.util.config_val(
"FEATS", "needs_feature_extraction", False)
extract = self.util.config_val("FEATS", "needs_feature_extraction", False)
no_reuse = eval(self.util.config_val("FEATS", "no_reuse", "False"))
if extract or no_reuse or not os.path.isfile(storage):
if not self.model_initialized:
self.init_model()
self.util.debug(
"extracting whisper embeddings, this might take a while...")
self.util.debug("extracting whisper embeddings, this might take a while...")
emb_series = []
for (file, start, end), _ in audeer.progress_bar(
self.data_df.iterrows(),
Expand Down
89 changes: 56 additions & 33 deletions nkululeko/modelrunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,12 @@ def do_epochs(self):
highest = 0
else:
highest = 100000
# for all epochs
for epoch in range(epoch_num):
if only_test:
self.model.load(self.run, epoch)
self.util.debug(f"reusing model: {self.model.store_path}")
self.model.reset_test(self.df_test, self.feats_test)
else:
self.model.set_id(self.run, epoch)
self.model.train()
if self.model.model_type == "finetuned":
# epochs are handled by Huggingface API
self.model.train()
report = self.model.predict()
# todo: findout the best epoch
epoch = epoch_num
report.set_id(self.run, epoch)
plot_name = self.util.get_plot_name() + f"_{self.run}_{epoch:03d}_cnf"
reports.append(report)
Expand All @@ -67,32 +63,53 @@ def do_epochs(self):
if plot_epochs:
self.util.debug(f"plotting conf matrix to {plot_name}")
report.plot_confmatrix(plot_name, epoch)
store_models = self.util.config_val("EXP", "save", False)
plot_best_model = self.util.config_val("PLOT", "best_model", False)
if (store_models or plot_best_model) and (
not only_test
): # in any case the model needs to be stored to disk.
self.model.store()
if patience:
patience = int(patience)
result = report.result.get_result()
if self.util.high_is_good():
if result > highest:
highest = result
patience_counter = 0
else:
patience_counter += 1
else:
# for all epochs
for epoch in range(epoch_num):
if only_test:
self.model.load(self.run, epoch)
self.util.debug(f"reusing model: {self.model.store_path}")
self.model.reset_test(self.df_test, self.feats_test)
else:
if result < highest:
highest = result
patience_counter = 0
self.model.set_id(self.run, epoch)
self.model.train()
report = self.model.predict()
report.set_id(self.run, epoch)
plot_name = self.util.get_plot_name() + f"_{self.run}_{epoch:03d}_cnf"
reports.append(report)
self.util.debug(
f"run: {self.run} epoch: {epoch}: result: "
f"{reports[-1].get_result().get_test_result()}"
)
if plot_epochs:
self.util.debug(f"plotting conf matrix to {plot_name}")
report.plot_confmatrix(plot_name, epoch)
store_models = self.util.config_val("EXP", "save", False)
plot_best_model = self.util.config_val("PLOT", "best_model", False)
if (store_models or plot_best_model) and (
not only_test
): # in any case the model needs to be stored to disk.
self.model.store()
if patience:
patience = int(patience)
result = report.result.get_result()
if self.util.high_is_good():
if result > highest:
highest = result
patience_counter = 0
else:
patience_counter += 1
else:
patience_counter += 1
if patience_counter >= patience:
self.util.debug(
f"reached patience ({str(patience)}): early stopping"
)
break
if result < highest:
highest = result
patience_counter = 0
else:
patience_counter += 1
if patience_counter >= patience:
self.util.debug(
f"reached patience ({str(patience)}): early stopping"
)
break

if not plot_epochs:
# Do at least one confusion matrix plot
Expand Down Expand Up @@ -133,6 +150,12 @@ def _select_model(self, model_type):
self.model = Bayes_model(
self.df_train, self.df_test, self.feats_train, self.feats_test
)
elif model_type == "finetune":
from nkululeko.models.model_tuned import Pretrained_model

self.model = Pretrained_model(
self.df_train, self.df_test, self.feats_train, self.feats_test
)
elif model_type == "gmm":
from nkululeko.models.model_gmm import GMM_model

Expand Down
9 changes: 9 additions & 0 deletions nkululeko/models/finetune_model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
"""
Code based on @jwagner
"""

import dataclasses
import typing

Expand Down Expand Up @@ -148,6 +152,11 @@ def forward(
logits_cat=logits_cat,
)

def predict(self, signal):
result = self(torch.from_numpy(signal))
result = result[0].detach().numpy()[0]
return result


class ModelWithPreProcessing(Model):

Expand Down
2 changes: 1 addition & 1 deletion nkululeko/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def set_model_type(self, type):
self.model_type = type

def is_ann(self):
if self.model_type == "ann":
if (self.model_type == "ann") or (self.model_type == "finetuned"):
return True
else:
return False
Expand Down
Loading

0 comments on commit 2432ddb

Please sign in to comment.