From 5c03a724ed1ecd6c688b75d6895ff62790b1df68 Mon Sep 17 00:00:00 2001 From: Bagus Tris Atmaja Date: Tue, 23 Apr 2024 18:04:39 +0900 Subject: [PATCH] modified: nkululeko/feat_extract/feats_agender.py modified: nkululeko/feat_extract/feats_auddim.py modified: nkululeko/feat_extract/feats_audmodel.py modified: nkululeko/feat_extract/feats_clap.py modified: nkululeko/feat_extract/feats_hubert.py modified: nkululeko/feat_extract/feats_oxbow.py modified: nkululeko/feat_extract/feats_praat.py modified: nkululeko/feat_extract/feats_trill.py modified: nkululeko/feat_extract/feats_wav2vec2.py modified: nkululeko/feat_extract/featureset.py modified: nkululeko/feature_extractor.py --- nkululeko/feat_extract/feats_agender.py | 10 +++++---- nkululeko/feat_extract/feats_auddim.py | 8 ++++--- nkululeko/feat_extract/feats_audmodel.py | 8 ++++--- nkululeko/feat_extract/feats_clap.py | 16 ++++++++------ nkululeko/feat_extract/feats_hubert.py | 5 +++-- nkululeko/feat_extract/feats_oxbow.py | 27 ++++++++++++++---------- nkululeko/feat_extract/feats_praat.py | 13 +++++++----- nkululeko/feat_extract/feats_trill.py | 16 ++++++++------ nkululeko/feat_extract/feats_wav2vec2.py | 23 ++++++++++++++------ nkululeko/feat_extract/featureset.py | 9 +++++--- nkululeko/feature_extractor.py | 10 ++++++--- 11 files changed, 92 insertions(+), 53 deletions(-) diff --git a/nkululeko/feat_extract/feats_agender.py b/nkululeko/feat_extract/feats_agender.py index fc47bf98..97ac22bc 100644 --- a/nkululeko/feat_extract/feats_agender.py +++ b/nkululeko/feat_extract/feats_agender.py @@ -9,16 +9,17 @@ import audinterface -class AudModelAgenderSet(Featureset): +class AgenderSet(Featureset): """ Embeddings from the wav2vec2. based model finetuned on agender data, described in the paper "Speech-based Age and Gender Prediction with Transformers" https://arxiv.org/abs/2306.16962 """ - def __init__(self, name, data_df): - super().__init__(name, data_df) + def __init__(self, name, data_df, feats_type): + super().__init__(name, data_df, feats_type) self.model_loaded = False + self.feats_type = feats_type def _load_model(self): model_url = "https://zenodo.org/record/7761387/files/w2v2-L-robust-6-age-gender.25c844af-1.1.1.zip" @@ -28,7 +29,8 @@ def _load_model(self): if not os.path.isdir(model_root): cache_root = audeer.mkdir("cache") model_root = audeer.mkdir(model_root) - archive_path = audeer.download_url(model_url, cache_root, verbose=True) + archive_path = audeer.download_url( + model_url, cache_root, verbose=True) audeer.extract_archive(archive_path, model_root) device = self.util.config_val("MODEL", "device", "cpu") self.model = audonnx.load(model_root, device=device) diff --git a/nkululeko/feat_extract/feats_auddim.py b/nkululeko/feat_extract/feats_auddim.py index e9d3cbab..26cbae5f 100644 --- a/nkululeko/feat_extract/feats_auddim.py +++ b/nkululeko/feat_extract/feats_auddim.py @@ -21,9 +21,10 @@ class AuddimSet(Featureset): https://arxiv.org/abs/2203.07378. """ - def __init__(self, name, data_df): - super().__init__(name, data_df) + def __init__(self, name, data_df, feats_type): + super().__init__(name, data_df, feats_type) self.model_loaded = False + self.feats_types = feats_type def _load_model(self): model_url = "https://zenodo.org/record/6221127/files/w2v2-L-robust-12.6bc4a7fd-1.1.0.zip" @@ -31,7 +32,8 @@ def _load_model(self): if not os.path.isdir(model_root): cache_root = audeer.mkdir("cache") model_root = audeer.mkdir(model_root) - archive_path = audeer.download_url(model_url, cache_root, verbose=True) + archive_path = audeer.download_url( + model_url, cache_root, verbose=True) audeer.extract_archive(archive_path, model_root) cuda = "cuda" if torch.cuda.is_available() else "cpu" device = self.util.config_val("MODEL", "device", cuda) diff --git a/nkululeko/feat_extract/feats_audmodel.py b/nkululeko/feat_extract/feats_audmodel.py index c2f890f4..10487e2b 100644 --- a/nkululeko/feat_extract/feats_audmodel.py +++ b/nkululeko/feat_extract/feats_audmodel.py @@ -19,9 +19,10 @@ class AudmodelSet(Featureset): https://arxiv.org/abs/2203.07378. """ - def __init__(self, name, data_df): - super().__init__(name, data_df) + def __init__(self, name, data_df, feats_type): + super().__init__(name, data_df, feats_type) self.model_loaded = False + self.feats_type = feats_type def _load_model(self): model_url = "https://zenodo.org/record/6221127/files/w2v2-L-robust-12.6bc4a7fd-1.1.0.zip" @@ -29,7 +30,8 @@ def _load_model(self): if not os.path.isdir(model_root): cache_root = audeer.mkdir("cache") model_root = audeer.mkdir(model_root) - archive_path = audeer.download_url(model_url, cache_root, verbose=True) + archive_path = audeer.download_url( + model_url, cache_root, verbose=True) audeer.extract_archive(archive_path, model_root) cuda = "cuda" if torch.cuda.is_available() else "cpu" device = self.util.config_val("MODEL", "device", cuda) diff --git a/nkululeko/feat_extract/feats_clap.py b/nkululeko/feat_extract/feats_clap.py index a0758414..da1f5de4 100644 --- a/nkululeko/feat_extract/feats_clap.py +++ b/nkululeko/feat_extract/feats_clap.py @@ -11,14 +11,15 @@ import audiofile -class Clap(Featureset): +class ClapSet(Featureset): """Class to extract laion's clap embeddings (https://github.com/LAION-AI/CLAP)""" - def __init__(self, name, data_df): + def __init__(self, name, data_df, feats_type): """Constructor. is_train is needed to distinguish from test/dev sets, because they use the codebook from the training""" - super().__init__(name, data_df) + super().__init__(name, data_df, feats_type) self.device = self.util.config_val("MODEL", "device", "cpu") self.model_initialized = False + self.feat_type = feats_type def init_model(self): # load model @@ -32,12 +33,14 @@ def extract(self): store = self.util.get_path("store") store_format = self.util.config_val("FEATS", "store_format", "pkl") storage = f"{store}{self.name}.{store_format}" - extract = self.util.config_val("FEATS", "needs_feature_extraction", False) + extract = self.util.config_val( + "FEATS", "needs_feature_extraction", False) no_reuse = eval(self.util.config_val("FEATS", "no_reuse", "False")) if extract or no_reuse or not os.path.isfile(storage): if not self.model_initialized: self.init_model() - self.util.debug("extracting clap embeddings, this might take a while...") + self.util.debug( + "extracting clap embeddings, this might take a while...") emb_series = pd.Series(index=self.data_df.index, dtype=object) length = len(self.data_df.index) for idx, (file, start, end) in enumerate( @@ -51,7 +54,8 @@ def extract(self): ) emb = self.get_embeddings(signal, sampling_rate) emb_series[idx] = emb - self.df = pd.DataFrame(emb_series.values.tolist(), index=self.data_df.index) + self.df = pd.DataFrame( + emb_series.values.tolist(), index=self.data_df.index) self.util.write_store(self.df, storage, store_format) try: glob_conf.config["DATA"]["needs_feature_extraction"] = "false" diff --git a/nkululeko/feat_extract/feats_hubert.py b/nkululeko/feat_extract/feats_hubert.py index 9f3fc917..4a63dbe9 100644 --- a/nkululeko/feat_extract/feats_hubert.py +++ b/nkululeko/feat_extract/feats_hubert.py @@ -1,6 +1,7 @@ # feats_hubert.py # HuBERT feature extractor for Nkululeko -# example feat_type = "hubert-large-ll60k", "hubert-xlarge-ll60k" +# example feat_type = "hubert-large-ll60k", "hubert-xlarge-ll60k", +# "hubert-base-ls960", hubert-large-ls960-ft", "hubert-xlarge-ls960-ft" import os @@ -22,7 +23,7 @@ class Hubert(Featureset): def __init__(self, name, data_df, feat_type): """Constructor. is_train is needed to distinguish from test/dev sets, because they use the codebook from the training""" - super().__init__(name, data_df) + super().__init__(name, data_df, feat_type) # check if device is not set, use cuda if available cuda = "cuda" if torch.cuda.is_available() else "cpu" self.device = self.util.config_val("MODEL", "device", cuda) diff --git a/nkululeko/feat_extract/feats_oxbow.py b/nkululeko/feat_extract/feats_oxbow.py index acab3e74..779fc4e2 100644 --- a/nkululeko/feat_extract/feats_oxbow.py +++ b/nkululeko/feat_extract/feats_oxbow.py @@ -10,9 +10,10 @@ class Openxbow(Featureset): """Class to extract openXBOW processed opensmile features (https://github.com/openXBOW)""" - def __init__(self, name, data_df, is_train=False): + def __init__(self, name, data_df, feats_type, is_train=False): """Constructor. is_train is needed to distinguish from test/dev sets, because they use the codebook from the training""" - super().__init__(name, data_df) + super().__init__(name, data_df, feats_type) + self.feats_types = feats_type self.is_train = is_train def extract(self): @@ -21,11 +22,13 @@ def extract(self): self.feature_set = eval(f"opensmile.FeatureSet.{self.featset}") store = self.util.get_path("store") storage = f"{store}{self.name}_{self.featset}.pkl" - extract = self.util.config_val("FEATS", "needs_feature_extraction", False) + extract = self.util.config_val( + "FEATS", "needs_feature_extraction", False) no_reuse = eval(self.util.config_val("FEATS", "no_reuse", "False")) if extract or no_reuse or not os.path.isfile(storage): # extract smile features first - self.util.debug("extracting openSmile features, this might take a while...") + self.util.debug( + "extracting openSmile features, this might take a while...") smile = opensmile.Smile( feature_set=self.feature_set, feature_level=opensmile.FeatureLevel.LowLevelDescriptors, @@ -48,7 +51,13 @@ def extract(self): # save the smile features smile_df.to_csv(lld_name, sep=";", header=False) # get the path of the xbow java jar file - xbow_path = self.util.config_val("FEATS", "xbow.model", "../openXBOW/") + xbow_path = self.util.config_val( + "FEATS", "xbow.model", "openXBOW") + # check if JAR file exist + if not os.path.isfile(f"{xbow_path}/openXBOW.jar"): + # download using wget if not exist and locate in xbow_path + os.system( + f"git clone https://github.com/openXBOW/openXBOW") # get the size of the codebook size = self.util.config_val("FEATS", "size", 500) # get the number of assignements @@ -57,16 +66,12 @@ def extract(self): if self.is_train: # store the codebook os.system( - f"java -jar {xbow_path}openXBOW.jar -i" - f" {lld_name} -standardizeInput -log -o" - f" {xbow_name} -size {size} -a {assignments} -B" - f" {codebook_name}" + f"java -jar {xbow_path}/openXBOW.jar -i {lld_name} -standardizeInput -log -o {xbow_name} -size {size} -a {assignments} -B {codebook_name}" ) else: # use the codebook os.system( - f"java -jar {xbow_path}openXBOW.jar -i {lld_name} " - f" -o {xbow_name} -b {codebook_name}" + f"java -jar {xbow_path}/openXBOW.jar -i {lld_name} -o {xbow_name} -b {codebook_name}" ) # read in the result from disk xbow_df = pd.read_csv(xbow_name, sep=";", header=None) diff --git a/nkululeko/feat_extract/feats_praat.py b/nkululeko/feat_extract/feats_praat.py index 369e3691..6bc69049 100644 --- a/nkululeko/feat_extract/feats_praat.py +++ b/nkululeko/feat_extract/feats_praat.py @@ -18,18 +18,20 @@ class PraatSet(Featureset): """ - def __init__(self, name, data_df): - super().__init__(name, data_df) + def __init__(self, name, data_df, feats_type): + super().__init__(name, data_df, feats_type) def extract(self): """Extract the features based on the initialized dataset or re-open them when found on disk.""" store = self.util.get_path("store") store_format = self.util.config_val("FEATS", "store_format", "pkl") storage = f"{store}{self.name}.{store_format}" - extract = self.util.config_val("FEATS", "needs_feature_extraction", False) + extract = self.util.config_val( + "FEATS", "needs_feature_extraction", False) no_reuse = eval(self.util.config_val("FEATS", "no_reuse", "False")) if extract or no_reuse or not os.path.isfile(storage): - self.util.debug("extracting Praat features, this might take a while...") + self.util.debug( + "extracting Praat features, this might take a while...") self.df = feinberg_praat.compute_features(self.data_df.index) self.df = self.df.set_index(self.data_df.index) for i, col in enumerate(self.df.columns): @@ -52,7 +54,8 @@ def extract(self): self.df = self.df.astype(float) def extract_sample(self, signal, sr): - import audiofile, audformat + import audiofile + import audformat tmp_audio_names = ["praat_audio_tmp.wav"] audiofile.write(tmp_audio_names[0], signal, sr) diff --git a/nkululeko/feat_extract/feats_trill.py b/nkululeko/feat_extract/feats_trill.py index b2ccc091..b3c42b4d 100644 --- a/nkululeko/feat_extract/feats_trill.py +++ b/nkululeko/feat_extract/feats_trill.py @@ -1,4 +1,5 @@ # feats_trill.py +import tensorflow_hub as hub import os import tensorflow as tf from numpy.core.numeric import tensordot @@ -11,7 +12,6 @@ # Import TF 2.X and make sure we're running eager. assert tf.executing_eagerly() -import tensorflow_hub as hub class TRILLset(Featureset): @@ -20,7 +20,7 @@ class TRILLset(Featureset): """https://ai.googleblog.com/2020/06/improving-speech-representations-and.html""" # Initialization of the class - def __init__(self, name, data_df): + def __init__(self, name, data_df, feats_type): """ Initialize the class with name, data and Util instance Also loads the model from hub @@ -31,7 +31,7 @@ def __init__(self, name, data_df): :type data_df: DataFrame :return: None """ - super().__init__(name, data_df) + super().__init__(name, data_df, feats_type) # Load the model from the configured path model_path = self.util.config_val( "FEATS", @@ -39,20 +39,24 @@ def __init__(self, name, data_df): "https://tfhub.dev/google/nonsemantic-speech-benchmark/trill/3", ) self.module = hub.load(model_path) + self.feats_type = feats_type def extract(self): store = self.util.get_path("store") storage = f"{store}{self.name}.pkl" - extract = self.util.config_val("FEATS", "needs_feature_extraction", False) + extract = self.util.config_val( + "FEATS", "needs_feature_extraction", False) no_reuse = eval(self.util.config_val("FEATS", "no_reuse", "False")) if extract or no_reuse or not os.path.isfile(storage): - self.util.debug("extracting TRILL embeddings, this might take a while...") + self.util.debug( + "extracting TRILL embeddings, this might take a while...") emb_series = pd.Series(index=self.data_df.index, dtype=object) length = len(self.data_df.index) for idx, file in enumerate(tqdm(self.data_df.index.get_level_values(0))): emb = self.getEmbeddings(file) emb_series[idx] = emb - self.df = pd.DataFrame(emb_series.values.tolist(), index=self.data_df.index) + self.df = pd.DataFrame( + emb_series.values.tolist(), index=self.data_df.index) self.df.to_pickle(storage) try: glob_conf.config["DATA"]["needs_feature_extraction"] = "false" diff --git a/nkululeko/feat_extract/feats_wav2vec2.py b/nkululeko/feat_extract/feats_wav2vec2.py index d4fa00b8..6ace2f73 100644 --- a/nkululeko/feat_extract/feats_wav2vec2.py +++ b/nkululeko/feat_extract/feats_wav2vec2.py @@ -1,5 +1,11 @@ -# feats_wav2vec2.py -# feat_types example = wav2vec2-large-robust-ft-swbd-300h +""" feats_wav2vec2.py +feat_types example = [wav2vec2-large-robust-ft-swbd-300h, +wav2vec2-xls-r-2b, wav2vec2-large, wav2vec2-large-xlsr-53, wav2vec2-base] + +Complete list: https://huggingface.co/facebook?search_models=wav2vec2 +Currently only supports wav2vec2 +""" + import os from tqdm import tqdm import pandas as pd @@ -16,11 +22,11 @@ class Wav2vec2(Featureset): def __init__(self, name, data_df, feat_type): """Constructor. is_train is needed to distinguish from test/dev sets, because they use the codebook from the training""" - super().__init__(name, data_df) + super().__init__(name, data_df, feat_type) cuda = "cuda" if torch.cuda.is_available() else "cpu" self.device = self.util.config_val("MODEL", "device", cuda) self.model_initialized = False - if feat_type == "wav2vec" or feat_type == "wav2vec2": + if feat_type == "wav2vec2": self.feat_type = "wav2vec2-large-robust-ft-swbd-300h" else: self.feat_type = feat_type @@ -33,7 +39,8 @@ def init_model(self): ) config = transformers.AutoConfig.from_pretrained(model_path) layer_num = config.num_hidden_layers - hidden_layer = int(self.util.config_val("FEATS", "wav2vec2.layer", "0")) + hidden_layer = int(self.util.config_val( + "FEATS", "wav2vec2.layer", "0")) config.num_hidden_layers = layer_num - hidden_layer self.util.debug(f"using hidden layer #{config.num_hidden_layers}") self.processor = Wav2Vec2FeatureExtractor.from_pretrained(model_path) @@ -48,7 +55,8 @@ def extract(self): """Extract the features or load them from disk if present.""" store = self.util.get_path("store") storage = f"{store}{self.name}.pkl" - extract = self.util.config_val("FEATS", "needs_feature_extraction", False) + extract = self.util.config_val( + "FEATS", "needs_feature_extraction", False) no_reuse = eval(self.util.config_val("FEATS", "no_reuse", "False")) if extract or no_reuse or not os.path.isfile(storage): if not self.model_initialized: @@ -69,7 +77,8 @@ def extract(self): emb = self.get_embeddings(signal, sampling_rate, file) emb_series[idx] = emb # print(f"emb_series shape: {emb_series.shape}") - self.df = pd.DataFrame(emb_series.values.tolist(), index=self.data_df.index) + self.df = pd.DataFrame( + emb_series.values.tolist(), index=self.data_df.index) # print(f"df shape: {self.df.shape}") self.df.to_pickle(storage) try: diff --git a/nkululeko/feat_extract/featureset.py b/nkululeko/feat_extract/featureset.py index 4dfda6d6..810761cb 100644 --- a/nkululeko/feat_extract/featureset.py +++ b/nkululeko/feat_extract/featureset.py @@ -7,13 +7,15 @@ class Featureset: name = "" # designation - df = None # pandas dataframe to store the features (and indexed with the data from the sets) + df = None # pandas dataframe to store the features + # (and indexed with the data from the sets) data_df = None # dataframe to get audio paths - def __init__(self, name, data_df): + def __init__(self, name, data_df, feats_type): self.name = name self.data_df = data_df self.util = Util("featureset") + self.feats_types = feats_type def extract(self): pass @@ -23,7 +25,8 @@ def filter(self): self.df = self.df[self.df.index.isin(self.data_df.index)] try: # use only some features - selected_features = ast.literal_eval(glob_conf.config["FEATS"]["features"]) + selected_features = ast.literal_eval( + glob_conf.config["FEATS"]["features"]) self.util.debug(f"selecting features: {selected_features}") sel_feats_df = pd.DataFrame() hit = False diff --git a/nkululeko/feature_extractor.py b/nkululeko/feature_extractor.py index 468dc5ea..9cc573c6 100644 --- a/nkululeko/feature_extractor.py +++ b/nkululeko/feature_extractor.py @@ -65,16 +65,20 @@ def _get_feat_extractor_class(self, feats_type): elif feats_type == "spectra": from nkululeko.feat_extract.feats_spectra import Spectraloader - return Spectraloader + elif feats_type == "trill": from nkululeko.feat_extract.feats_trill import TRILLset - return TRILLset + elif feats_type.startswith( - ("wav2vec", "hubert", "wavlm", "spkrec", "whisper")): + ("wav2vec2", "hubert", "wavlm", "spkrec", "whisper")): return self._get_feat_extractor_by_prefix(feats_type) + elif feats_type == "xbow": + from nkululeko.feat_extract.feats_oxbow import Openxbow + return Openxbow + elif feats_type in ( "audmodel", "auddim",