Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add docstrings #27

Merged
merged 1 commit into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions utmosv2/_core/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,31 @@ def create_model(
checkpoint_path: Path | str | None = None,
seed: int = 42,
):
"""
Create a UTMOSv2 model with the specified configuration and optional pretrained weights.

Args:
pretrained (bool):
If True, loads pretrained weights. Defaults to True.
config (str):
The configuration name to load for the model. Defaults to "fusion_stage3".
fold (int):
The fold number for the pretrained weights (used for model selection). Defaults to 0.
checkpoint_path (Path | str | None):
Path to a specific model checkpoint. If None, the checkpoint downloaded from GitHub is used. Defaults to None.
seed (int):
The seed used for model training to select the correct checkpoint. Defaults to 42.

Returns:
UTMOSv2Model: The initialized UTMOSv2 model.

Raises:
FileNotFoundError: If the specified checkpoint file is not found.

Notes:
- The configuration is dynamically loaded from `utmosv2.config`.
- If `pretrained` is True and `checkpoint_path` is not provided, the function attempts to download pretrained weights from GitHub.
"""
_cfg = importlib.import_module(f"utmosv2.config.{config}")
# Avoid issues with pickling `types.ModuleType`,
# making it easier to use with multiprocessing, DDP, etc.
Expand Down
38 changes: 38 additions & 0 deletions utmosv2/_core/model/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@


class UTMOSv2ModelMixin(abc.ABC):
"""
Abstract mixin for UTMOSv2 models, providing a template for prediction.
"""

