Skip to content

Commit

Permalink
Add mean pooling experiment to classifier bonus experiments (#406)
Browse files Browse the repository at this point in the history
* Add mean pooling experiment to classifier bonus  experiments

* formatting

* add average embeddings option

* pep8
  • Loading branch information
rasbt authored Oct 20, 2024
1 parent 467197b commit 3896986
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 46 deletions.
3 changes: 3 additions & 0 deletions ch06/02_bonus_additional-experiments/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ For example,
| 15 | gpt2-small (124M) | pretrained | last | last_block | variable: no padding (batch size 8) | 99.33% | 98.66% | 98.33% | 1.70 min | A100 |
| 16 | gpt2-small (124M) | pretrained | last | last_block | longest train ex. (120); but no causal mask | 99.23% | 98.66% | 95.33% | 0.29 min | A100 |
| 17 | gpt2-small (124M) | pretrained | last | last_block | longest train ex. (120) and `ignore_index` for padding | 96.63% | 99.33% | 95.00% | 0.28 min | A100 |
| 18 | gpt2-small (124M) | pretrained | last + pooled embeddings | last_block | longest train ex. (120) | 97.79% | 99.33% | 96.33% | 0.32 min | A100 |

 

Expand All @@ -52,6 +53,7 @@ You can use the following code to reproduce the experiments:
- Row 15: `python additional_experiments.py --no_padding --batch_size 1 --accumulation_steps 8`
- Row 16: `python additional_experiments.py --disable_causal_mask`
- Row 17: `python additional_experiments.py --ignore_index 50256`
- Row 18: `python additional_experiments.py --average embeddings`

I've kept the LLM and dataset small on purpose, so you can run the training on a regular laptop like a MacBook Air M3 in about 15 minutes (for the default setting) in case you don't have access to a GPU.

Expand All @@ -70,3 +72,4 @@ I've kept the LLM and dataset small on purpose, so you can run the training on a
9. **Padding vs no padding (Row 1 vs. 14 and 15)**: The `--no_padding` option disables the padding in the dataset, which requires training the model with a batch size of 1 since the inputs have variable lengths. This results in a better test accuracy but takes longer to train. In row 15, we additionally enable gradient accumulation with 8 steps to achieve the same batch size as in the other experiments, which helps reduce overfitting and slightly boost the test set accuracy.
10. **Disabling the causal attention mask (Row 1 vs. 16)**: Disables the causal attention mask used in the multi-head attention module. This means all tokens can attend all other tokens. The model accuracy is slightly improved compared to the GPT model with causal mask.
11. **Ignoring the padding indices in the loss and backpropagation (Row 1 vs. 17)**: Setting `--ignore_index 50256` excludes the `|endoftext|` padding tokens in the `cross_entropy` loss function in PyTorch. In this case, it does not have any effect because we replaced the output layers so that the token IDs are either 0 or 1 for the binary classification example. However, this setting is useful when instruction finetuning models in chapter 7.
13. **Averaging the embeddings over all tokens (Row 1 vs. 18)**: Setting `--average_embeddings` will average the embeddings over all tokens. If this option is not used (the default), only the output embeddings at the chosen token position (specified by `--trainable_token_pos`) are considered; for example, the embeddings of the last token. Enabling `--average_embeddings` will mean-pool the embeddings of all tokens into the position chosen by `--trainable_token_pos` (the last token by default). As we can see, this improves the performance from 95.00% to 96.33% with only a minimal increase in run time (0.28 min to 0.32 min) and might be worthwhile considering in practice.
89 changes: 67 additions & 22 deletions ch06/02_bonus_additional-experiments/additional_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,15 +181,24 @@ def instantiate_model(choose_model, load_weights):


def calc_loss_batch(input_batch, target_batch, model, device,
trainable_token_pos=-1, ignore_index=-100):
trainable_token_pos=-1, ignore_index=-100, average_embeddings=False):
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
logits = model(input_batch)[:, trainable_token_pos, :] # Logits of last output token

model_output = model(input_batch)
if average_embeddings:
# Average over the sequence dimension (dim=1)
logits = model_output.mean(dim=1)
else:
# Select embeddings at the specified token position
logits = model_output[:, trainable_token_pos, :]

