diff --git a/nkululeko/feat_extract/feats_opensmile.py b/nkululeko/feat_extract/feats_opensmile.py index 2289c81f..607b55a2 100644 --- a/nkululeko/feat_extract/feats_opensmile.py +++ b/nkululeko/feat_extract/feats_opensmile.py @@ -8,12 +8,12 @@ class Opensmileset(Featureset): - def __init__(self, name, data_df): + def __init__(self, name, data_df, feats_type=None, config_file=None): super().__init__(name, data_df) self.featset = self.util.config_val("FEATS", "set", "eGeMAPSv02") try: self.feature_set = eval(f"opensmile.FeatureSet.{self.featset}") - #'eGeMAPSv02, ComParE_2016, GeMAPSv01a, eGeMAPSv01a': + # 'eGeMAPSv02, ComParE_2016, GeMAPSv01a, eGeMAPSv01a': except AttributeError: self.util.error( f"something is wrong with feature set: {self.featset}" diff --git a/nkululeko/feat_extract/feats_wavlm.py b/nkululeko/feat_extract/feats_wavlm.py index 03973c70..748791e1 100644 --- a/nkululeko/feat_extract/feats_wavlm.py +++ b/nkululeko/feat_extract/feats_wavlm.py @@ -59,10 +59,7 @@ def extract(self): frame_offset=int(start.total_seconds() * 16000), num_frames=int((end - start).total_seconds() * 16000), ) - if sampling_rate != 16000: - self.util.error( - f"sampling rate should be 16000 but is {sampling_rate}" - ) + assert sampling_rate == 16000, f"sampling rate should be 16000 but is {sampling_rate}" emb = self.get_embeddings(signal, sampling_rate, file) emb_series.iloc[idx] = emb self.df = pd.DataFrame(emb_series.values.tolist(), index=self.data_df.index) diff --git a/nkululeko/feature_extractor.py b/nkululeko/feature_extractor.py index a5f8f5bf..468dc5ea 100644 --- a/nkululeko/feature_extractor.py +++ b/nkululeko/feature_extractor.py @@ -39,10 +39,12 @@ def extract(self): self.feats = pd.DataFrame() for feats_type in self.feats_types: store_name = f"{self.data_name}_{feats_type}" - self.feat_extractor = self._get_feat_extractor(store_name, feats_type) + self.feat_extractor = self._get_feat_extractor( + store_name, feats_type) self.feat_extractor.extract() self.feat_extractor.filter() - self.feats = pd.concat([self.feats, self.feat_extractor.df], axis=1) + self.feats = pd.concat( + [self.feats, self.feat_extractor.df], axis=1) return self.feats def extract_sample(self, signal, sr): @@ -53,14 +55,14 @@ def _get_feat_extractor(self, store_name, feats_type): if feat_extractor_class is None: self.util.error(f"unknown feats_type: {feats_type}") return feat_extractor_class( - f"{store_name}_{self.feats_designation}", self.data_df + f"{store_name}_{self.feats_designation}", self.data_df, feats_type ) def _get_feat_extractor_class(self, feats_type): if feats_type == "os": from nkululeko.feat_extract.feats_opensmile import Opensmileset - return Opensmileset + elif feats_type == "spectra": from nkululeko.feat_extract.feats_spectra import Spectraloader @@ -69,8 +71,10 @@ def _get_feat_extractor_class(self, feats_type): from nkululeko.feat_extract.feats_trill import TRILLset return TRILLset - elif feats_type.startswith(("wav2vec", "hubert", "wavlm", "spkrec")): + elif feats_type.startswith( + ("wav2vec", "hubert", "wavlm", "spkrec", "whisper")): return self._get_feat_extractor_by_prefix(feats_type) + elif feats_type in ( "audmodel", "auddim", @@ -89,16 +93,18 @@ def _get_feat_extractor_class(self, feats_type): return None def _get_feat_extractor_by_prefix(self, feats_type): - prefix, _, ext = feats_type.partition("_") + prefix, _, ext = feats_type.partition("-") from importlib import import_module - module = import_module(f"nkululeko.feat_extract.feats_{prefix.lower()}") - class_name = f"{prefix.capitalize()}{ext.capitalize()}set" + module = import_module( + f"nkululeko.feat_extract.feats_{prefix.lower()}") + class_name = f"{prefix.capitalize()}" return getattr(module, class_name) def _get_feat_extractor_by_name(self, feats_type): from importlib import import_module - module = import_module(f"nkululeko.feat_extract.feats_{feats_type.lower()}") + module = import_module( + f"nkululeko.feat_extract.feats_{feats_type.lower()}") class_name = f"{feats_type.capitalize()}Set" return getattr(module, class_name)