-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
367 lines (289 loc) · 15 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
# import locale
# locale.getpreferredencoding = lambda *args: "UTF-8"
import os
import argparse
import logging
import math
import time
from pathlib import Path
from tqdm.auto import tqdm
import numpy as np
import torch
import torch.utils.checkpoint
from collections import defaultdict
import transformers
from datasets import load_dataset, load_from_disk
from transformers import set_seed, AutoTokenizer, FlaxBertModel
from huggingface_hub import create_repo, upload_folder
from diffusers.utils import check_min_version, is_wandb_available
import jax
import jax.numpy as jnp
import optax
from flax import jax_utils
from flax.core.frozen_dict import freeze, unfreeze
from flax.training import train_state, checkpoints
from flax.training.common_utils import shard
from flax.training import orbax_utils
import orbax.checkpoint
import diffusion_model as dm
import model_utils as u
if is_wandb_available():
import wandb
check_min_version("0.16.0.dev0")
print(f"Device count : {jax.device_count()}")
logger = logging.getLogger(__name__)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--seed', type = int, default = 0)
parser.add_argument('--batch_size', type = int, default = 16) # set bigger, think batch_size * accum_steps
parser.add_argument('--epochs', type = int, default = 100)
parser.add_argument('--timesteps', type = int, default = 2000)
parser.add_argument('--prefix', type = str, default = 'test')
parser.add_argument('--padding_mode', type = str, default = 'normal')
parser.add_argument('--data_path', type = str, default = 'data/poems/poems.txt')
parser.add_argument('--learning_rate', type = float, default = 0.001)
parser.add_argument('--latent_dim', type = int, default = 32)
parser.add_argument('--seq_len', type = int, default = 64)
parser.add_argument("--rewrite_vocab", action="store_true",)
parser.add_argument("--use_pretrained", action="store_true",)
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
parser.add_argument("--report_to", type=str, default="wandb", help=('The integration to report the results and logs to. Currently only supported platforms are `"wandb"`'))
parser.add_argument('--output_dir', type = str, default = 'test')
parser.add_argument('--hub_token', type = str, default = 'test') # how do we get one?
parser.add_argument('--hub_model_id', type = str, default = 'test')
parser.add_argument('--gradient_accumulation_steps', type = int, default = 1)
parser.add_argument("--profile_memory", action="store_true", help="Whether to dump an initial (before training loop) and a final (at program end) memory profile.",)
parser.add_argument("--profile_steps", type=int, default=2, help="How many training steps to profile in the beginning.",)
parser.add_argument("--logging_steps", type=int, default=300, help=("log training metric every X steps to `--report_t`"),)
parser.add_argument("--checkpointing_steps", type=int, default=300, help=("log training metric every X steps to `--report_t`"),)
args = parser.parse_args()
print('args', args)
return args
def main():
args = parse_args()
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
# Setup logging, we only want one process per machine to log things on the screen.
logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
if jax.process_index() == 0 and args.report_to == "wandb":
wandb.init(
entity="diff-lm",
project="diff-lm",
job_type="train",
config=args,
)
if args.use_pretrained:
tokenizer = AutoTokenizer.from_pretrained('bert-base-cased')
vocab_size = tokenizer.vocab_size
train_dataset = load_dataset("text", data_files = args.data_path)
train_dataset = train_dataset.map(lambda sample : tokenizer(sample['text'], padding='max_length', truncation=True, max_length = args.seq_len), batched = True)
def collate_fn(examples):
#batch = {'text' : [ex['text'] for ex in examples]}
batch = {}
#for k in ['input_ids', 'token_type_ids', 'attention_mask']:
for k in ['input_ids']:
vals = torch.stack([torch.tensor(example[k]) for example in examples]).numpy()
batch[k] = vals
return batch
train_dataset = train_dataset['train']
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
shuffle=True, # false if streaming dataset
collate_fn=collate_fn,
batch_size=args.batch_size,
drop_last=True)
else:
#vocab_path = 'vocab.json'
tokenizer = u.get_tokenizer()
vocab_dict = u.make_vocab(tokenizer = tokenizer, data_path = args.data_path, rewrite=args.rewrite_vocab)
vocab_size = len(vocab_dict)
train_dataset = u.make_dataset(args.data_path, vocab_dict, padding_mode = args.padding_mode, seq_length = args.seq_len)
# test_dataset = u.make_dataset('data/e2e_data/src1_test.txt', vocab_dict, padding_mode = args.padding_mode, seq_length = args.seq_len)
# val_dataset = u.make_dataset('data/e2e_data/src1_valid.txt', vocab_dict, padding_mode = args.padding_mode, seq_length = args.seq_len)
def collate_fn(examples):
input_ids = torch.stack([torch.tensor(example["input_ids"]) for example in examples]).numpy()
batch = {
"input_ids": input_ids,
}
return batch
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
shuffle=True, # false if streaming dataset
collate_fn=collate_fn,
batch_size=args.batch_size,
drop_last=True)
# test_dataloader = torch.utils.data.DataLoader(
# test_dataset,
# shuffle=False,
# collate_fn=collate_fn,
# batch_size=args.batch_size,
# drop_last=True)
# val_dataloader = torch.utils.data.DataLoader(
# val_dataset,
# shuffle=False,
# collate_fn=collate_fn,
# batch_size=args.batch_size,
# drop_last=True)
# initialize
if args.seed is not None:
set_seed(args.seed)
rng = jax.random.PRNGKey(args.seed)
rng, rng_params = jax.random.split(rng)
rng, rng_dropout = jax.random.split(rng)
if args.use_pretrained:
pretrain_model = FlaxBertModel.from_pretrained('bert-base-cased')
args.latent_dim = pretrain_model.config.hidden_size
diff_lm = dm.DiffusionLM(timesteps = args.timesteps,
latent_dim = args.latent_dim,
batch_size = args.batch_size,
seq_len = args.seq_len,
vocab_size = vocab_size,
use_pretrained = args.use_pretrained,
train = True)
for b in train_dataloader:
#print(b)
# for s in b['input_ids']:
# print(np.where(s == 3))
break
diff_lm_params = diff_lm.init({'params' : rng, 'dropout' : rng_dropout}, b, rng_params) # jnp.ones((args.batch_size, args.seq_len, args.latent_dim))
if args.use_pretrained:
# init some weights
diff_lm_params = unfreeze(diff_lm_params)
diff_lm_params['params']['transformer']['input_transformer'] = pretrain_model.params['encoder']
diff_lm_params['params']['embedder']['embedding'] = pretrain_model.params['embeddings']['word_embeddings']['embedding']
diff_lm_params['params']['transformer']['position_embeddings']['embedding'] = pretrain_model.params['embeddings']['position_embeddings']['embedding']
diff_lm_params = freeze(diff_lm_params)
del pretrain_model
# prep for training
tx = optax.adamw(learning_rate=args.learning_rate, b1=0.9, b2=0.999, eps=1e-6)
state = train_state.TrainState.create(apply_fn=diff_lm.__call__, params=diff_lm_params, tx=tx)
train_rng, validation_rng = jax.random.split(rng)
@jax.jit
def train_step(state, batch, rng):
if args.gradient_accumulation_steps > 1:
grad_steps = args.gradient_accumulation_steps
batch = jax.tree_map(lambda x: x.reshape((grad_steps, x.shape[0] // grad_steps) + x.shape[1:]), batch) # split into mini-batches
def compute_loss(params, batch, rng):
rng, rng_dropout = jax.random.split(rng)
losses_dict = diff_lm.apply(params, batch, rng, rngs = {'dropout' : rng_dropout}) # dict
return losses_dict['loss'], losses_dict # int, dict
grad_fn = jax.value_and_grad(compute_loss, has_aux = True)
## for grad accumulation ##
def get_minibatch(batch, grad_idx):
return jax.tree_util.tree_map(
lambda x: jax.lax.dynamic_index_in_dim(x, grad_idx, keepdims=False),
batch)
def loss_and_grad(grad_idx, train_rng):
# create minibatch for the grad step
minibatch = get_minibatch(batch, grad_idx) if grad_idx is not None else batch
sample_rng, train_rng = jax.random.split(train_rng, 2)
(loss, loss_dict), grad = grad_fn(state.params, minibatch, sample_rng) # int, tensor
return loss, loss_dict, grad, train_rng
if args.gradient_accumulation_steps == 1:
loss, loss_dict, grads, new_train_rng = loss_and_grad(None, rng)
else:
init_loss_grad_rng = (
0.0, # initial value for cumul_loss
jax.tree_map(jnp.zeros_like, state.params), # initial value for cumul_grad
rng, # initial value for train_rng
)
def cumul_grad_step(grad_idx, loss_grad_rng):
cumul_loss, cumul_grad, train_rng = loss_grad_rng
loss, loss_dict, grad, new_train_rng = loss_and_grad(grad_idx, train_rng)
cumul_loss, cumul_grad = jax.tree_map(jnp.add, (cumul_loss, cumul_grad), (loss, grad))
return cumul_loss, cumul_grad, new_train_rng
loss, grads, new_train_rng = jax.lax.fori_loop(
0, # from ind
args.gradient_accumulation_steps, # to ind
cumul_grad_step, # function
init_loss_grad_rng, # data to apply function to
)
loss, grads = jax.tree_map(lambda x: x / args.gradient_accumulation_steps, (loss, grads))
## done with grad accumulation ##
new_state = state.apply_gradients(grads=grads)
return new_state, new_train_rng, loss, loss_dict
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
total_train_batch_size = args.batch_size
max_train_steps = args.epochs * num_update_steps_per_epoch
dataset_length = len(train_dataloader)
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num Epochs = {args.epochs}")
logger.info(f" Instantaneous batch size per device = {total_train_batch_size}")
logger.info(f" Train batch per step = {args.batch_size // args.gradient_accumulation_steps}")
logger.info(f" Otimization steps per epoch = {dataset_length}")
if jax.process_index() == 0 and args.report_to == "wandb":
wandb.define_metric("*", step_metric="train/step")
wandb.define_metric("train/step", step_metric="walltime")
wandb.config.update(
{
"num_train_examples": len(train_dataset),
"total_train_batch_size": args.batch_size,
"total_optimization_step": args.epochs * num_update_steps_per_epoch,
"num_devices": jax.device_count(),
"diffusion_lm_params": sum(np.prod(x.shape) for x in jax.tree_util.tree_leaves(state.params)),
}
)
global_step = step0 = 0
epochs = tqdm(
range(args.epochs),
desc="Epoch ... ",
position=0,
disable=jax.process_index() > 0,)
t00 = t0 = time.monotonic()
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
options = orbax.checkpoint.CheckpointManagerOptions(max_to_keep=2, create=True)
checkpoint_manager = orbax.checkpoint.CheckpointManager(wandb.run.dir, orbax_checkpointer, options) # 'managed_ckpts'
save_args = orbax_utils.save_args_from_target(state)
for epoch in epochs:
#print(f'WANDB DIR: {wandb.run.dir}')
train_metrics = []
train_metrics_dict = defaultdict(list)
steps_per_epoch = (len(train_dataset) // total_train_batch_size)
train_step_progress_bar = tqdm(
total=steps_per_epoch,
desc="Training...",
position=1,
leave=False,
disable=jax.process_index() > 0,)
for batch in train_dataloader:
state, train_rng, loss, loss_dict = train_step(state, batch, train_rng)
train_metrics.append(loss)
train_step_progress_bar.update(1)
train_metrics_dict['mse'].append(loss_dict['mse'])
train_metrics_dict['tT_loss'].append(loss_dict['tT_loss'])
train_metrics_dict['decoder_nll'].append(loss_dict['decoder_nll'])
global_step += 1
if global_step >= max_train_steps:
break
if global_step % args.logging_steps == 0 and jax.process_index() == 0:
if args.report_to == "wandb":
train_metrics = jax.tree_util.tree_map(lambda *m: jnp.array(m).mean(), *train_metrics)
for k, v in train_metrics_dict.items():
train_metrics_dict[k] = jax.tree_util.tree_map(lambda *m: jnp.array(m).mean(), *v)
wandb.log(
{
"walltime": time.monotonic() - t00,
"train/step": global_step,
"train/epoch": global_step / dataset_length,
"train/steps_per_sec": (global_step - step0) / (time.monotonic() - t0),
"train/loss": train_metrics,
**{f"train/{k}": train_metrics_dict[k] for k in ['mse', 'tT_loss', 'decoder_nll']},
}
)
t0, step0 = time.monotonic(), global_step
train_metrics = []
train_metrics_dict = defaultdict(list)
if global_step % args.checkpointing_steps == 0 and jax.process_index() == 0:
checkpoint_manager.save(global_step, state, save_kwargs = {"save_args": save_args})
logger.info(f'Saved checkpoint at step {global_step}')
train_step_progress_bar.close()
epochs.write(f"Epoch... ({epoch + 1}/{args.epochs} | Loss: {loss})")
if jax.process_index() == 0:
checkpoint_manager.save(global_step, state, save_kwargs = {"save_args": save_args})
logger.info(f'Saved final checkpoint at step {global_step}')
if __name__ == "__main__":
main()