diff --git a/ch06/02_bonus_additional-experiments/README.md b/ch06/02_bonus_additional-experiments/README.md index 4dfec41d..3e8a2eb6 100644 --- a/ch06/02_bonus_additional-experiments/README.md +++ b/ch06/02_bonus_additional-experiments/README.md @@ -28,6 +28,7 @@ For example, | 15 | gpt2-small (124M) | pretrained | last | last_block | variable: no padding (batch size 8) | 99.33% | 98.66% | 98.33% | 1.70 min | A100 | | 16 | gpt2-small (124M) | pretrained | last | last_block | longest train ex. (120); but no causal mask | 99.23% | 98.66% | 95.33% | 0.29 min | A100 | | 17 | gpt2-small (124M) | pretrained | last | last_block | longest train ex. (120) and `ignore_index` for padding | 96.63% | 99.33% | 95.00% | 0.28 min | A100 | +| 18 | gpt2-small (124M) | pretrained | last + pooled embeddings | last_block | longest train ex. (120) | 97.79% | 99.33% | 96.33% | 0.32 min | A100 |   @@ -52,6 +53,7 @@ You can use the following code to reproduce the experiments: - Row 15: `python additional_experiments.py --no_padding --batch_size 1 --accumulation_steps 8` - Row 16: `python additional_experiments.py --disable_causal_mask` - Row 17: `python additional_experiments.py --ignore_index 50256` +- Row 18: `python additional_experiments.py --average embeddings` I've kept the LLM and dataset small on purpose, so you can run the training on a regular laptop like a MacBook Air M3 in about 15 minutes (for the default setting) in case you don't have access to a GPU. @@ -70,3 +72,4 @@ I've kept the LLM and dataset small on purpose, so you can run the training on a 9. **Padding vs no padding (Row 1 vs. 14 and 15)**: The `--no_padding` option disables the padding in the dataset, which requires training the model with a batch size of 1 since the inputs have variable lengths. This results in a better test accuracy but takes longer to train. In row 15, we additionally enable gradient accumulation with 8 steps to achieve the same batch size as in the other experiments, which helps reduce overfitting and slightly boost the test set accuracy. 10. **Disabling the causal attention mask (Row 1 vs. 16)**: Disables the causal attention mask used in the multi-head attention module. This means all tokens can attend all other tokens. The model accuracy is slightly improved compared to the GPT model with causal mask. 11. **Ignoring the padding indices in the loss and backpropagation (Row 1 vs. 17)**: Setting `--ignore_index 50256` excludes the `|endoftext|` padding tokens in the `cross_entropy` loss function in PyTorch. In this case, it does not have any effect because we replaced the output layers so that the token IDs are either 0 or 1 for the binary classification example. However, this setting is useful when instruction finetuning models in chapter 7. +13. **Averaging the embeddings over all tokens (Row 1 vs. 18)**: Setting `--average_embeddings` will average the embeddings over all tokens. If this option is not used (the default), only the output embeddings at the chosen token position (specified by `--trainable_token_pos`) are considered; for example, the embeddings of the last token. Enabling `--average_embeddings` will mean-pool the embeddings of all tokens into the position chosen by `--trainable_token_pos` (the last token by default). As we can see, this improves the performance from 95.00% to 96.33% with only a minimal increase in run time (0.28 min to 0.32 min) and might be worthwhile considering in practice. \ No newline at end of file diff --git a/ch06/02_bonus_additional-experiments/additional_experiments.py b/ch06/02_bonus_additional-experiments/additional_experiments.py index 47dd2150..e3b1b177 100644 --- a/ch06/02_bonus_additional-experiments/additional_experiments.py +++ b/ch06/02_bonus_additional-experiments/additional_experiments.py @@ -181,15 +181,24 @@ def instantiate_model(choose_model, load_weights): def calc_loss_batch(input_batch, target_batch, model, device, - trainable_token_pos=-1, ignore_index=-100): + trainable_token_pos=-1, ignore_index=-100, average_embeddings=False): input_batch, target_batch = input_batch.to(device), target_batch.to(device) - logits = model(input_batch)[:, trainable_token_pos, :] # Logits of last output token + + model_output = model(input_batch) + if average_embeddings: + # Average over the sequence dimension (dim=1) + logits = model_output.mean(dim=1) + else: + # Select embeddings at the specified token position + logits = model_output[:, trainable_token_pos, :] + loss = torch.nn.functional.cross_entropy(logits, target_batch, ignore_index=ignore_index) return loss def calc_loss_loader(data_loader, model, device, - num_batches=None, trainable_token_pos=-1, ignore_index=-100): + num_batches=None, trainable_token_pos=-1, + ignore_index=-100, average_embeddings=False): total_loss = 0. if len(data_loader) == 0: return float("nan") @@ -203,7 +212,8 @@ def calc_loss_loader(data_loader, model, device, if i < num_batches: loss = calc_loss_batch( input_batch, target_batch, model, device, - trainable_token_pos=trainable_token_pos, ignore_index=ignore_index + trainable_token_pos=trainable_token_pos, ignore_index=ignore_index, + average_embeddings=average_embeddings ) total_loss += loss.item() else: @@ -212,7 +222,8 @@ def calc_loss_loader(data_loader, model, device, @torch.no_grad() # Disable gradient tracking for efficiency -def calc_accuracy_loader(data_loader, model, device, num_batches=None, trainable_token_pos=-1): +def calc_accuracy_loader(data_loader, model, device, num_batches=None, + trainable_token_pos=-1, average_embeddings=False): model.eval() correct_predictions, num_examples = 0, 0 @@ -223,7 +234,15 @@ def calc_accuracy_loader(data_loader, model, device, num_batches=None, trainable for i, (input_batch, target_batch) in enumerate(data_loader): if i < num_batches: input_batch, target_batch = input_batch.to(device), target_batch.to(device) - logits = model(input_batch)[:, trainable_token_pos, :] # Logits of last output token + + model_output = model(input_batch) + if average_embeddings: + # Average over the sequence dimension (dim=1) + logits = model_output.mean(dim=1) + else: + # Select embeddings at the specified token position + logits = model_output[:, trainable_token_pos, :] + predicted_labels = torch.argmax(logits, dim=-1) num_examples += predicted_labels.shape[0] @@ -234,16 +253,19 @@ def calc_accuracy_loader(data_loader, model, device, num_batches=None, trainable def evaluate_model(model, train_loader, val_loader, device, - eval_iter, trainable_token_pos=-1, ignore_index=-100): + eval_iter, trainable_token_pos=-1, + ignore_index=-100, average_embeddings=False): model.eval() with torch.no_grad(): train_loss = calc_loss_loader( train_loader, model, device, num_batches=eval_iter, - trainable_token_pos=trainable_token_pos, ignore_index=ignore_index + trainable_token_pos=trainable_token_pos, ignore_index=ignore_index, + average_embeddings=average_embeddings ) val_loss = calc_loss_loader( val_loader, model, device, num_batches=eval_iter, - trainable_token_pos=trainable_token_pos, ignore_index=ignore_index + trainable_token_pos=trainable_token_pos, ignore_index=ignore_index, + average_embeddings=average_embeddings ) model.train() return train_loss, val_loss @@ -251,7 +273,7 @@ def evaluate_model(model, train_loader, val_loader, device, def train_classifier_simple(model, train_loader, val_loader, optimizer, device, num_epochs, eval_freq, eval_iter, max_steps=None, trainable_token_pos=-1, - accumulation_steps=1, ignore_index=-100): + accumulation_steps=1, ignore_index=-100, average_embeddings=False): # Initialize lists to track losses and tokens seen train_losses, val_losses, train_accs, val_accs = [], [], [], [] examples_seen, global_step = 0, -1 @@ -263,7 +285,8 @@ def train_classifier_simple(model, train_loader, val_loader, optimizer, device, for batch_idx, (input_batch, target_batch) in enumerate(train_loader): loss = calc_loss_batch( input_batch, target_batch, model, device, - trainable_token_pos=trainable_token_pos, ignore_index=ignore_index + trainable_token_pos=trainable_token_pos, ignore_index=ignore_index, + average_embeddings=average_embeddings ) # Use gradient accumulation if accumulation_steps > 1 @@ -286,7 +309,8 @@ def train_classifier_simple(model, train_loader, val_loader, optimizer, device, if global_step % eval_freq == 0: train_loss, val_loss = evaluate_model( model, train_loader, val_loader, device, eval_iter, - trainable_token_pos=trainable_token_pos, ignore_index=ignore_index + trainable_token_pos=trainable_token_pos, ignore_index=ignore_index, + average_embeddings=average_embeddings ) train_losses.append(train_loss) val_losses.append(val_loss) @@ -297,8 +321,14 @@ def train_classifier_simple(model, train_loader, val_loader, optimizer, device, break # New: Calculate accuracy after each epoch - train_accuracy = calc_accuracy_loader(train_loader, model, device, num_batches=eval_iter, trainable_token_pos=trainable_token_pos) - val_accuracy = calc_accuracy_loader(val_loader, model, device, num_batches=eval_iter, trainable_token_pos=trainable_token_pos) + train_accuracy = calc_accuracy_loader( + train_loader, model, device, num_batches=eval_iter, + trainable_token_pos=trainable_token_pos, average_embeddings=average_embeddings + ) + val_accuracy = calc_accuracy_loader( + val_loader, model, device, num_batches=eval_iter, + trainable_token_pos=trainable_token_pos, average_embeddings=average_embeddings + ) print(f"Training accuracy: {train_accuracy*100:.2f}% | ", end="") print(f"Validation accuracy: {val_accuracy*100:.2f}%") train_accs.append(train_accuracy) @@ -359,13 +389,22 @@ def replace_linear_with_lora(model, rank, alpha, alternative=False): "Which token position to train. Options: 'first', 'last'." ) ) + parser.add_argument( + "--average_embeddings", + action='store_true', + default=False, + help=( + "Average the output embeddings from all tokens instead of using" + " only the embedding at the token position specified by `--trainable_token_pos`." + ) + ) parser.add_argument( "--context_length", type=str, default="longest_training_example", help=( "The context length of the data inputs." - "Options: 'longest_training_example', 'model_context_length' or integer value." + " Options: 'longest_training_example', 'model_context_length' or integer value." ) ) parser.add_argument( @@ -409,7 +448,6 @@ def replace_linear_with_lora(model, rank, alpha, alternative=False): "The batch size used for training." ) ) - parser.add_argument( "--accumulation_steps", type=int, @@ -422,7 +460,6 @@ def replace_linear_with_lora(model, rank, alpha, alternative=False): " the latter setting uses more iterations." ) ) - parser.add_argument( "--disable_causal_mask", action='store_true', @@ -431,7 +468,6 @@ def replace_linear_with_lora(model, rank, alpha, alternative=False): "Disables the causal attention mask." ) ) - parser.add_argument( "--ignore_index", type=int, @@ -589,7 +625,7 @@ def replace_linear_with_lora(model, rank, alpha, alternative=False): model, train_loader, val_loader, optimizer, device, num_epochs=args.num_epochs, eval_freq=50, eval_iter=5, max_steps=None, trainable_token_pos=args.trainable_token_pos, - accumulation_steps=args.accumulation_steps + accumulation_steps=args.accumulation_steps, average_embeddings=args.average_embeddings ) end_time = time.time() @@ -600,9 +636,18 @@ def replace_linear_with_lora(model, rank, alpha, alternative=False): # Evaluate model ############################### - train_accuracy = calc_accuracy_loader(train_loader, model, device, trainable_token_pos=args.trainable_token_pos) - val_accuracy = calc_accuracy_loader(val_loader, model, device, trainable_token_pos=args.trainable_token_pos) - test_accuracy = calc_accuracy_loader(test_loader, model, device, trainable_token_pos=args.trainable_token_pos) + train_accuracy = calc_accuracy_loader( + train_loader, model, device, + trainable_token_pos=args.trainable_token_pos, average_embeddings=args.average_embeddings + ) + val_accuracy = calc_accuracy_loader( + val_loader, model, device, + trainable_token_pos=args.trainable_token_pos, average_embeddings=args.average_embeddings + ) + test_accuracy = calc_accuracy_loader( + test_loader, model, device, + trainable_token_pos=args.trainable_token_pos, average_embeddings=args.average_embeddings + ) print(f"Training accuracy: {train_accuracy*100:.2f}%") print(f"Validation accuracy: {val_accuracy*100:.2f}%") diff --git a/ch06/03_bonus_imdb-classification/train_gpt.py b/ch06/03_bonus_imdb-classification/train_gpt.py index ca092ea0..e9df13ad 100644 --- a/ch06/03_bonus_imdb-classification/train_gpt.py +++ b/ch06/03_bonus_imdb-classification/train_gpt.py @@ -81,14 +81,25 @@ def instantiate_model(choose_model, load_weights): return model -def calc_loss_batch(input_batch, target_batch, model, device, trainable_token=-1): +def calc_loss_batch(input_batch, target_batch, model, device, + trainable_token_pos=-1, average_embeddings=False): input_batch, target_batch = input_batch.to(device), target_batch.to(device) - logits = model(input_batch)[:, trainable_token, :] # Logits of last output token + + model_output = model(input_batch) + if average_embeddings: + # Average over the sequence dimension (dim=1) + logits = model_output.mean(dim=1) + else: + # Select embeddings at the specified token position + logits = model_output[:, trainable_token_pos, :] + loss = torch.nn.functional.cross_entropy(logits, target_batch) return loss -def calc_loss_loader(data_loader, model, device, num_batches=None, trainable_token=-1): +def calc_loss_loader(data_loader, model, device, + num_batches=None, trainable_token_pos=-1, + average_embeddings=False): total_loss = 0. if len(data_loader) == 0: return float("nan") @@ -100,7 +111,10 @@ def calc_loss_loader(data_loader, model, device, num_batches=None, trainable_tok num_batches = min(num_batches, len(data_loader)) for i, (input_batch, target_batch) in enumerate(data_loader): if i < num_batches: - loss = calc_loss_batch(input_batch, target_batch, model, device, trainable_token=trainable_token) + loss = calc_loss_batch( + input_batch, target_batch, model, device, + trainable_token_pos=trainable_token_pos, average_embeddings=average_embeddings + ) total_loss += loss.item() else: break @@ -108,7 +122,9 @@ def calc_loss_loader(data_loader, model, device, num_batches=None, trainable_tok @torch.no_grad() # Disable gradient tracking for efficiency -def calc_accuracy_loader(data_loader, model, device, num_batches=None, trainable_token=-1): +def calc_accuracy_loader(data_loader, model, device, + num_batches=None, trainable_token_pos=-1, + average_embeddings=False): model.eval() correct_predictions, num_examples = 0, 0 @@ -119,7 +135,15 @@ def calc_accuracy_loader(data_loader, model, device, num_batches=None, trainable for i, (input_batch, target_batch) in enumerate(data_loader): if i < num_batches: input_batch, target_batch = input_batch.to(device), target_batch.to(device) - logits = model(input_batch)[:, trainable_token, :] # Logits of last output token + + model_output = model(input_batch) + if average_embeddings: + # Average over the sequence dimension (dim=1) + logits = model_output.mean(dim=1) + else: + # Select embeddings at the specified token position + logits = model_output[:, trainable_token_pos, :] + predicted_labels = torch.argmax(logits, dim=-1) num_examples += predicted_labels.shape[0] @@ -129,17 +153,25 @@ def calc_accuracy_loader(data_loader, model, device, num_batches=None, trainable return correct_predictions / num_examples -def evaluate_model(model, train_loader, val_loader, device, eval_iter, trainable_token=-1): +def evaluate_model(model, train_loader, val_loader, device, eval_iter, + trainable_token_pos=-1, average_embeddings=False): model.eval() with torch.no_grad(): - train_loss = calc_loss_loader(train_loader, model, device, num_batches=eval_iter, trainable_token=trainable_token) - val_loss = calc_loss_loader(val_loader, model, device, num_batches=eval_iter, trainable_token=trainable_token) + train_loss = calc_loss_loader( + train_loader, model, device, num_batches=eval_iter, + trainable_token_pos=trainable_token_pos, average_embeddings=average_embeddings + ) + val_loss = calc_loss_loader( + val_loader, model, device, num_batches=eval_iter, + trainable_token_pos=trainable_token_pos, average_embeddings=average_embeddings + ) model.train() return train_loss, val_loss def train_classifier_simple(model, train_loader, val_loader, optimizer, device, num_epochs, - eval_freq, eval_iter, max_steps=None, trainable_token=-1): + eval_freq, eval_iter, max_steps=None, trainable_token_pos=-1, + average_embeddings=False): # Initialize lists to track losses and tokens seen train_losses, val_losses, train_accs, val_accs = [], [], [], [] examples_seen, global_step = 0, -1 @@ -150,7 +182,8 @@ def train_classifier_simple(model, train_loader, val_loader, optimizer, device, for input_batch, target_batch in train_loader: optimizer.zero_grad() # Reset loss gradients from previous batch iteration - loss = calc_loss_batch(input_batch, target_batch, model, device, trainable_token=trainable_token) + loss = calc_loss_batch(input_batch, target_batch, model, device, + trainable_token_pos=trainable_token_pos, average_embeddings=average_embeddings) loss.backward() # Calculate loss gradients optimizer.step() # Update model weights using loss gradients examples_seen += input_batch.shape[0] # New: track examples instead of tokens @@ -159,7 +192,9 @@ def train_classifier_simple(model, train_loader, val_loader, optimizer, device, # Optional evaluation step if global_step % eval_freq == 0: train_loss, val_loss = evaluate_model( - model, train_loader, val_loader, device, eval_iter, trainable_token=trainable_token) + model, train_loader, val_loader, device, eval_iter, + trainable_token_pos=trainable_token_pos, average_embeddings=average_embeddings + ) train_losses.append(train_loss) val_losses.append(val_loss) print(f"Ep {epoch+1} (Step {global_step:06d}): " @@ -169,8 +204,14 @@ def train_classifier_simple(model, train_loader, val_loader, optimizer, device, break # New: Calculate accuracy after each epoch - train_accuracy = calc_accuracy_loader(train_loader, model, device, num_batches=eval_iter, trainable_token=trainable_token) - val_accuracy = calc_accuracy_loader(val_loader, model, device, num_batches=eval_iter, trainable_token=trainable_token) + train_accuracy = calc_accuracy_loader( + train_loader, model, device, num_batches=eval_iter, + trainable_token_pos=trainable_token_pos, average_embeddings=average_embeddings + ) + val_accuracy = calc_accuracy_loader( + val_loader, model, device, num_batches=eval_iter, + trainable_token_pos=trainable_token_pos, average_embeddings=average_embeddings + ) print(f"Training accuracy: {train_accuracy*100:.2f}% | ", end="") print(f"Validation accuracy: {val_accuracy*100:.2f}%") train_accs.append(train_accuracy) @@ -211,13 +252,22 @@ def train_classifier_simple(model, train_loader, val_loader, optimizer, device, ) ) parser.add_argument( - "--trainable_token", + "--trainable_token_pos", type=str, default="last", help=( "Which token to train. Options: 'first', 'last'." ) ) + parser.add_argument( + "--average_embeddings", + action='store_true', + default=False, + help=( + "Average the output embeddings from all tokens instead of using" + " only the embedding at the token position specified by `--trainable_token_pos`." + ) + ) parser.add_argument( "--context_length", type=str, @@ -245,12 +295,12 @@ def train_classifier_simple(model, train_loader, val_loader, optimizer, device, ) args = parser.parse_args() - if args.trainable_token == "first": - args.trainable_token = 0 - elif args.trainable_token == "last": - args.trainable_token = -1 + if args.trainable_token_pos == "first": + args.trainable_token_pos = 0 + elif args.trainable_token_pos == "last": + args.trainable_token_pos = -1 else: - raise ValueError("Invalid --trainable_token argument") + raise ValueError("Invalid --trainable_token_pos argument") ############################### # Load model @@ -358,7 +408,8 @@ def train_classifier_simple(model, train_loader, val_loader, optimizer, device, train_losses, val_losses, train_accs, val_accs, examples_seen = train_classifier_simple( model, train_loader, val_loader, optimizer, device, num_epochs=args.num_epochs, eval_freq=50, eval_iter=20, - max_steps=None, trainable_token=args.trainable_token + max_steps=None, trainable_token_pos=args.trainable_token_pos, + average_embeddings=args.average_embeddings ) end_time = time.time() @@ -371,9 +422,18 @@ def train_classifier_simple(model, train_loader, val_loader, optimizer, device, print("\nEvaluating on the full datasets ...\n") - train_accuracy = calc_accuracy_loader(train_loader, model, device, trainable_token=args.trainable_token) - val_accuracy = calc_accuracy_loader(val_loader, model, device, trainable_token=args.trainable_token) - test_accuracy = calc_accuracy_loader(test_loader, model, device, trainable_token=args.trainable_token) + train_accuracy = calc_accuracy_loader( + train_loader, model, device, + trainable_token_pos=args.trainable_token_pos, average_embeddings=args.average_embeddings + ) + val_accuracy = calc_accuracy_loader( + val_loader, model, device, + trainable_token_pos=args.trainable_token_pos, average_embeddings=args.average_embeddings + ) + test_accuracy = calc_accuracy_loader( + test_loader, model, device, + trainable_token_pos=args.trainable_token_pos, average_embeddings=args.average_embeddings + ) print(f"Training accuracy: {train_accuracy*100:.2f}%") print(f"Validation accuracy: {val_accuracy*100:.2f}%")