loss = torch.nn.functional.cross_entropy(logits, target_batch, ignore_index=ignore_index)
return loss


def calc_loss_loader(data_loader, model, device,
num_batches=None, trainable_token_pos=-1, ignore_index=-100):
num_batches=None, trainable_token_pos=-1,
ignore_index=-100, average_embeddings=False):
total_loss = 0.
if len(data_loader) == 0:
return float("nan")
Expand All @@ -203,7 +212,8 @@ def calc_loss_loader(data_loader, model, device,
if i < num_batches:
loss = calc_loss_batch(
input_batch, target_batch, model, device,
trainable_token_pos=trainable_token_pos, ignore_index=ignore_index
trainable_token_pos=trainable_token_pos, ignore_index=ignore_index,
average_embeddings=average_embeddings
)
total_loss += loss.item()
else:
Expand All @@ -212,7 +222,8 @@ def calc_loss_loader(data_loader, model, device,


@torch.no_grad() # Disable gradient tracking for efficiency
def calc_accuracy_loader(data_loader, model, device, num_batches=None, trainable_token_pos=-1):
def calc_accuracy_loader(data_loader, model, device, num_batches=None,
trainable_token_pos=-1, average_embeddings=False):
model.eval()
correct_predictions, num_examples = 0, 0

Expand All @@ -223,7 +234,15 @@ def calc_accuracy_loader(data_loader, model, device, num_batches=None, trainable
for i, (input_batch, target_batch) in enumerate(data_loader):
if i < num_batches:
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
logits = model(input_batch)[:, trainable_token_pos, :] # Logits of last output token

model_output = model(input_batch)
if average_embeddings:
# Average over the sequence dimension (dim=1)
logits = model_output.mean(dim=1)
else:
# Select embeddings at the specified token position
logits = model_output[:, trainable_token_pos, :]

predicted_labels = torch.argmax(logits, dim=-1)

num_examples += predicted_labels.shape[0]
Expand All @@ -234,24 +253,27 @@ def calc_accuracy_loader(data_loader, model, device, num_batches=None, trainable


def evaluate_model(model, train_loader, val_loader, device,
eval_iter, trainable_token_pos=-1, ignore_index=-100):
eval_iter, trainable_token_pos=-1,
ignore_index=-100, average_embeddings=False):
model.eval()
with torch.no_grad():
train_loss = calc_loss_loader(
train_loader, model, device, num_batches=eval_iter,
trainable_token_pos=trainable_token_pos, ignore_index=ignore_index
trainable_token_pos=trainable_token_pos, ignore_index=ignore_index,
average_embeddings=average_embeddings
)
val_loss = calc_loss_loader(
val_loader, model, device, num_batches=eval_iter,
trainable_token_pos=trainable_token_pos, ignore_index=ignore_index
trainable_token_pos=trainable_token_pos, ignore_index=ignore_index,
average_embeddings=average_embeddings
)
model.train()
return train_loss, val_loss


def train_classifier_simple(model, train_loader, val_loader, optimizer, device, num_epochs,
eval_freq, eval_iter, max_steps=None, trainable_token_pos=-1,
accumulation_steps=1, ignore_index=-100):
accumulation_steps=1, ignore_index=-100, average_embeddings=False):
# Initialize lists to track losses and tokens seen
train_losses, val_losses, train_accs, val_accs = [], [], [], []
examples_seen, global_step = 0, -1
Expand All @@ -263,7 +285,8 @@ def train_classifier_simple(model, train_loader, val_loader, optimizer, device,
for batch_idx, (input_batch, target_batch) in enumerate(train_loader):
loss = calc_loss_batch(
input_batch, target_batch, model, device,
trainable_token_pos=trainable_token_pos, ignore_index=ignore_index
trainable_token_pos=trainable_token_pos, ignore_index=ignore_index,
average_embeddings=average_embeddings
)

# Use gradient accumulation if accumulation_steps > 1
Expand All @@ -286,7 +309,8 @@ def train_classifier_simple(model, train_loader, val_loader, optimizer, device,
if global_step % eval_freq == 0:
train_loss, val_loss = evaluate_model(
model, train_loader, val_loader, device, eval_iter,
trainable_token_pos=trainable_token_pos, ignore_index=ignore_index
trainable_token_pos=trainable_token_pos, ignore_index=ignore_index,
average_embeddings=average_embeddings
)
train_losses.append(train_loss)
val_losses.append(val_loss)
Expand All @@ -297,8 +321,14 @@ def train_classifier_simple(model, train_loader, val_loader, optimizer, device,
break

# New: Calculate accuracy after each epoch
train_accuracy = calc_accuracy_loader(train_loader, model, device, num_batches=eval_iter, trainable_token_pos=trainable_token_pos)
val_accuracy = calc_accuracy_loader(val_loader, model, device, num_batches=eval_iter, trainable_token_pos=trainable_token_pos)
train_accuracy = calc_accuracy_loader(
train_loader, model, device, num_batches=eval_iter,
trainable_token_pos=trainable_token_pos, average_embeddings=average_embeddings
)
val_accuracy = calc_accuracy_loader(
val_loader, model, device, num_batches=eval_iter,
trainable_token_pos=trainable_token_pos, average_embeddings=average_embeddings
)
print(f"Training accuracy: {train_accuracy*100:.2f}% | ", end="")
print(f"Validation accuracy: {val_accuracy*100:.2f}%")
train_accs.append(train_accuracy)
Expand Down Expand Up @@ -359,13 +389,22 @@ def replace_linear_with_lora(model, rank, alpha, alternative=False):
"Which token position to train. Options: 'first', 'last'."
)
)
parser.add_argument(
"--average_embeddings",
action='store_true',
default=False,
help=(
"Average the output embeddings from all tokens instead of using"
" only the embedding at the token position specified by `--trainable_token_pos`."
)
)
parser.add_argument(
"--context_length",
type=str,
default="longest_training_example",
help=(
"The context length of the data inputs."
"Options: 'longest_training_example', 'model_context_length' or integer value."
" Options: 'longest_training_example', 'model_context_length' or integer value."
)
)
parser.add_argument(
Expand Down Expand Up @@ -409,7 +448,6 @@ def replace_linear_with_lora(model, rank, alpha, alternative=False):
"The batch size used for training."
)
)

parser.add_argument(
"--accumulation_steps",
type=int,
Expand All @@ -422,7 +460,6 @@ def replace_linear_with_lora(model, rank, alpha, alternative=False):
" the latter setting uses more iterations."
)
)

parser.add_argument(
"--disable_causal_mask",
action='store_true',
Expand All @@ -431,7 +468,6 @@ def replace_linear_with_lora(model, rank, alpha, alternative=False):
"Disables the causal attention mask."
)
)

