-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.py
51 lines (37 loc) · 1.57 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
"""Main file for training Chroma.
This file is intentionally kept short. The majority for logic is in libraries
that can be easily tested and imported in Colab.
"""
from absl import app
from absl import flags
from absl import logging
from clu import platform
import jax
from ml_collections import config_flags
import tensorflow as tf
import train
FLAGS = flags.FLAGS
flags.DEFINE_string('workdir', None, 'Directory to store model data.')
config_flags.DEFINE_config_file(
'config',
None,
'File path to the training hyperparameter configuration.',
lock_config=True)
def main(argv):
if len(argv) > 1:
raise app.UsageError('Too many command-line arguments.')
# Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make
# it unavailable to JAX.
tf.config.experimental.set_visible_devices([], 'GPU')
# jax.distributed.initialize()
logging.info('JAX process: %d / %d', jax.process_index(), jax.process_count())
logging.info('JAX local devices: %r', jax.local_devices())
# Add a note so that we can tell which task is which JAX host.
# (Depending on the platform task 0 is not guaranteed to be host 0)
platform.work_unit().set_task_status(f'process_index: {jax.process_index()}, process_count: {jax.process_count()}')
platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY,
FLAGS.workdir, 'workdir')
train.train_and_evaluate(FLAGS.config, FLAGS.workdir)
if __name__ == '__main__':
flags.mark_flags_as_required(['config', 'workdir'])
app.run(main)