Skip to content

Latest commit

 

History

History
173 lines (129 loc) · 6.49 KB

inference.md

File metadata and controls

173 lines (129 loc) · 6.49 KB

📘 Guide to Inference 📘

Please refer to this section for basic inference methods.

If you want to use the inference.py scripts, rather than using the utmosv2 library, please install some additional dependencies:

pip install --upgrade pip  # enable PEP 660 support
pip install -e .[optional]

Note

If you are using zsh, make sure to escape the square brackets like this:

pip install -e '.[optional]'

📌 Data-domain ID for the MOS Prediction 📌

By default, the data-domain ID for the MOS prediction is set to sarulab-data. To specify this and make predictions, you can specify the --predict_dataset flag with the following options:

  • sarulab (default)
  • bvcc
  • blizzard2008, blizzard2009, blizzard2010-EH1, blizzard2010-EH2, blizzard2010-ES1, blizzard2010-ES3, blizzard2011
  • somos

For example, to make predictions with the data-domain ID set to somos, use the following command:

  • If you are using in your Python code:

    mos = model.predict(input_dir="/path/to/wav/dir/", predict_dataset="somos")
  • If you are using the inference script:

    python inference.py --input_dir /path/to/wav/dir/ --out_path /path/to/output/file.csv --predict_dataset somos

✂️ Predicting Only a Subset of Files ✂️

By default, all .wav files in the --input_dir are used for prediction. To specify only a subset of these files, use the --val_list_path flag:

  • If you are using in your Python code:

    mos = model.predict(input_dir="/path/to/wav/dir/", val_list_path="/path/to/your/val/list.txt")

    or, you can provide the list directly:

    mos = model.predict(
        input_dir="/path/to/wav/dir/",
        val_list=["sys00691-utt0682e32", "sys00691-utt31fd854", "sys00691-utt33a4826", ...]
    )
  • If you are using the inference script:

    python inference.py --input_dir /path/to/wav/dir/ --out_path /path/to/output/file.csv --val_list_path /path/to/your/val/list.txt

The list of .wav files specified here should contain utt-id separated by new lines, as shown below. The file extension .wav is optional and can be included or omitted.

sys00691-utt0682e32
sys00691-utt31fd854
sys00691-utt33a4826
...

📈 Specify the Fold and the Number of Repetitions for More Accurate Predictions 📈

In the paper, predictions are made repeatedly for five randomly selected frames of the input speech waveform for all five folds, and the average is used. To specify this for more accurate predictions, do the following:

  • If you are using in your Python code:

    model = utmosv2.create_model(fold=2)
    mos = model.predict(input_dir="/path/to/wav/dir/", num_repetitions=5)
  • If you are using the inference script:

    python inference.py --input_dir /path/to/wav/dir/ --out_path /path/to/output/file.csv --fold 2 --num_repetitions 5

Here, the --fold option specifies the fold number to be used. If set to -1, all folds will be used. The --num_repetitions option specifies the number of repetitions.

🎯 Specify a Configuration File 🎯

To specify a configuration file for predictions, do the following:

  • If you are using in your Python code:

    model = utmosv2.create_model(config="configuration_file_name")
    mos = model.predict(input_dir="/path/to/wav/dir/")
  • If you are using the inference script:

    python inference.py --config configuration_file_name --input_dir /path/to/wav/dir/ --out_path /path/to/output/file.csv

By default, fusion_stage3, which is the entire model of UTMOSv2, is used.

⚖️ Make Predictions Using Your Own Weights ⚖️

If you are using in your Python code, specify the checkpoint path with the checkpoint_path argument to make predictions using your own weights:

model = utmosv2.create_model(checkpoint_path="/path/to/your/weight.pth")
mos = model.predict(input_dir="/path/to/wav/dir/")

If you are using the inference script, specify the path to the weights with the --weight option to make predictions using your own weights:

python inference.py --input_dir /path/to/wav/dir/ --out_path /path/to/output/file.csv --weight /path/to/your/weight.pth

The checkpoint_path argument and --weight option can specify either the configuration file name or the path to the weight .pth file. By default, models/{config_name}/fold{now_fold}_s{seed}_best_model.pth is used.

The weights must be compatible with the model specified by config argument or --config_name option.

Note

In this case, the same weights specified will be used for all folds. To use different weights for each fold, you can do the following:

for i in {0..5}; do
    python inference.py --input_path /path/to/wav/file.wav --out_path /path/to/output/file.csv --weight /path/to/your/weight_fold${i}.pth --fold $i
done