parser.add_argument(
"--ignore_index",
type=int,
Expand Down Expand Up @@ -589,7 +625,7 @@ def replace_linear_with_lora(model, rank, alpha, alternative=False):
model, train_loader, val_loader, optimizer, device,
num_epochs=args.num_epochs, eval_freq=50, eval_iter=5,
max_steps=None, trainable_token_pos=args.trainable_token_pos,
accumulation_steps=args.accumulation_steps
accumulation_steps=args.accumulation_steps, average_embeddings=args.average_embeddings
)

end_time = time.time()
Expand All @@ -600,9 +636,18 @@ def replace_linear_with_lora(model, rank, alpha, alternative=False):
# Evaluate model
###############################

train_accuracy = calc_accuracy_loader(train_loader, model, device, trainable_token_pos=args.trainable_token_pos)
val_accuracy = calc_accuracy_loader(val_loader, model, device, trainable_token_pos=args.trainable_token_pos)
test_accuracy = calc_accuracy_loader(test_loader, model, device, trainable_token_pos=args.trainable_token_pos)
train_accuracy = calc_accuracy_loader(
train_loader, model, device,
trainable_token_pos=args.trainable_token_pos, average_embeddings=args.average_embeddings
)
val_accuracy = calc_accuracy_loader(
val_loader, model, device,
trainable_token_pos=args.trainable_token_pos, average_embeddings=args.average_embeddings
)
test_accuracy = calc_accuracy_loader(
test_loader, model, device,
trainable_token_pos=args.trainable_token_pos, average_embeddings=args.average_embeddings
)

print(f"Training accuracy: {train_accuracy*100:.2f}%")
print(f"Validation accuracy: {val_accuracy*100:.2f}%")
Expand Down
Loading

0 comments on commit 3896986

Please sign in to comment.