Skip to content

Commit

Permalink
Merge pull request felixbur#123 from bagustris/master
Browse files Browse the repository at this point in the history
Make base model for finetuning as variable for INI file
  • Loading branch information
felixbur authored May 28, 2024
2 parents 16c2e59 + b395870 commit 8714bce
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 22 deletions.
6 changes: 4 additions & 2 deletions ini_file.md
Original file line number Diff line number Diff line change
Expand Up @@ -299,10 +299,12 @@
* batch_size = 8
* **num_workers**: Number of parallel processes for neural nets
* num_workers = 5
* **device**: For torch/huggingface models: select your GPU if you have one
* device = cpu
* **device**: For torch/huggingface models: select your GPU number if you have one. Values are either "cpu" or GPU ids (e.g., 0, 1 or both "0,1"). By default, the GPU/CUDA is used if available, otherwise is CPU.
* device = 0
* **patience**: Number of epochs to wait if the result gets better (for early stopping)
* patience = 5
* **pretrained_model**: Base model for finetuning/transfer learning. Variants of wav2vec2, Hubert, and WavLM are tested to work. Default is facebook/wav2vec2-large-robust-ft-swbd-300h.
* pretrained_model = microsoft/wavlm-base

### EXPL
* **model**: Which model to use to estimate feature importance.
Expand Down
52 changes: 33 additions & 19 deletions nkululeko/models/model_tuned.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,18 @@
import pickle
import typing

import audeer
import audiofile
import audmetric
import datasets
import numpy as np
import pandas as pd
import torch
import transformers
from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2Model
from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2PreTrainedModel

import audeer
import audiofile
import audmetric
from transformers.models.wav2vec2.modeling_wav2vec2 import (
Wav2Vec2Model,
Wav2Vec2PreTrainedModel,
)

import nkululeko.glob_conf as glob_conf
from nkululeko.models.model import Model as BaseModel
Expand All @@ -37,25 +38,32 @@ def __init__(self, df_train, df_test, feats_train, feats_test):
self.target = glob_conf.config["DATA"]["target"]
labels = glob_conf.labels
self.class_num = len(labels)
device = self.util.config_val("MODEL", "device", "cpu")
# device = self.util.config_val("MODEL", "device", "cpu")
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.batch_size = int(self.util.config_val("MODEL", "batch_size", "8"))
if device != "cpu":
self.util.debug(f"running on device {device}")
# self.device_id = self.util.config_val("MODEL", "device_id", "0")
if self.device != "cpu":
self.util.debug(f"running on device {self.device}")
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = device
os.environ["CUDA_VISIBLE_DEVICES"] = self.device # self.device
self.df_train, self.df_test = df_train, df_test
self.epoch_num = int(self.util.config_val("EXP", "epochs", 1))

self._init_model()

def _init_model(self):
model_path = "facebook/wav2vec2-large-robust-ft-swbd-300h"
pretrained_model = self.util.config_val(
"MODEL", "pretrained_model", model_path)
self.num_layers = None
self.sampling_rate = 16000
self.max_duration_sec = 8.0
self.accumulation_steps = 4
# create dataset

# print finetuning information via debug
self.util.debug(f"Finetuning from model: {pretrained_model}")

# create dataset
dataset = {}
target_name = glob_conf.target
data_sources = {
Expand All @@ -79,15 +87,18 @@ def _init_model(self):
le = glob_conf.label_encoder
mapping = dict(zip(le.classes_, range(len(le.classes_))))
target_mapping = {k: int(v) for k, v in mapping.items()}
target_mapping_reverse = {value: key for key, value in target_mapping.items()}
target_mapping_reverse = {
value: key for key,
value in target_mapping.items()}

self.config = transformers.AutoConfig.from_pretrained(
model_path,
pretrained_model,
num_labels=len(target_mapping),
label2id=target_mapping,
id2label=target_mapping_reverse,
finetuning_task=target_name,
)

if self.num_layers is not None:
self.config.num_hidden_layers = self.num_layers
setattr(self.config, "sampling_rate", self.sampling_rate)
Expand All @@ -113,7 +124,7 @@ def _init_model(self):
assert self.processor.feature_extractor.sampling_rate == self.sampling_rate

self.model = Model.from_pretrained(
model_path,
pretrained_model,
config=self.config,
)
self.model.freeze_feature_extractor()
Expand Down Expand Up @@ -207,19 +218,17 @@ def train(self):
train_weights /= train_weights.sum()
self.util.debug("train weights: {train_weights}")
criterion = torch.nn.CrossEntropyLoss(
weight=torch.Tensor(train_weights).to("cuda"),
weight=torch.Tensor(train_weights).to(self.device),
)
# criterion = torch.nn.CrossEntropyLoss()

class Trainer(transformers.Trainer):

def compute_loss(
self,
model,
inputs,
return_outputs=False,
):

targets = inputs.pop("labels").squeeze()
targets = targets.type(torch.long)

Expand All @@ -246,7 +255,7 @@ def compute_loss(
gradient_accumulation_steps=self.accumulation_steps,
evaluation_strategy="steps",
num_train_epochs=self.epoch_num,
fp16=True,
fp16=self.device == "cuda",
save_steps=num_steps,
eval_steps=num_steps,
logging_steps=num_steps,
Expand Down Expand Up @@ -368,6 +377,9 @@ class Model(Wav2Vec2PreTrainedModel):

def __init__(self, config):

if not hasattr(config, 'add_adapter'):
setattr(config, 'add_adapter', False)

super().__init__(config)

self.wav2vec2 = Wav2Vec2Model(config)
Expand Down Expand Up @@ -462,7 +474,9 @@ def forward(
mean = input_values.mean()

# var = input_values.var()
# raises: onnxruntime.capi.onnxruntime_pybind11_state.NotImplemented: [ONNXRuntimeError] : 9 : NOT_IMPLEMENTED : Could not find an implementation for the node ReduceProd_3:ReduceProd(11)
# raises: onnxruntime.capi.onnxruntime_pybind11_state.NotImplemented:
# [ONNXRuntimeError] : 9 : NOT_IMPLEMENTED : Could not find an
# implementation for the node ReduceProd_3:ReduceProd(11)

var = torch.square(input_values - mean).mean()
input_values = (input_values - mean) / torch.sqrt(var + 1e-7)
Expand Down
3 changes: 2 additions & 1 deletion tests/exp_emodb_finetune.ini
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[EXP]
root = ./tests/results/
name = test_pretrain
name = test_pretrain_1
runs = 1
epochs = 10
save = True
Expand All @@ -18,3 +18,4 @@ type = []
type = finetune
device = 1
batch_size = 8
pretrained_model = microsoft/wavlm-base

0 comments on commit 8714bce

Please sign in to comment.