Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/optimizations #8

Merged
merged 4 commits into from
Aug 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,12 @@ options:
-t TOKENS_PATH, --tokens_path TOKENS_PATH
The output path were tokens are saved
-m MIDIS_PATH, --midis_path MIDIS_PATH
The path where MIDI files can be located
The path where MIDI files can be located or a file containing a list of paths
-g MIDIS_GLOB, --midis_glob MIDIS_GLOB
The glob pattern used to locate MIDI files
-b, --bpe Applies BPE to the corpora of tokens
-p, --process Extracts tokens from the MIDI files
-p PARAMS_PATH, --preload PARAMS_PATH
Absolute path to existing token_params.cfg settings
-a {REMI,MMM}, --algo {REMI,MMM}
Tokenization algorithm
-c CLASSES, --classes CLASSES
Expand Down
2 changes: 1 addition & 1 deletion dev/cloud-commands.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ apt install cuda-toolkit-11-7
pip install -q -U fastapi==0.93.0
pip install -q -U gdown

gdown --fuzzy -O ~/rwkv-2.pth https://drive.google.com/file/d/1dQ_QLPgIb7crfdwwkI9yrfijQVeW3RD1/view?usp=sharing
gdown --fuzzy -O ~/rwkv-92.pth https://drive.google.com/file/d/1kDNgfYBcq4vsbDkwTUIKBESxJaIg_mwO/view?usp=sharing
gdown --fuzzy -O ~/all-es.binidx.tgz https://drive.google.com/file/d/108wQdxHM6CJNvliWlVqOLCPahDYq0E9p/view?usp=sharing
tar xzvf all-es.binidx.tgz

