Skip to content

Commit

Permalink
add device_id param
Browse files Browse the repository at this point in the history
  • Loading branch information
bagustris committed May 27, 2024
1 parent 7fc73d4 commit 9c0a260
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 3 deletions.
4 changes: 3 additions & 1 deletion ini_file.md
Original file line number Diff line number Diff line change
Expand Up @@ -299,8 +299,10 @@
* batch_size = 8
* **num_workers**: Number of parallel processes for neural nets
* num_workers = 5
* **device**: For torch/huggingface models: select your GPU if you have one
* **device**: For torch/huggingface models: select your GPU if you have one. Values are either "cpu" or "cuda".
* device = cpu
* **device_ids**: For torch/huggingface models: select your GPU if you have multiple. Values are GPU ids (0, 1 or both "0,1").
* device_id = 0
* **patience**: Number of epochs to wait if the result gets better (for early stopping)
* patience = 5
* **pretrained_model**: Base model for finetuning/transfer learning. Variants of wav2vec2, Hubert, and WavLM are tested to work.
Expand Down
3 changes: 2 additions & 1 deletion nkululeko/models/model_tuned.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,11 @@ def __init__(self, df_train, df_test, feats_train, feats_test):
# device = self.util.config_val("MODEL", "device", "cpu")
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.batch_size = int(self.util.config_val("MODEL", "batch_size", "8"))
self.device_id = self.util.config_val("MODEL", "device_id", "0")
if self.device != "cpu":
self.util.debug(f"running on device {self.device}")
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # self.device
os.environ["CUDA_VISIBLE_DEVICES"] = self.device_id # self.device
self.df_train, self.df_test = df_train, df_test
self.epoch_num = int(self.util.config_val("EXP", "epochs", 1))

Expand Down
2 changes: 1 addition & 1 deletion tests/exp_emodb_finetune.ini
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[EXP]
root = ./tests/results/
name = test_pretrain
name = test_pretrain_1
runs = 1
epochs = 10
save = True
Expand Down

0 comments on commit 9c0a260

Please sign in to comment.