Skip to content

Commit

Permalink
modify feat extractor, wavlm, os
Browse files Browse the repository at this point in the history
  • Loading branch information
bagustris committed Apr 23, 2024
1 parent 764e487 commit a7a05cd
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 15 deletions.
4 changes: 2 additions & 2 deletions nkululeko/feat_extract/feats_opensmile.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@


class Opensmileset(Featureset):
def __init__(self, name, data_df):
def __init__(self, name, data_df, feats_type=None, config_file=None):
super().__init__(name, data_df)
self.featset = self.util.config_val("FEATS", "set", "eGeMAPSv02")
try:
self.feature_set = eval(f"opensmile.FeatureSet.{self.featset}")
#'eGeMAPSv02, ComParE_2016, GeMAPSv01a, eGeMAPSv01a':
# 'eGeMAPSv02, ComParE_2016, GeMAPSv01a, eGeMAPSv01a':
except AttributeError:
self.util.error(
f"something is wrong with feature set: {self.featset}"
Expand Down
5 changes: 1 addition & 4 deletions nkululeko/feat_extract/feats_wavlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,7 @@ def extract(self):
frame_offset=int(start.total_seconds() * 16000),
num_frames=int((end - start).total_seconds() * 16000),
)
if sampling_rate != 16000:
self.util.error(
f"sampling rate should be 16000 but is {sampling_rate}"
)
assert sampling_rate == 16000, f"sampling rate should be 16000 but is {sampling_rate}"
emb = self.get_embeddings(signal, sampling_rate, file)
emb_series.iloc[idx] = emb
self.df = pd.DataFrame(emb_series.values.tolist(), index=self.data_df.index)
Expand Down
24 changes: 15 additions & 9 deletions nkululeko/feature_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,12 @@ def extract(self):
self.feats = pd.DataFrame()
for feats_type in self.feats_types:
store_name = f"{self.data_name}_{feats_type}"
self.feat_extractor = self._get_feat_extractor(store_name, feats_type)
self.feat_extractor = self._get_feat_extractor(
store_name, feats_type)
self.feat_extractor.extract()
self.feat_extractor.filter()
self.feats = pd.concat([self.feats, self.feat_extractor.df], axis=1)
self.feats = pd.concat(
[self.feats, self.feat_extractor.df], axis=1)
return self.feats

def extract_sample(self, signal, sr):
Expand All @@ -53,14 +55,14 @@ def _get_feat_extractor(self, store_name, feats_type):
if feat_extractor_class is None:
self.util.error(f"unknown feats_type: {feats_type}")
return feat_extractor_class(
f"{store_name}_{self.feats_designation}", self.data_df
f"{store_name}_{self.feats_designation}", self.data_df, feats_type
)

def _get_feat_extractor_class(self, feats_type):
if feats_type == "os":
from nkululeko.feat_extract.feats_opensmile import Opensmileset

return Opensmileset

elif feats_type == "spectra":
from nkululeko.feat_extract.feats_spectra import Spectraloader

Expand All @@ -69,8 +71,10 @@ def _get_feat_extractor_class(self, feats_type):
from nkululeko.feat_extract.feats_trill import TRILLset

return TRILLset
elif feats_type.startswith(("wav2vec", "hubert", "wavlm", "spkrec")):
elif feats_type.startswith(
("wav2vec", "hubert", "wavlm", "spkrec", "whisper")):
return self._get_feat_extractor_by_prefix(feats_type)

elif feats_type in (
"audmodel",
"auddim",
Expand All @@ -89,16 +93,18 @@ def _get_feat_extractor_class(self, feats_type):
return None

def _get_feat_extractor_by_prefix(self, feats_type):
prefix, _, ext = feats_type.partition("_")
prefix, _, ext = feats_type.partition("-")
from importlib import import_module

module = import_module(f"nkululeko.feat_extract.feats_{prefix.lower()}")
class_name = f"{prefix.capitalize()}{ext.capitalize()}set"
module = import_module(
f"nkululeko.feat_extract.feats_{prefix.lower()}")
class_name = f"{prefix.capitalize()}"
return getattr(module, class_name)

def _get_feat_extractor_by_name(self, feats_type):
from importlib import import_module

module = import_module(f"nkululeko.feat_extract.feats_{feats_type.lower()}")
module = import_module(
f"nkululeko.feat_extract.feats_{feats_type.lower()}")
class_name = f"{feats_type.capitalize()}Set"
return getattr(module, class_name)

0 comments on commit a7a05cd

Please sign in to comment.