Expand Down
6 changes: 3 additions & 3 deletions notebooks/vae.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -292,16 +292,16 @@
},
{
"cell_type": "code",
"execution_count": 27,
"execution_count": 51,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([2048, 1024]), 0.010902252979576588)"
"(torch.Size([2048, 1024]), 0.010932102799415588)"
]
},
"execution_count": 27,
"execution_count": 51,
"metadata": {},
"output_type": "execute_result"
}
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ lightning==2.0.2
loguru==0.7.0
miditok==2.1.1
miditoolkit @ git+https://github.com/webpolis/miditoolkit@master
music21==9.1.0
numpy==1.23.5
psutil==5.9.5
ray==2.4.0
Expand Down
4 changes: 2 additions & 2 deletions src/model/embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def __init__(self, hidden_dims: List = [HIDDEN_DIM], latent_dim=LATENT_DIM, embe

self.module = nn.Sequential(
*modules,
nn.Linear(hidden_dims[-1], embed_dim, bias=False)
nn.Linear(hidden_dims[-1], embed_dim, bias=True)
)

def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
Expand All @@ -105,7 +105,7 @@ def __init__(self,
self.emb = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
self.encoder = Encoder(hidden_dims, latent_dim, embed_dim)
self.decoder = Decoder(hidden_dims, latent_dim, embed_dim)
self.z_emb = nn.Linear(latent_dim, embed_dim, bias=False)
self.z_emb = nn.Linear(latent_dim, embed_dim, bias=True)
self.proj = nn.Linear(embed_dim, vocab_size, bias=True)
self.ln_out = nn.LayerNorm(vocab_size)

Expand Down
3 changes: 3 additions & 0 deletions src/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,7 @@ def __init__(self, args):
latent_dim,
hidden_dim
)
self.emb_norm = nn.LayerNorm(args.n_embd)
else:
self.emb = VAE(
embed_dim,
Expand Down Expand Up @@ -552,6 +553,8 @@ def forward(self, idx):
else:
output, x, emb, hidden, mean, logvar = self.emb(idx)

x = self.emb_norm(x)

self.register_buffer('emb_input', idx.detach().clone(), persistent=False)
self.register_buffer('emb_output', output, persistent=False)
self.register_buffer('emb_hat', x.detach().clone(), persistent=False)
Expand Down
100 changes: 100 additions & 0 deletions src/tools/sanitizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
"""
MusAI

Author: Nicolás Iglesias
Email: nfiglesias@gmail.com

This file is part of MusAI, a project for generating MIDI music using
a combination of machine learning algorithms.

This script offers optional methods to sanitize a set of MIDI files.

MIT License
Copyright (c) [2023] [Nicolás Iglesias]

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""
import os
import re
import ray
import argparse
from loguru import logger
from pathlib import Path
from music21 import converter, note, chord
from music21.stream.base import Score
from tqdm import tqdm
from tokenizer import deco, ProgressBar, ProgressBarActor


def trim_midi(score: Score):
start_measure = None
end_measure = 0

for element in score.flatten().elements:
if isinstance(element, note.Note) or \
isinstance(element, note.Rest) or \
isinstance(element, note.Unpitched) or \
isinstance(element, chord.Chord):
if start_measure is None and not element.isRest:
start_measure = element.measureNumber
if not element.isRest and element.measureNumber > end_measure:
end_measure = element.measureNumber

return start_measure, end_measure


if __name__ == "__main__":
# parse command line arguments
arg_parser = argparse.ArgumentParser()
arg_parser.add_argument('-m', '--midis_path', default=None,
help='The path where MIDI files can be located', type=str)
arg_parser.add_argument('-g', '--midis_glob', default='*.mid',
help='The glob pattern used to locate MIDI files', type=str)
arg_parser.add_argument('-o', '--output_path', default='out',
help='The path where the sanitized MIDI files will be saved', type=str)
arg_parser.add_argument('-n', '--rename', help='Sanitize filename for convenience',
action='store_true', default=True)
arg_parser.add_argument('-t', '--trim', help='Remove silence from beginning and end of MIDI songs',
action='store_true', default=True)
args = arg_parser.parse_args()

Path(args.output_path).mkdir(parents=True, exist_ok=True)

if os.path.isfile(args.midis_path):
midi_file_paths = [line.strip()
for line in open(args.midis_path) if line.strip()]
else:
midi_file_paths = list(Path(args.midis_path).glob(args.midis_glob))

for midi_path in tqdm(midi_file_paths):
midi_score = converter.parse(midi_path)
fname = Path(midi_path).name

if args.rename:
# rename
fname = re.sub(r'[^a-z\d\.]{1,}', '_', fname.lower(), flags=re.IGNORECASE)

if args.trim:
# trim MIDI
start_measure, end_measure = trim_midi(midi_score)
trimmed_score = midi_score.measures(start_measure, end_measure)
trim_path = re.sub(r'([\w\d_\-]+)(\.[a-z]+)$', '\\1_trim\\2', fname)
trim_path = f'{args.output_path}/{trim_path}'

trimmed_score.write('midi', fp=trim_path)
113 changes: 60 additions & 53 deletions src/tools/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,13 +175,14 @@ def deco(func): return func
arg_parser.add_argument('-t', '--tokens_path', default=TOKENS_PATH,
help='The output path were tokens are saved', type=str)
arg_parser.add_argument('-m', '--midis_path', default=MIDIS_PATH,
help='The path where MIDI files can be located', type=str)
arg_parser.add_argument('-g', '--midis_glob', default='*mix*.mid',
help='The path where MIDI files can be located or a file containing a list of paths',
type=str)
arg_parser.add_argument('-g', '--midis_glob', default='*.mid',
help='The glob pattern used to locate MIDI files', type=str)
arg_parser.add_argument('-b', '--bpe', help='Applies BPE to the corpora of tokens',
action='store_true', default=False)
arg_parser.add_argument('-p', '--process', help='Extracts tokens from the MIDI files',
action='store_true', default=False)
arg_parser.add_argument('-p', '--preload', help='Absolute path to existing token_params.cfg settings',
default='token_params.cfg', type=str)
arg_parser.add_argument('-a', '--algo', help='Tokenization algorithm',
choices=TOKENIZER_ALGOS, default='REMI', type=str)
arg_parser.add_argument('-c', '--classes', help='Only extract this instruments classes (e.g. 1,14,16,3,4,10,11)',
Expand Down Expand Up @@ -292,7 +293,8 @@ def get_tokenizer(params=None, algo='MMM', programs=None):
tokenizer = MMM(density_bins_max=(10, 20), tokenizer_config=TokenizerConfig(
**TOKENIZER_PARAMS), params=params)

logger.info('Tokenizer initialized. Using {algo}', algo=algo)
logger.info(
'Tokenizer initialized. Using {algo} ({size})', algo=algo, size=len(tokenizer))

return tokenizer

Expand Down Expand Up @@ -460,63 +462,69 @@ def get_collection_refs(midis_path=None, midis_glob=None, classes=None, classes_
MIDI_TITLES = []
MIDI_PROGRAMS = []

if args.process:
if not args.debug:
# starts orchestration
ray.init(num_cpus=psutil.cpu_count())
if not args.debug:
# starts orchestration
ray.init(num_cpus=psutil.cpu_count())

MIDI_COLLECTION_REFS = get_collection_refs(
args.midis_path, args.midis_glob, args.classes, args.classes_req, args.length)
MIDI_COLLECTION_REFS = get_collection_refs(
args.midis_path, args.midis_glob, args.classes, args.classes_req, args.length)

for ray_midi_ref in MIDI_COLLECTION_REFS:
midi_doc = ray.get(ray_midi_ref)
for ray_midi_ref in MIDI_COLLECTION_REFS:
midi_doc = ray.get(ray_midi_ref)

if midi_doc != None:
MIDI_TITLES.append(midi_doc['name'])
MIDI_PROGRAMS.append(midi_doc['programs'])
else:
MIDI_COLLECTION_REFS = get_collection_refs(
args.midis_path, args.midis_glob, args.classes, args.classes_req, args.length, args.debug)
MIDI_TITLES = [midi_ref['name'] for midi_ref in MIDI_COLLECTION_REFS]
MIDI_PROGRAMS = [midi_ref['programs'] for midi_ref in MIDI_COLLECTION_REFS]
if midi_doc != None:
MIDI_TITLES.append(midi_doc['name'])
MIDI_PROGRAMS.append(midi_doc['programs'])
else:
MIDI_COLLECTION_REFS = get_collection_refs(
args.midis_path, args.midis_glob, args.classes, args.classes_req, args.length, args.debug)
MIDI_TITLES = [midi_ref['name'] for midi_ref in MIDI_COLLECTION_REFS]
MIDI_PROGRAMS = [midi_ref['programs'] for midi_ref in MIDI_COLLECTION_REFS]

logger.info('Processing tokenization: {collection_size} documents', collection_size=len(
MIDI_COLLECTION_REFS))
logger.info('Processing tokenization: {collection_size} documents', collection_size=len(
MIDI_COLLECTION_REFS))

Path(args.tokens_path).mkdir(parents=True, exist_ok=True)
Path(args.tokens_path).mkdir(parents=True, exist_ok=True)

# collect used programs
programs_used = [program for program in list(
set(reduce(iconcat, MIDI_PROGRAMS, [])))]
# collect used programs
programs_used = [program for program in list(
set(reduce(iconcat, MIDI_PROGRAMS, [])))]

# initializes tokenizer
TOKENIZER = get_tokenizer(programs=programs_used, algo=args.algo)
# initializes tokenizer
params_path = os.path.join(args.tokens_path, args.preload)
preloads = Path(params_path).is_file()

# process tokenization via Ray
if not args.debug:
pb = ProgressBar(len(MIDI_COLLECTION_REFS))
actor = pb.actor
midi_collection_refs = MIDI_COLLECTION_REFS
else:
actor = None
midi_collection_refs = tqdm(MIDI_COLLECTION_REFS)
if preloads:
logger.info('Preloading params from {path}', path=params_path)

tokenize_call = tokenize_set if args.debug else tokenize_set.remote
ray_tokenized_refs = [tokenize_call(ray_midi_ref, args.tokens_path, TOKENIZER,
actor, bpe=args.bpe, debug=args.debug) for ray_midi_ref in midi_collection_refs]
TOKENIZER = get_tokenizer(
programs=programs_used, algo=args.algo) if not preloads else get_tokenizer(params=params_path)

if not args.debug:
pb.print_until_done()
ray.shutdown()
# process tokenization via Ray
if not args.debug:
pb = ProgressBar(len(MIDI_COLLECTION_REFS))
actor = pb.actor
midi_collection_refs = MIDI_COLLECTION_REFS
else:
actor = None
midi_collection_refs = tqdm(MIDI_COLLECTION_REFS)

tokenize_call = tokenize_set if args.debug else tokenize_set.remote
ray_tokenized_refs = [tokenize_call(ray_midi_ref, args.tokens_path, TOKENIZER,
actor, bpe=args.bpe, debug=args.debug) for ray_midi_ref in midi_collection_refs]

if not args.debug:
pb.print_until_done()
ray.shutdown()

logger.info('Vocab size (base): {vocab_size}',
vocab_size=len(TOKENIZER.vocab))
logger.info('Saving params...')
logger.info('Vocab size (base): {vocab_size}',
vocab_size=len(TOKENIZER.vocab))
logger.info('Saving params...')

""" !IMPORTANT always store the _vocab_base when saving params.
Order of keys in the vocab may differ in a new instance of a preloaded TOKENIZER. """
TOKENIZER.save_params(
f'{args.tokens_path}/{TOKEN_PARAMS_NAME}', {'_vocab_base': TOKENIZER.vocab})
""" !IMPORTANT always store the _vocab_base when saving params.
Order of keys in the vocab may differ in a new instance of a preloaded TOKENIZER. """
TOKENIZER.save_params(
f'{args.tokens_path}/{TOKEN_PARAMS_NAME}', {'_vocab_base': TOKENIZER.vocab})

if args.bpe:
# Constructs the vocabulary with BPE, from the tokenized files
Expand All @@ -525,9 +533,8 @@ def get_collection_refs(midis_path=None, midis_glob=None, classes=None, classes_

Path(tokens_bpe_path).mkdir(parents=True, exist_ok=True)

if not args.process:
TOKENIZER = get_tokenizer(
params=f'{args.tokens_path}/{TOKEN_PARAMS_NAME}', algo=args.algo)
TOKENIZER = get_tokenizer(
params=f'{args.tokens_path}/{TOKEN_PARAMS_NAME}', algo=args.algo)

logger.info('Learning BPE from vocab size {vocab_size}...', vocab_size=len(
TOKENIZER))
Expand Down