@property
@abc.abstractmethod
def _cfg(self) -> SimpleNamespace:
Expand Down Expand Up @@ -44,6 +48,40 @@ def predict(
num_repetitions: int = 1,
verbose: bool = True,
) -> float | list[dict[str, str | float]]:
"""
Predict the MOS (Mean Opinion Score) of audio files.

Args:
input_path (Path | str | None):
Path to a single audio file (`.wav`) to predict MOS.
Either `input_path` or `input_dir` must be provided, but not both.
input_dir (Path | str | None):
Path to a directory of `.wav` files to predict MOS.
Either `input_path` or `input_dir` must be provided, but not both.
val_list (list[str] | None):
List of filenames to include for prediction. Defaults to None.
val_list_path (Path | str | None):
Path to a text file containing a list of filenames to include for prediction. Defaults to None.
predict_dataset (str):
Name of the dataset to associate with the prediction. Defaults to "sarulab".
device (str | torch.device):
Device to use for prediction (e.g., "cuda:0" or "cpu"). Defaults to "cuda:0".
num_workers (int):
Number of workers for data loading. Defaults to 4.
batch_size (int):
Batch size for the data loader. Defaults to 16.
num_repetitions (int):
Number of prediction repetitions to average results. Defaults to 1.
verbose (bool):
Whether to display progress during prediction. Defaults to True.

Returns:
float: If the `input_path` is specified, returns the predicted MOS.
list[dict[str, str | float]]: If the `input_dir` is specified, returns a list of dicts containing file paths and predicted MOS scores.

Raises:
ValueError: If both `input_path` and `input_dir` are provided, or if neither is provided.
"""
if not ((input_path is None) ^ (input_dir is None)):
raise ValueError(
"Either `input_path` or `input_dir` must be provided, but not both."
Expand Down
14 changes: 14 additions & 0 deletions utmosv2/_core/model/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,21 @@


class UTMOSv2Model(UTMOSv2ModelMixin):
"""
UTMOSv2Model class that wraps different models specified by the configuration.
This class allows for flexible model selection and provides a unified interface for evaluation, calling, and prediction.
"""

def __init__(self, cfg: SimpleNamespace | ModuleType):
"""
Initialize the UTMOSv2Model with a specified configuration.

Args:
cfg (SimpleNamespace | ModuleType): Configuration object that contains the model configuration.

Raises:
ValueError: If the model name specified in the configuration is not recognized.
"""
models = {
"multi_spec_ext": MultiSpecExtModel,
"multi_specv2": MultiSpecModelV2,
Expand Down
45 changes: 45 additions & 0 deletions utmosv2/dataset/multi_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,28 @@


class MultiSpecDataset(BaseDataset):
"""
Dataset class for mel-spectrogram feature extractor. This class is responsible for
loading audio data, generating multiple spectrograms for each sample, and
applying the necessary transformations.

Args:
cfg (SimpleNamespace): The configuration object containing dataset and model settings.
data (list[DatasetSchema] | pd.DataFrame): The dataset containing file paths and labels.
phase (str): The phase of the dataset, either "train" or any other phase (e.g., "valid").
transform (Callable[[torch.Tensor], torch.Tensor] | None): Transformation function to apply to spectrograms.
"""

def __getitem__(self, idx):
"""
Get the spectrogram and target MOS for a given index.

Args:
idx (int): Index of the sample.

Returns:
tuple: The spectrogram (torch.Tensor) and target MOS (torch.Tensor) for the sample.
"""
row = self.data[idx] if isinstance(self.data, list) else self.data.iloc[idx]
file = row.file_path
y = load_audio(self.cfg, file)
Expand Down Expand Up @@ -56,6 +77,20 @@ def __getitem__(self, idx):


class MultiSpecExtDataset(MultiSpecDataset):
"""
Dataset class for mel-spectrogram feature extractor with data-domain embedding.

Args:
cfg (SimpleNamespace | ModuleType):
The configuration object containing dataset and model settings.
data (pd.DataFrame | list[DatasetSchema]):
The dataset containing file paths and labels.
phase (str):
The phase of the dataset, either "train" or any other phase (e.g., "valid").
transform (Callable[[torch.Tensor], torch.Tensor] | None):
Transformation function to apply to spectrograms.
"""

def __init__(
self,
cfg,
Expand All @@ -67,6 +102,16 @@ def __init__(
self.dataset_map = get_dataset_map(cfg)

def __getitem__(self, idx):
"""
Get the spectrogram, data-domain embedding, and target MOS for a given index.

Args:
idx (int): Index of the sample.

Returns:
tuple: A tuple containing the generated spectrogram (torch.Tensor), data-domain embedding (torch.Tensor),
and target MOS (torch.Tensor).
"""
spec, target = super().__getitem__(idx)
row = self.data[idx] if isinstance(self.data, list) else self.data.iloc[idx]

Expand Down
42 changes: 42 additions & 0 deletions utmosv2/dataset/ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,28 @@


class SSLDataset(BaseDataset):
"""
Dataset class for SSL (Self-Supervised Learning) feature extractor.
This class handles audio loading, extending, and random selection of a segment from the audio.

Args:
cfg (SimpleNamespace | ModuleType):
The configuration object containing dataset and model settings.
data (pd.DataFrame | list[DatasetSchema]):
The dataset containing file paths and MOS labels.
phase (str):
The phase of the dataset, either "train" or any other phase (e.g., "valid").
"""

def __getitem__(self, idx):
"""
Get the processed audio, and target MOS for a given index.

Args:
idx (int): Index of the sample.
Returns:
tuple: A tuple containing the processed audio (torch.Tensor), and target MOS (torch.Tensor).
"""
row = self.data[idx] if isinstance(self.data, list) else self.data.iloc[idx]
file = row.file_path
y = load_audio(self.cfg, file)
Expand All @@ -35,11 +56,32 @@ def __getitem__(self, idx):


class SSLExtDataset(SSLDataset):
"""
Dataset class for SSL (Self-Supervised Learning) feature extractor with data-domein embedding.

Args:
cfg (SimpleNamespace | ModuleType):
The configuration object containing dataset and model settings.
data (pd.DataFrame | list[DatasetSchema]):
The dataset containing file paths and MOS labels.
phase (str):
The phase of the dataset, either "train" or any other phase (e.g., "valid").
"""

def __init__(self, cfg, data: "pd.DataFrame" | list[DatasetSchema], phase: str):
super().__init__(cfg, data, phase)
self.dataset_map = get_dataset_map(cfg)

def __getitem__(self, idx):
"""
Get the processed audio, data-domain embedding, and target MOS for a given index.

Args:
idx (int): Index of the sample.
Returns:
tuple: A tuple containing the processed audio (torch.Tensor), data-domain embedding (torch.Tensor),
and target MOS (torch.Tensor).
"""
y, target = super().__getitem__(idx)
row = self.data[idx] if isinstance(self.data, list) else self.data.iloc[idx]

Expand Down
26 changes: 26 additions & 0 deletions utmosv2/dataset/ssl_multispec.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,22 @@


class SSLLMultiSpecExtDataset(BaseDataset):
"""
Dataset class that combines both SSL (Self-Supervised Learning) and Multi-Spectrogram datasets.
This dataset uses both SSLExtDataset and MultiSpecDataset to provide different representations
of the same audio sample.

Args:
cfg (SimpleNamespace | ModuleType):
The configuration object containing dataset and model settings.
data (pd.DataFrame | list[DatasetSchema]):
The dataset containing file paths and MOS labels.
phase (str):
The phase of the dataset, either "train" or any other phase (e.g., "valid").
transform (Callable[[torch.Tensor], torch.Tensor] | None):
Transformation function to apply to spectrograms.
"""

def __init__(
self,
cfg,
Expand All @@ -30,6 +46,16 @@ def __len__(self):
return len(self.data)

def __getitem__(self, idx):
"""
Get data for SSL feature extractor, mel-spectrogram feature extractor, data-domain embedding, and target MOS for a given index.

Args:
idx (int): Index of the sample.

Returns:
tuple: data for SSL feature extractor (torch.Tensor), data for mel-spectrogram feature extractor (torch.Tensor),
data-domain id (torch.Tensor), and target MOS (torch.Tensor).
"""
x1, d, target = self.ssl[idx]
x2, _ = self.multi_spec[idx]

Expand Down
44 changes: 44 additions & 0 deletions utmosv2/loss/_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,37 @@


class PairwizeDiffLoss(nn.Module):
"""
Pairwise difference loss function for comparing input and target tensors.
The loss is based on the difference between pairs of inputs and pairs of targets,
with a specified margin and norm ("l1" or "l2_squared").
"""

def __init__(self, margin: float = 0.2, norm: str = "l1"):
"""
Initialize the PairwizeDiffLoss with the specified margin and norm.

Args:
margin (float):
The margin value used for the loss function. Defaults to 0.2.
norm (str):
The norm to use for the difference calculation. Must be "l1" or "l2_squared". Defaults to "l1".
"""
super().__init__()
self.margin = margin
self.norm = norm

def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Compute the pairwise difference loss between input and target tensors.

Args:
input (torch.Tensor): The input tensor.
target (torch.Tensor): The target tensor.

Returns:
torch.Tensor: The computed loss.
"""
s = input.unsqueeze(1) - input.unsqueeze(0)
t = target.unsqueeze(1) - target.unsqueeze(0)
if self.norm not in ["l1", "l2_squared"]:
Expand All @@ -27,11 +52,30 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:


class CombinedLoss(nn.Module):
"""
A combined loss function that allows for multiple loss functions to be weighted and combined.

Args:
weighted_losses (list[tuple[nn.Module, float]]):
A list of loss functions and their associated weights.
"""

def __init__(self, weighted_losses: list[tuple[nn.Module, float]]):
super().__init__()
self.weighted_losses = weighted_losses

def forward(
self, input: torch.Tensor, target: torch.Tensor
) -> list[tuple[float, torch.Tensor]]:
"""
Compute the weighted loss for each loss function in the list.

Args:
input (torch.Tensor): The input tensor.
target (torch.Tensor): The target tensor.

Returns:
list[tuple[float, torch.Tensor]]:
A list of tuples where each contains a weight and the corresponding computed loss.
"""
return [(w, loss(input, target)) for loss, w in self.weighted_losses]
Loading
Loading