From 239eef1dcb74c949acc1004955a4257a417feecf Mon Sep 17 00:00:00 2001 From: quinnouyang <90884224+quinnouyang@users.noreply.github.com> Date: Tue, 3 Dec 2024 23:10:03 -0600 Subject: [PATCH 01/18] Replace setup.py with pyproject.toml and prune requirements --- pyproject.toml | 43 +++++++++++++++++++++++++++++++++++++++++++ requirements.txt | 7 +++---- setup.py | 28 ---------------------------- 3 files changed, 46 insertions(+), 32 deletions(-) create mode 100644 pyproject.toml delete mode 100644 setup.py diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..c914364 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,43 @@ +[build-system] +requires = ["hatchling", "hatch-requirements-txt"] +build-backend = "hatchling.build" + +[project] +name = "timbre" +version = "0" +# dependencies = [] +requires-python = ">=3.8" +authors = [{ name = "Quinn Ouyang", email = "qouyang3+timbre@illinois.edu" }] +# maintainers = [ +# { name = "Quinn qouyang", email = "qouyang3+timbre@illinois.edu" }, +# ] +# description = "" +readme = "README.md" +# license = {file = "LICENSE"} +# keywords = [] +classifiers = [ + "Development Status :: 1 - Planning", + "Intended Audience :: Education", + "Intended Audience :: Science/Research", + # "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3.8", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Multimedia :: Sound/Audio :: Sound Synthesis", +] +dynamic = ["dependencies"] + +[project.urls] +Homepage = "https://github.com/quinnouyang/timbre" +# Documentation = "" +Repository = "https://github.com/quinnouyang/timbre.git" +"Bug Tracker" = "https://github.com/quinnouyang/timbre/issues" +# Changelog = "" + +[tool.hatch.metadata] +allow-direct-references = true + +[tool.hatch.metadata.hooks.requirements_txt] +files = ["requirements.txt"] + +# [tool.hatch.metadata.hooks.requirements_txt.optional-dependencies] +# cli = ["requirements-dev.txt"] diff --git a/requirements.txt b/requirements.txt index 267d60d..1be396d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,9 +1,8 @@ black ipykernel +librosa matplotlib -scipy -tensorboard -torch +pip torchaudio torchvision -tqdm +tqdm \ No newline at end of file diff --git a/setup.py b/setup.py deleted file mode 100644 index 5051897..0000000 --- a/setup.py +++ /dev/null @@ -1,28 +0,0 @@ -from os.path import dirname, join -from pkg_resources import parse_requirements -from setuptools import setup, find_packages - -with open("README.md") as file: - long_description = file.read() - -setup( - name="timbre", - py_modules=["timbre"], - version="0.0", - description="", - author="Quinn Ouyang", - packages=find_packages(exclude=["tests*"]), - install_requires=[ - str(r) - for r in parse_requirements(open(join(dirname(__file__), "requirements.txt"))) - ], - include_package_data=True, - author_email="qouyang3@illinois.edu", - url="https://github.com/quinnouyang/timbre", - package_data={"timbre": ["assets/*", "assets/*/*"]}, - long_description=long_description, - long_description_content_type="text/markdown", - keywords=[], - classifiers=["License :: OSI Approved :: MIT License"], - license="MIT", -) From 63931b4950d7dc28c1a719715acca0cb9dcc8fe3 Mon Sep 17 00:00:00 2001 From: quinnouyang <90884224+quinnouyang@users.noreply.github.com> Date: Wed, 4 Dec 2024 17:15:47 -0600 Subject: [PATCH 02/18] Move data code to package dir --- data/__init__.py | 1 - data/utils/nsynth.py | 117 -------------------------------------- timbre/model/train/run.py | 2 +- 3 files changed, 1 insertion(+), 119 deletions(-) delete mode 100644 data/__init__.py delete mode 100644 data/utils/nsynth.py diff --git a/data/__init__.py b/data/__init__.py deleted file mode 100644 index 5f12e2e..0000000 --- a/data/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .utils.nsynth import NSynthExample, NSynthDataset diff --git a/data/utils/nsynth.py b/data/utils/nsynth.py deleted file mode 100644 index 7ed6dc1..0000000 --- a/data/utils/nsynth.py +++ /dev/null @@ -1,117 +0,0 @@ -import json - -from argparse import ArgumentParser -from contextlib import suppress -from os import path, listdir -from pathlib import Path -from torch import Tensor, float32 -from torch.utils.data import Dataset -from typing import NamedTuple, Literal, Any -from torchaudio import load, transforms - -transform_melspec = transforms.MelSpectrogram(n_fft=512, n_mels=64) - - -def forward_transform(x: Tensor) -> Tensor: - return transform_melspec(x).flatten().to(float32) - - -NSynthExample = NamedTuple( - "NSynthExample", - [ - ("note", int), - ("note_str", str), - ("instrument", int), - ("instrument_str", str), - ("pitch", int), - ("velocity", int), - ("sample_rate", int), - ("audio", Tensor), - ("qualities", list[Literal[0, 1]]), - # ("qualities_str", list[str]), # Exclude because it raises a RuntimeError from the default DataLoader collate_fn when it varies in length - ("instrument_family", int), - ("instrument_family_str", str), - ("instrument_source", int), - ("instrument_source_str", str), - ], -) -"NSynth example [features](https://magenta.tensorflow.org/datasets/nsynth#example-features)" - - -class NSynthDataset(Dataset): - """A subset of the [NSynth dataset](https://magenta.tensorflow.org/datasets/nsynth) - - Parameters - ---------- - `annotations_file` : `str` | `Path` - Path to the JSON file containing the annotations - `audio_dir` : `str` | `Path` - Path to the directory containing the audio files - """ - - def __init__( - self, - annotations_file: str | Path, - audio_dir: str | Path, - ) -> None: - with open(annotations_file, "r") as f: - self.annotations: dict[str, Any] = json.load(f) - self.keys = sorted(self.annotations.keys()) # note_strs - self.audio_dir = audio_dir - self.audio_filenames = sorted( - path.splitext(path.basename(f))[0] for f in listdir(audio_dir) - ) - - # Verify that audio_dir and annotations_file alphabetically map 1:1 - assert ( - self.keys == self.audio_filenames - ), f"Expected every key/note_str from annotations_file to match every filename from audio_dir\n Instead, audio_dir is missing {set(self.keys) - set(self.audio_filenames)} and/or annotations_file is mising {set(self.audio_filenames) - set(self.keys)}" - - def __len__(self): - return len(self.keys) - - def __getitem__(self, i: int) -> Tensor: - """Load the `i`-th example - - Parameters - ---------- - `i` : `int` - Example index (by alphabetical order of the example filename/"note_str") - - Returns - ------- - Flattened magnitude spectrogram of the example - # NSynth example features : `NSynthExample` - # See the [NSynth example features](https://magenta.tensorflow.org/datasets/nsynth#example-features) - """ - annotation: dict = self.annotations[self.keys[i]] - with suppress(KeyError): # Remove the "qualities_str" key, if not already - annotation.pop("qualities_str") - - return forward_transform( - load(path.join(self.audio_dir, self.audio_filenames[i] + ".wav"))[0][0] - ) - - return NSynthExample( - **annotation, - audio=from_numpy( - read(path.join(self.audio_dir, self.audio_filenames[i] + ".wav"))[1] - ) - / 32768.0, # 2**15 (to normalize 16-bit samples to [-1, 1]) - ) - - -if __name__ == "__main__": - p = ArgumentParser() - p.add_argument("--subset-path", type=str, required=True) - - args = p.parse_args() - - print(f"TEST: Loading NSynthDataset and printing its last example") - - D = NSynthDataset( - path.join(args.subset_path, "examples.json"), - path.join(args.subset_path, "audio"), - ) - - print(D[-1]) diff --git a/timbre/model/train/run.py b/timbre/model/train/run.py index f81846f..9e1f1e2 100644 --- a/timbre/model/train/run.py +++ b/timbre/model/train/run.py @@ -1,7 +1,7 @@ from torch.optim.adamw import AdamW from torch.utils.data import DataLoader -from data.utils.nsynth import NSynthDataset +from timbre.datasets.nsynth import NSynthDataset from timbre.model.config.single import ( DATASETS_DIR, BATCH_SIZE, From 9fe09ed7739f7ba7a59c19966ed29b5486114990 Mon Sep 17 00:00:00 2001 From: quinnouyang <90884224+quinnouyang@users.noreply.github.com> Date: Wed, 4 Dec 2024 17:46:02 -0600 Subject: [PATCH 03/18] Refactor single config as untested default --- .gitignore | 9 ++- timbre/config/__init__.py | 0 timbre/config/defaults.py | 101 ++++++++++++++++++++++++++++++++++ timbre/model/config/single.py | 33 ----------- 4 files changed, 105 insertions(+), 38 deletions(-) create mode 100644 timbre/config/__init__.py create mode 100644 timbre/config/defaults.py delete mode 100644 timbre/model/config/single.py diff --git a/.gitignore b/.gitignore index 4e4913d..53146c8 100644 --- a/.gitignore +++ b/.gitignore @@ -3,11 +3,10 @@ __pycache__ .python-version -# Datasets -datasets/ - -# Outputs -runs +# Directories +data/ +runs/ +reference_models/heidenreich/datasets/ # Miscellaneous **/.DS_Store diff --git a/timbre/config/__init__.py b/timbre/config/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/timbre/config/defaults.py b/timbre/config/defaults.py new file mode 100644 index 0000000..108f067 --- /dev/null +++ b/timbre/config/defaults.py @@ -0,0 +1,101 @@ +# Based off of https://github.com/maxrmorrison/deep-learning-project-template/blob/main/NAME/config/defaults.py + +import torch + +from os import cpu_count +from pathlib import Path + +# Configuration name +CONFIG = "timbre" + + +############################################################################### +# Data +############################################################################### + +# Dataset names +DATASETS = ["nsynth"] + +EVALUATION_DATASETS = DATASETS + + +############################################################################### +# Directories +############################################################################### + +PACKAGE_DIR = Path(__file__).parent.parent + +# For assets to bundle with pip release +# ASSETS_DIR = PACKAGE_DIR / "assets" + +ROOT_DIR = PACKAGE_DIR.parent + +# For preprocessed features +# CACHE_DIR = ROOT_DIR.parent / "data" / "cache" + +# For unprocessed datasets +SOURCES_DIR = ROOT_DIR / "data" / "sources" + +# For preprocessed datasets +DATA_DIR = ROOT_DIR / "data" / "datasets" + +# For training and adaptation artifacts +RUNS_DIR = ROOT_DIR / "runs" + +# For evaluation artifacts +# EVAL_DIR = ROOT_DIR.parent / "eval" + + +############################################################################### +# Training +############################################################################### + + +# Batch size per gpu +BATCH_SIZE = 128 + +# Steps between saving checkpoints +CHECKPOINT_INTERVAL = 25000 + +# Training steps +STEPS = 300000 + +n_cpus = cpu_count() or 0 +if not n_cpus: + raise ValueError("Could not determine the number of CPUs") + +# Worker threads for data loading +NUM_WORKERS = int(n_cpus / max(1, torch.cuda.device_count())) + +RANDOM_SEED = 1234 + +# [TODO] Formalize these options + +LEARN_RATE = 1e-3 +WEIGHT_DECAY = 1e-2 +N_EPOCHS = 64 +INPUT_DIM = 16064 +LATENT_DIM = 2 +HIDDEN_DIM = 8064 + +DEVICE = torch.device( + "cuda" + if torch.cuda.is_available() + else "mps" if torch.backends.mps.is_available() else "cpu" +) + +USE_PIN_MEMORY = DEVICE == "cuda" +# DATETIME_NOW = datetime.now().strftime("%Y%m%d-%H%M%S") +# WRITER = SummaryWriter(RUNS_DIR / f"log_{DATETIME_NOW}") + + +############################################################################### +# Evaluation +############################################################################### + + +# Steps between tensorboard logging +# EVALUATION_INTERVAL = 2500 # steps + +# Steps to perform for tensorboard logging +# DEFAULT_EVALUATION_STEPS = 16 diff --git a/timbre/model/config/single.py b/timbre/model/config/single.py deleted file mode 100644 index 53281d9..0000000 --- a/timbre/model/config/single.py +++ /dev/null @@ -1,33 +0,0 @@ -import torch - -from datetime import datetime -from pathlib import Path -from torch.utils.tensorboard.writer import SummaryWriter - - -CONFIG_DIR = Path(__file__).parent -MODEL_DIR = CONFIG_DIR.parent -PROJ_DIR = MODEL_DIR.parent -RUNS_DIR = CONFIG_DIR / "runs" -DATASETS_DIR = PROJ_DIR / "data" / "datasets" - -BATCH_SIZE = 128 -LEARN_RATE = 1e-3 -WEIGHT_DECAY = 1e-2 -N_EPOCHS = 64 -INPUT_DIM = 16064 -LATENT_DIM = 2 -HIDDEN_DIM = 8064 -NUM_WORKERS = 0 # os.cpu_count() or 0 -DEVICE = torch.device( - "cuda" - if torch.cuda.is_available() - else "mps" if torch.backends.mps.is_available() else "cpu" -) -PIN_MEMORY = DEVICE == "cuda" -DATETIME_NOW = datetime.now().strftime("%Y%m%d-%H%M%S") -WRITER = SummaryWriter(RUNS_DIR / f"log_{DATETIME_NOW}") - -print( - f"SINGLE CONFIGURATION\nDevice: {DEVICE}\nBatch size: {BATCH_SIZE}\nConfiguration directory: {CONFIG_DIR.relative_to(PROJ_DIR)}\nRuns directory: {RUNS_DIR.relative_to(PROJ_DIR)}\nData directory: {DATASETS_DIR.relative_to(PROJ_DIR)}\n" -) From 997896b49c9529462708df8b810f76cd54b0965f Mon Sep 17 00:00:00 2001 From: quinnouyang <90884224+quinnouyang@users.noreply.github.com> Date: Wed, 4 Dec 2024 22:06:48 -0600 Subject: [PATCH 04/18] [config] Update to separate utils and not estimate number of workers --- timbre/config/defaults.py | 32 ++++++++++++-------------------- timbre/config/utils.py | 13 +++++++++++++ 2 files changed, 25 insertions(+), 20 deletions(-) create mode 100644 timbre/config/utils.py diff --git a/timbre/config/defaults.py b/timbre/config/defaults.py index 108f067..7a72e36 100644 --- a/timbre/config/defaults.py +++ b/timbre/config/defaults.py @@ -1,12 +1,12 @@ # Based off of https://github.com/maxrmorrison/deep-learning-project-template/blob/main/NAME/config/defaults.py -import torch - -from os import cpu_count from pathlib import Path +from .utils import get_device, should_pin_memory + + # Configuration name -CONFIG = "timbre" +CONFIG = "default" ############################################################################### @@ -16,7 +16,7 @@ # Dataset names DATASETS = ["nsynth"] -EVALUATION_DATASETS = DATASETS +# EVALUATION_DATASETS = DATASETS ############################################################################### @@ -55,36 +55,28 @@ BATCH_SIZE = 128 # Steps between saving checkpoints -CHECKPOINT_INTERVAL = 25000 +# CHECKPOINT_INTERVAL = 25000 # Training steps -STEPS = 300000 - -n_cpus = cpu_count() or 0 -if not n_cpus: - raise ValueError("Could not determine the number of CPUs") +# STEPS = 300000 # Worker threads for data loading -NUM_WORKERS = int(n_cpus / max(1, torch.cuda.device_count())) +NUM_WORKERS = 8 -RANDOM_SEED = 1234 +# RANDOM_SEED = 1234 # [TODO] Formalize these options LEARN_RATE = 1e-3 WEIGHT_DECAY = 1e-2 -N_EPOCHS = 64 +NUM_EPOCHS = 64 INPUT_DIM = 16064 LATENT_DIM = 2 HIDDEN_DIM = 8064 -DEVICE = torch.device( - "cuda" - if torch.cuda.is_available() - else "mps" if torch.backends.mps.is_available() else "cpu" -) +DEVICE = get_device() -USE_PIN_MEMORY = DEVICE == "cuda" +USE_PIN_MEMORY = should_pin_memory() # DATETIME_NOW = datetime.now().strftime("%Y%m%d-%H%M%S") # WRITER = SummaryWriter(RUNS_DIR / f"log_{DATETIME_NOW}") diff --git a/timbre/config/utils.py b/timbre/config/utils.py new file mode 100644 index 0000000..e56e961 --- /dev/null +++ b/timbre/config/utils.py @@ -0,0 +1,13 @@ +import torch + + +def get_device() -> torch.device: + return torch.device( + "cuda" + if torch.cuda.is_available() + else "mps" if torch.backends.mps.is_available() else "cpu" + ) + + +def should_pin_memory() -> bool: + return get_device() == "cuda" From 36929949dd0276655f1044ede178d24362b52e74 Mon Sep 17 00:00:00 2001 From: quinnouyang <90884224+quinnouyang@users.noreply.github.com> Date: Thu, 5 Dec 2024 00:07:05 -0600 Subject: [PATCH 05/18] Rename dataset dirs --- .gitignore | 2 +- timbre/config/defaults.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.gitignore b/.gitignore index 53146c8..070f139 100644 --- a/.gitignore +++ b/.gitignore @@ -4,7 +4,7 @@ __pycache__ .python-version # Directories -data/ +datasets/ runs/ reference_models/heidenreich/datasets/ diff --git a/timbre/config/defaults.py b/timbre/config/defaults.py index 7a72e36..52b5d41 100644 --- a/timbre/config/defaults.py +++ b/timbre/config/defaults.py @@ -6,7 +6,7 @@ # Configuration name -CONFIG = "default" +CONFIG = "defaults" ############################################################################### @@ -34,10 +34,10 @@ # CACHE_DIR = ROOT_DIR.parent / "data" / "cache" # For unprocessed datasets -SOURCES_DIR = ROOT_DIR / "data" / "sources" +SOURCES_DIR = ROOT_DIR / "datasets" / "sources" # For preprocessed datasets -DATA_DIR = ROOT_DIR / "data" / "datasets" +PREPROCESSED_DIR = ROOT_DIR / "datasets" / "preprocessed" # For training and adaptation artifacts RUNS_DIR = ROOT_DIR / "runs" From 6c078854addaae72bb0255050c555885ebaaf876 Mon Sep 17 00:00:00 2001 From: quinnouyang <90884224+quinnouyang@users.noreply.github.com> Date: Thu, 5 Dec 2024 01:16:29 -0600 Subject: [PATCH 06/18] Add init files, use yapecs for config --- timbre/__init__.py | 22 ++++++++++++++++++++-- timbre/data/__init__.py | 0 timbre/model/__init__.py | 0 timbre/train/__init__.py | 0 4 files changed, 20 insertions(+), 2 deletions(-) create mode 100644 timbre/data/__init__.py create mode 100644 timbre/model/__init__.py create mode 100644 timbre/train/__init__.py diff --git a/timbre/__init__.py b/timbre/__init__.py index aa6a0d7..1aabc53 100644 --- a/timbre/__init__.py +++ b/timbre/__init__.py @@ -1,2 +1,20 @@ -from .model.train.utils import train, test, plot -from .model.vae import VAE, VAEOutput +############################################################################### +# Configuration +############################################################################### + + +from .config import defaults + +import yapecs + +yapecs.configure("timbre", defaults) + +from .config.defaults import * + +############################################################################### +# Module +############################################################################### + +# from .train import ... +from . import data +from . import model diff --git a/timbre/data/__init__.py b/timbre/data/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/timbre/model/__init__.py b/timbre/model/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/timbre/train/__init__.py b/timbre/train/__init__.py new file mode 100644 index 0000000..e69de29 From fd7fc26b6ffe79a11ceca6f6d906aab3e0d38c0b Mon Sep 17 00:00:00 2001 From: quinnouyang <90884224+quinnouyang@users.noreply.github.com> Date: Thu, 5 Dec 2024 01:17:00 -0600 Subject: [PATCH 07/18] Move data and train files --- timbre/data/nsynth.py | 117 ++++++++++++++++++++++++++++++ timbre/{model => }/train/run.py | 19 +---- timbre/{model => }/train/utils.py | 0 3 files changed, 120 insertions(+), 16 deletions(-) create mode 100644 timbre/data/nsynth.py rename timbre/{model => }/train/run.py (84%) rename timbre/{model => }/train/utils.py (100%) diff --git a/timbre/data/nsynth.py b/timbre/data/nsynth.py new file mode 100644 index 0000000..7ed6dc1 --- /dev/null +++ b/timbre/data/nsynth.py @@ -0,0 +1,117 @@ +import json + +from argparse import ArgumentParser +from contextlib import suppress +from os import path, listdir +from pathlib import Path +from torch import Tensor, float32 +from torch.utils.data import Dataset +from typing import NamedTuple, Literal, Any +from torchaudio import load, transforms + +transform_melspec = transforms.MelSpectrogram(n_fft=512, n_mels=64) + + +def forward_transform(x: Tensor) -> Tensor: + return transform_melspec(x).flatten().to(float32) + + +NSynthExample = NamedTuple( + "NSynthExample", + [ + ("note", int), + ("note_str", str), + ("instrument", int), + ("instrument_str", str), + ("pitch", int), + ("velocity", int), + ("sample_rate", int), + ("audio", Tensor), + ("qualities", list[Literal[0, 1]]), + # ("qualities_str", list[str]), # Exclude because it raises a RuntimeError from the default DataLoader collate_fn when it varies in length + ("instrument_family", int), + ("instrument_family_str", str), + ("instrument_source", int), + ("instrument_source_str", str), + ], +) +"NSynth example [features](https://magenta.tensorflow.org/datasets/nsynth#example-features)" + + +class NSynthDataset(Dataset): + """A subset of the [NSynth dataset](https://magenta.tensorflow.org/datasets/nsynth) + + Parameters + ---------- + `annotations_file` : `str` | `Path` + Path to the JSON file containing the annotations + `audio_dir` : `str` | `Path` + Path to the directory containing the audio files + """ + + def __init__( + self, + annotations_file: str | Path, + audio_dir: str | Path, + ) -> None: + with open(annotations_file, "r") as f: + self.annotations: dict[str, Any] = json.load(f) + self.keys = sorted(self.annotations.keys()) # note_strs + self.audio_dir = audio_dir + self.audio_filenames = sorted( + path.splitext(path.basename(f))[0] for f in listdir(audio_dir) + ) + + # Verify that audio_dir and annotations_file alphabetically map 1:1 + assert ( + self.keys == self.audio_filenames + ), f"Expected every key/note_str from annotations_file to match every filename from audio_dir\n Instead, audio_dir is missing {set(self.keys) - set(self.audio_filenames)} and/or annotations_file is mising {set(self.audio_filenames) - set(self.keys)}" + + def __len__(self): + return len(self.keys) + + def __getitem__(self, i: int) -> Tensor: + """Load the `i`-th example + + Parameters + ---------- + `i` : `int` + Example index (by alphabetical order of the example filename/"note_str") + + Returns + ------- + Flattened magnitude spectrogram of the example + # NSynth example features : `NSynthExample` + # See the [NSynth example features](https://magenta.tensorflow.org/datasets/nsynth#example-features) + """ + annotation: dict = self.annotations[self.keys[i]] + with suppress(KeyError): # Remove the "qualities_str" key, if not already + annotation.pop("qualities_str") + + return forward_transform( + load(path.join(self.audio_dir, self.audio_filenames[i] + ".wav"))[0][0] + ) + + return NSynthExample( + **annotation, + audio=from_numpy( + read(path.join(self.audio_dir, self.audio_filenames[i] + ".wav"))[1] + ) + / 32768.0, # 2**15 (to normalize 16-bit samples to [-1, 1]) + ) + + +if __name__ == "__main__": + p = ArgumentParser() + p.add_argument("--subset-path", type=str, required=True) + + args = p.parse_args() + + print(f"TEST: Loading NSynthDataset and printing its last example") + + D = NSynthDataset( + path.join(args.subset_path, "examples.json"), + path.join(args.subset_path, "audio"), + ) + + print(D[-1]) diff --git a/timbre/model/train/run.py b/timbre/train/run.py similarity index 84% rename from timbre/model/train/run.py rename to timbre/train/run.py index 9e1f1e2..cc1d197 100644 --- a/timbre/model/train/run.py +++ b/timbre/train/run.py @@ -1,22 +1,9 @@ +import timbre + from torch.optim.adamw import AdamW from torch.utils.data import DataLoader -from timbre.datasets.nsynth import NSynthDataset -from timbre.model.config.single import ( - DATASETS_DIR, - BATCH_SIZE, - NUM_WORKERS, - PIN_MEMORY, - INPUT_DIM, - HIDDEN_DIM, - LATENT_DIM, - DEVICE, - WEIGHT_DECAY, - WRITER, - N_EPOCHS, - RUNS_DIR, - DATETIME_NOW, -) +from timbre.data.nsynth import NSynthDataset from timbre import train, test, plot, VAE diff --git a/timbre/model/train/utils.py b/timbre/train/utils.py similarity index 100% rename from timbre/model/train/utils.py rename to timbre/train/utils.py From 4064d8039ee03ac2f3f08f65264f0c66ffc57294 Mon Sep 17 00:00:00 2001 From: quinnouyang <90884224+quinnouyang@users.noreply.github.com> Date: Fri, 6 Dec 2024 16:17:17 -0600 Subject: [PATCH 08/18] [requirements] Add tensorboard and yapecs --- requirements.txt | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 1be396d..7b625c3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,8 @@ ipykernel librosa matplotlib pip +tensorboard torchaudio torchvision -tqdm \ No newline at end of file +tqdm +yapecs \ No newline at end of file From ed8f9824ae853e0f99d15a429cf06eb80c6c8270 Mon Sep 17 00:00:00 2001 From: quinnouyang <90884224+quinnouyang@users.noreply.github.com> Date: Fri, 6 Dec 2024 16:19:02 -0600 Subject: [PATCH 09/18] [config] Include SummaryWriter and temporarily lower dims --- timbre/config/defaults.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/timbre/config/defaults.py b/timbre/config/defaults.py index 52b5d41..6f6baca 100644 --- a/timbre/config/defaults.py +++ b/timbre/config/defaults.py @@ -1,6 +1,8 @@ # Based off of https://github.com/maxrmorrison/deep-learning-project-template/blob/main/NAME/config/defaults.py +from datetime import datetime from pathlib import Path +from torch.utils.tensorboard.writer import SummaryWriter from .utils import get_device, should_pin_memory @@ -70,15 +72,17 @@ LEARN_RATE = 1e-3 WEIGHT_DECAY = 1e-2 NUM_EPOCHS = 64 -INPUT_DIM = 16064 +# INPUT_DIM = 16064 +INPUT_DIM = 256 LATENT_DIM = 2 -HIDDEN_DIM = 8064 +# HIDDEN_DIM = 8064 +HIDDEN_DIM = 32 DEVICE = get_device() USE_PIN_MEMORY = should_pin_memory() -# DATETIME_NOW = datetime.now().strftime("%Y%m%d-%H%M%S") -# WRITER = SummaryWriter(RUNS_DIR / f"log_{DATETIME_NOW}") +DATETIME_NOW = datetime.now().strftime("%Y%m%d-%H%M%S") +WRITER = SummaryWriter(RUNS_DIR / f"log_{DATETIME_NOW}") ############################################################################### From 042af82b147863d53435bbed336100c5bf68b999 Mon Sep 17 00:00:00 2001 From: quinnouyang <90884224+quinnouyang@users.noreply.github.com> Date: Fri, 6 Dec 2024 16:19:27 -0600 Subject: [PATCH 10/18] [train] Correct config imports and paths --- timbre/train/run.py | 69 +++++++++++++++++++++++++++++++------------ timbre/train/utils.py | 2 +- 2 files changed, 51 insertions(+), 20 deletions(-) diff --git a/timbre/train/run.py b/timbre/train/run.py index cc1d197..bce0fd6 100644 --- a/timbre/train/run.py +++ b/timbre/train/run.py @@ -4,51 +4,82 @@ from torch.utils.data import DataLoader from timbre.data.nsynth import NSynthDataset -from timbre import train, test, plot, VAE +from timbre.train.utils import train, test, plot +from timbre.model.vae import VAE -if __name__ == "__main__": +def main() -> None: + print(timbre.DEVICE) print("Loading datasets and dataloaders...") + # [TODO] Generalize datasets and dataloaders + # [TODO] Feature extraction TRAIN_DATA = NSynthDataset( - DATASETS_DIR / "nsynth" / "nsynth-valid" / "examples.json", - DATASETS_DIR / "nsynth" / "nsynth-valid" / "audio", + timbre.SOURCES_DIR / "nsynth" / "nsynth-valid" / "examples.json", + timbre.SOURCES_DIR / "nsynth" / "nsynth-valid" / "audio", ) TEST_DATA = NSynthDataset( - DATASETS_DIR / "nsynth" / "nsynth-test" / "examples.json", - DATASETS_DIR / "nsynth" / "nsynth-test" / "audio", + timbre.SOURCES_DIR / "nsynth" / "nsynth-test" / "examples.json", + timbre.SOURCES_DIR / "nsynth" / "nsynth-test" / "audio", ) TRAIN_LOADER = DataLoader( TRAIN_DATA, - batch_size=BATCH_SIZE, + batch_size=timbre.BATCH_SIZE, shuffle=True, - num_workers=NUM_WORKERS, - pin_memory=PIN_MEMORY, + num_workers=timbre.NUM_WORKERS, + pin_memory=timbre.USE_PIN_MEMORY, ) TEST_LOADER = DataLoader( TEST_DATA, - batch_size=BATCH_SIZE, + batch_size=timbre.BATCH_SIZE, shuffle=False, - num_workers=NUM_WORKERS, - pin_memory=PIN_MEMORY, + num_workers=timbre.NUM_WORKERS, + pin_memory=timbre.USE_PIN_MEMORY, ) print( f"Training datapoints: {len(TRAIN_DATA)}\nTesting datapoints: {len(TEST_DATA)}\n" ) print("Initiating model, optimizer, and Tensorboard...") - MODEL = VAE(INPUT_DIM, HIDDEN_DIM, LATENT_DIM).to(DEVICE) - OPT = AdamW(MODEL.parameters(), weight_decay=WEIGHT_DECAY) + MODEL = VAE(timbre.INPUT_DIM, timbre.HIDDEN_DIM, timbre.LATENT_DIM).to( + timbre.DEVICE + ) + OPT = AdamW(MODEL.parameters(), weight_decay=timbre.WEIGHT_DECAY) print("Entering train-test loop...\n") prev_updates = 0 - for epoch in range(N_EPOCHS): - print(f"Epoch {epoch+1}/{N_EPOCHS}") + for epoch in range(timbre.NUM_EPOCHS): + print(f"Epoch {epoch+1}/{timbre.NUM_EPOCHS}") prev_updates = train( - MODEL, TRAIN_LOADER, OPT, prev_updates, DEVICE, BATCH_SIZE, WRITER + MODEL, + TRAIN_LOADER, + OPT, + prev_updates, + timbre.DEVICE, + timbre.BATCH_SIZE, + timbre.WRITER, + ) + test( + MODEL, + TEST_LOADER, + prev_updates, + timbre.DEVICE, + timbre.LATENT_DIM, + timbre.WRITER, ) - test(MODEL, TEST_LOADER, prev_updates, DEVICE, LATENT_DIM, WRITER) print("\nPlotting...") - plot(MODEL, TRAIN_LOADER, DEVICE, LATENT_DIM, RUNS_DIR, DATETIME_NOW) + plot( + MODEL, + TRAIN_LOADER, + timbre.DEVICE, + timbre.LATENT_DIM, + timbre.RUNS_DIR, + timbre.DATETIME_NOW, + ) print("Done.") + + +if __name__ == "__main__": + # [TODO] Arguments + main() diff --git a/timbre/train/utils.py b/timbre/train/utils.py index 5e422e9..5762a08 100644 --- a/timbre/train/utils.py +++ b/timbre/train/utils.py @@ -11,7 +11,7 @@ from torch.utils.tensorboard.writer import SummaryWriter from tqdm import tqdm -from ..vae import VAE +from ..model.vae import VAE def train( From 6f7a9090ad5bfa405da0eceba7108b3b18684742 Mon Sep 17 00:00:00 2001 From: quinnouyang <90884224+quinnouyang@users.noreply.github.com> Date: Fri, 6 Dec 2024 16:24:43 -0600 Subject: [PATCH 11/18] [config] Increase dims to avoid empty layers --- timbre/config/defaults.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/timbre/config/defaults.py b/timbre/config/defaults.py index 6f6baca..83acff5 100644 --- a/timbre/config/defaults.py +++ b/timbre/config/defaults.py @@ -73,10 +73,10 @@ WEIGHT_DECAY = 1e-2 NUM_EPOCHS = 64 # INPUT_DIM = 16064 -INPUT_DIM = 256 +INPUT_DIM = 2048 LATENT_DIM = 2 # HIDDEN_DIM = 8064 -HIDDEN_DIM = 32 +HIDDEN_DIM = 1024 DEVICE = get_device() From 0cf27575184d8c5e57ec7fef821ce4866c059dfa Mon Sep 17 00:00:00 2001 From: quinnouyang <90884224+quinnouyang@users.noreply.github.com> Date: Fri, 6 Dec 2024 17:05:52 -0600 Subject: [PATCH 12/18] [data] Move dataset and dataloading logic out and refactor NSynth path --- timbre/data/nsynth.py | 23 ++++++++++------------- timbre/train/run.py | 43 ++++++++++++++++++++++++------------------- 2 files changed, 34 insertions(+), 32 deletions(-) diff --git a/timbre/data/nsynth.py b/timbre/data/nsynth.py index 7ed6dc1..4606280 100644 --- a/timbre/data/nsynth.py +++ b/timbre/data/nsynth.py @@ -43,17 +43,17 @@ class NSynthDataset(Dataset): Parameters ---------- - `annotations_file` : `str` | `Path` - Path to the JSON file containing the annotations - `audio_dir` : `str` | `Path` - Path to the directory containing the audio files + `source_dir` : `str` | `Path` + Path to a subset of the unprocessed NSynth dataset directory containing `examples.json` and `audio` directory """ - def __init__( - self, - annotations_file: str | Path, - audio_dir: str | Path, - ) -> None: + def __init__(self, source_dir: str | Path) -> None: + if isinstance(source_dir, str): + source_dir = Path(source_dir) + + annotations_file = source_dir / "examples.json" + audio_dir = source_dir / "audio" + with open(annotations_file, "r") as f: self.annotations: dict[str, Any] = json.load(f) self.keys = sorted(self.annotations.keys()) # note_strs @@ -109,9 +109,6 @@ def __getitem__(self, i: int) -> Tensor: print(f"TEST: Loading NSynthDataset and printing its last example") - D = NSynthDataset( - path.join(args.subset_path, "examples.json"), - path.join(args.subset_path, "audio"), - ) + D = NSynthDataset(args.subset_path) print(D[-1]) diff --git a/timbre/train/run.py b/timbre/train/run.py index bce0fd6..0ccb174 100644 --- a/timbre/train/run.py +++ b/timbre/train/run.py @@ -8,36 +8,41 @@ from timbre.model.vae import VAE -def main() -> None: - print(timbre.DEVICE) - print("Loading datasets and dataloaders...") - # [TODO] Generalize datasets and dataloaders - # [TODO] Feature extraction - TRAIN_DATA = NSynthDataset( - timbre.SOURCES_DIR / "nsynth" / "nsynth-valid" / "examples.json", - timbre.SOURCES_DIR / "nsynth" / "nsynth-valid" / "audio", +# [TODO] Generalize datasets and dataloaders +# [TODO] Feature extraction + + +def build_datasets() -> tuple[NSynthDataset, NSynthDataset]: + return ( + NSynthDataset(timbre.SOURCES_DIR / "nsynth" / "nsynth-valid"), + NSynthDataset(timbre.SOURCES_DIR / "nsynth" / "nsynth-test"), ) - TEST_DATA = NSynthDataset( - timbre.SOURCES_DIR / "nsynth" / "nsynth-test" / "examples.json", - timbre.SOURCES_DIR / "nsynth" / "nsynth-test" / "audio", + + +def build_dataloaders() -> tuple[DataLoader, DataLoader]: + print("Loading datasets and dataloaders...") + train_data, test_data = build_datasets() + print( + f"Training datapoints: {len(train_data)}\nTesting datapoints: {len(test_data)}\n" ) - TRAIN_LOADER = DataLoader( - TRAIN_DATA, + + return DataLoader( + train_data, batch_size=timbre.BATCH_SIZE, shuffle=True, num_workers=timbre.NUM_WORKERS, pin_memory=timbre.USE_PIN_MEMORY, - ) - TEST_LOADER = DataLoader( - TEST_DATA, + ), DataLoader( + test_data, batch_size=timbre.BATCH_SIZE, shuffle=False, num_workers=timbre.NUM_WORKERS, pin_memory=timbre.USE_PIN_MEMORY, ) - print( - f"Training datapoints: {len(TRAIN_DATA)}\nTesting datapoints: {len(TEST_DATA)}\n" - ) + + +def main() -> None: + TRAIN_LOADER, TEST_LOADER = build_dataloaders() print("Initiating model, optimizer, and Tensorboard...") MODEL = VAE(timbre.INPUT_DIM, timbre.HIDDEN_DIM, timbre.LATENT_DIM).to( From 3b9efacebd9a4ae08b71dff30512b520cebd9f9c Mon Sep 17 00:00:00 2001 From: quinnouyang <90884224+quinnouyang@users.noreply.github.com> Date: Fri, 6 Dec 2024 17:12:19 -0600 Subject: [PATCH 13/18] [train] Add TODOs --- timbre/train/run.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/timbre/train/run.py b/timbre/train/run.py index 0ccb174..1e79aa9 100644 --- a/timbre/train/run.py +++ b/timbre/train/run.py @@ -9,7 +9,13 @@ # [TODO] Generalize datasets and dataloaders -# [TODO] Feature extraction +# [TODO] Feature extraction (actually figure out I/O with reasonable dimensions) +# [TODO] Get it to run on malleus +# [TODO] Tensorboard +# [TODO] Achieve basic clustering results and audio output +# [TODO] Checkpoints +# [TODO] DDP +# [TODO] Arguments def build_datasets() -> tuple[NSynthDataset, NSynthDataset]: @@ -86,5 +92,4 @@ def main() -> None: if __name__ == "__main__": - # [TODO] Arguments main() From 39d79c1ca2438b6426450dc48ce9fd19a9cec023 Mon Sep 17 00:00:00 2001 From: Quinn Ouyang <90884224+quinnouyang@users.noreply.github.com> Date: Sat, 7 Dec 2024 14:21:34 -0600 Subject: [PATCH 14/18] [pyproject] Refine URLs --- pyproject.toml | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c914364..f6d3d98 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,11 +27,13 @@ classifiers = [ dynamic = ["dependencies"] [project.urls] -Homepage = "https://github.com/quinnouyang/timbre" -# Documentation = "" -Repository = "https://github.com/quinnouyang/timbre.git" -"Bug Tracker" = "https://github.com/quinnouyang/timbre/issues" -# Changelog = "" +# homepage = "" +source = "https://github.com/quinnouyang/timbre" +github = "https://github.com/quinnouyang/timbre.git" +# download = "" +# changelog = "" +# documentation = "" +issues = "https://github.com/quinnouyang/timbre/issues" [tool.hatch.metadata] allow-direct-references = true From 90c56e6f9c76e6207d59f0b0a3578f1ed526c3d5 Mon Sep 17 00:00:00 2001 From: Quinn Ouyang <90884224+quinnouyang@users.noreply.github.com> Date: Sat, 7 Dec 2024 16:06:06 -0600 Subject: [PATCH 15/18] [config] Extend default config with malleus option --- timbre/config/defaults.py | 18 ++++++++++-------- timbre/config/malleus.py | 14 ++++++++++++++ timbre/train/run.py | 9 ++++++++- 3 files changed, 32 insertions(+), 9 deletions(-) create mode 100644 timbre/config/malleus.py diff --git a/timbre/config/defaults.py b/timbre/config/defaults.py index 83acff5..46d1068 100644 --- a/timbre/config/defaults.py +++ b/timbre/config/defaults.py @@ -25,27 +25,29 @@ # Directories ############################################################################### -PACKAGE_DIR = Path(__file__).parent.parent +_PACKAGE_DIR = Path(__file__).parent.parent # For assets to bundle with pip release -# ASSETS_DIR = PACKAGE_DIR / "assets" +# ASSETS_DIR = _PACKAGE_DIR / "assets" -ROOT_DIR = PACKAGE_DIR.parent +_ROOT_DIR = _PACKAGE_DIR.parent + +_DATA_DIR = _ROOT_DIR / "datasets" # For preprocessed features -# CACHE_DIR = ROOT_DIR.parent / "data" / "cache" +# CACHE_DIR = _DATA_DIR / "cache" # For unprocessed datasets -SOURCES_DIR = ROOT_DIR / "datasets" / "sources" +SOURCES_DIR = _DATA_DIR / "sources" # For preprocessed datasets -PREPROCESSED_DIR = ROOT_DIR / "datasets" / "preprocessed" +PREPROCESSED_DIR = _DATA_DIR / "preprocessed" # For training and adaptation artifacts -RUNS_DIR = ROOT_DIR / "runs" +RUNS_DIR = _ROOT_DIR / "runs" # For evaluation artifacts -# EVAL_DIR = ROOT_DIR.parent / "eval" +# EVAL_DIR = _ROOT_DIR.parent / "eval" ############################################################################### diff --git a/timbre/config/malleus.py b/timbre/config/malleus.py new file mode 100644 index 0000000..f2faa38 --- /dev/null +++ b/timbre/config/malleus.py @@ -0,0 +1,14 @@ +from pathlib import Path + +MODULE = "timbre" + +# Configuration name +CONFIG = "malleus" + +_DATA_DIR = Path("/") / "mnt" / "data" / "quinn" / "timbre" + +# For unprocessed datasets +SOURCES_DIR = _DATA_DIR / "sources" + +# For preprocessed features +PREPROCESSED_DIR = _DATA_DIR / "preprocessed" diff --git a/timbre/train/run.py b/timbre/train/run.py index 1e79aa9..a555fb1 100644 --- a/timbre/train/run.py +++ b/timbre/train/run.py @@ -1,5 +1,7 @@ import timbre +from argparse import ArgumentParser +from pathlib import Path from torch.optim.adamw import AdamW from torch.utils.data import DataLoader @@ -20,7 +22,7 @@ def build_datasets() -> tuple[NSynthDataset, NSynthDataset]: return ( - NSynthDataset(timbre.SOURCES_DIR / "nsynth" / "nsynth-valid"), + NSynthDataset(timbre.SOURCES_DIR / "nsynth" / "nsynth-train"), NSynthDataset(timbre.SOURCES_DIR / "nsynth" / "nsynth-test"), ) @@ -48,6 +50,11 @@ def build_dataloaders() -> tuple[DataLoader, DataLoader]: def main() -> None: + parser = ArgumentParser() + parser.add_argument("--config", type=Path, nargs="+", help="Configuration file(s)") + + print(f"Using {timbre.CONFIG} configuration\n") + TRAIN_LOADER, TEST_LOADER = build_dataloaders() print("Initiating model, optimizer, and Tensorboard...") From e7edfcab961511b550fe1ddbda5f9db0684c2144 Mon Sep 17 00:00:00 2001 From: Quinn Ouyang <90884224+quinnouyang@users.noreply.github.com> Date: Sat, 7 Dec 2024 16:42:07 -0600 Subject: [PATCH 16/18] [requirements] Add torch profiler and isort --- pyproject.toml | 3 +++ requirements.txt | 2 ++ 2 files changed, 5 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index f6d3d98..45ece59 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,3 +43,6 @@ files = ["requirements.txt"] # [tool.hatch.metadata.hooks.requirements_txt.optional-dependencies] # cli = ["requirements-dev.txt"] + +[tool.isort] +profile = "black" diff --git a/requirements.txt b/requirements.txt index 7b625c3..1bc21e4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,9 +1,11 @@ black ipykernel +isort librosa matplotlib pip tensorboard +torch_tb_profiler torchaudio torchvision tqdm From 83259f4c506eb83ed76d98b036dd5d2f9ad095ee Mon Sep 17 00:00:00 2001 From: Quinn Ouyang <90884224+quinnouyang@users.noreply.github.com> Date: Sat, 7 Dec 2024 17:48:00 -0600 Subject: [PATCH 17/18] Apply isort --- pyproject.toml | 1 + timbre/__init__.py | 10 ++++------ timbre/config/defaults.py | 2 +- timbre/data/nsynth.py | 6 +++--- timbre/model/vae.py | 4 ++-- timbre/train/run.py | 7 +++---- timbre/train/utils.py | 6 +++--- 7 files changed, 17 insertions(+), 19 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 45ece59..fbb4f98 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,3 +46,4 @@ files = ["requirements.txt"] [tool.isort] profile = "black" +skip_glob = ["reference_models/*"] diff --git a/timbre/__init__.py b/timbre/__init__.py index 1aabc53..8badc53 100644 --- a/timbre/__init__.py +++ b/timbre/__init__.py @@ -3,18 +3,16 @@ ############################################################################### -from .config import defaults - import yapecs -yapecs.configure("timbre", defaults) +from .config import defaults -from .config.defaults import * +yapecs.configure("timbre", defaults) ############################################################################### # Module ############################################################################### # from .train import ... -from . import data -from . import model +from . import data, model +from .config.defaults import * diff --git a/timbre/config/defaults.py b/timbre/config/defaults.py index 46d1068..525e55c 100644 --- a/timbre/config/defaults.py +++ b/timbre/config/defaults.py @@ -2,11 +2,11 @@ from datetime import datetime from pathlib import Path + from torch.utils.tensorboard.writer import SummaryWriter from .utils import get_device, should_pin_memory - # Configuration name CONFIG = "defaults" diff --git a/timbre/data/nsynth.py b/timbre/data/nsynth.py index 4606280..bb95069 100644 --- a/timbre/data/nsynth.py +++ b/timbre/data/nsynth.py @@ -1,12 +1,12 @@ import json - from argparse import ArgumentParser from contextlib import suppress -from os import path, listdir +from os import listdir, path from pathlib import Path +from typing import Any, Literal, NamedTuple + from torch import Tensor, float32 from torch.utils.data import Dataset -from typing import NamedTuple, Literal, Any from torchaudio import load, transforms transform_melspec = transforms.MelSpectrogram(n_fft=512, n_mels=64) diff --git a/timbre/model/vae.py b/timbre/model/vae.py index 59b82bc..903c786 100644 --- a/timbre/model/vae.py +++ b/timbre/model/vae.py @@ -1,8 +1,8 @@ +from dataclasses import dataclass + import torch import torch.nn as nn import torch.nn.functional as F - -from dataclasses import dataclass from torch.distributions.multivariate_normal import MultivariateNormal diff --git a/timbre/train/run.py b/timbre/train/run.py index a555fb1..f969aa1 100644 --- a/timbre/train/run.py +++ b/timbre/train/run.py @@ -1,14 +1,13 @@ -import timbre - from argparse import ArgumentParser from pathlib import Path + from torch.optim.adamw import AdamW from torch.utils.data import DataLoader +import timbre from timbre.data.nsynth import NSynthDataset -from timbre.train.utils import train, test, plot from timbre.model.vae import VAE - +from timbre.train.utils import plot, test, train # [TODO] Generalize datasets and dataloaders # [TODO] Feature extraction (actually figure out I/O with reasonable dimensions) diff --git a/timbre/train/utils.py b/timbre/train/utils.py index 5762a08..f6f1c36 100644 --- a/timbre/train/utils.py +++ b/timbre/train/utils.py @@ -1,11 +1,11 @@ +from pathlib import Path + import matplotlib.pyplot as plt import numpy as np import torch - from matplotlib.colors import LogNorm -from pathlib import Path -from torch.nn.utils import clip_grad_norm_ from torch.nn.parallel import DistributedDataParallel as DDP +from torch.nn.utils import clip_grad_norm_ from torch.optim.optimizer import Optimizer from torch.utils.data import DataLoader from torch.utils.tensorboard.writer import SummaryWriter From 995538a2ab7b0511fd4fd344e855d856ec691761 Mon Sep 17 00:00:00 2001 From: Quinn Ouyang <90884224+quinnouyang@users.noreply.github.com> Date: Sat, 7 Dec 2024 23:58:12 -0600 Subject: [PATCH 18/18] [pyproject] Exclude reference_models from formatting --- pyproject.toml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index fbb4f98..3df11a0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,9 @@ files = ["requirements.txt"] # [tool.hatch.metadata.hooks.requirements_txt.optional-dependencies] # cli = ["requirements-dev.txt"] +[tool.black] +extend-exclude = "reference_models" + [tool.isort] profile = "black" -skip_glob = ["reference_models/*"] +skip_glob = ["reference_models"]