Skip to content

Commit

Permalink
new experiment w/o causal mask
Browse files Browse the repository at this point in the history
  • Loading branch information
rasbt committed May 18, 2024
1 parent 57634f2 commit 5ef4edf
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 11 deletions.
4 changes: 4 additions & 0 deletions ch06/02_bonus_additional-experiments/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ For example,
| 10 | gpt2-small (124M) | pretrained | last | last_block | context length (1024) | 83.08% | 87.92% | 78.33% | 2.46 min | A100 |
| 11 | gpt2-small (124M) | pretrained | last | last_block | variable: no padding (batch size 1) | 100.00% | 98.66% | 98.00% | 1.75 min | A100 |
| 12 | gpt2-small (124M) | pretrained | last | last_block | variable: no padding (batch size 8) | 99.33% | 98.66% | 98.33% | 1.70 min | A100 |
| 13 | gpt2-small (124M) | pretrained | last | last_block | longest train ex. (120); but no causal mask | 99.23% | 98.66% | 95.33% | 0.29 min | A100 |


 
Expand All @@ -43,6 +44,7 @@ You can use the following code to reproduce the experiments:
- Row 10: `python additional-experiments.py --context_length "model_context_length"`
- Row 11: `python additional-experiments.py --no_padding --batch_size 1`
- Row 12: `python additional-experiments.py --no_padding --batch_size 1 --accumulation_steps 8`
- Row 13: `python additional-experiments.py --disable_causal_mask`

I've kept the LLM and dataset small on purpose, so you can run the training on a regular laptop like a MacBook Air M3 in about 15 minutes in case you don't have access to a GPU.

Expand All @@ -65,3 +67,5 @@ I've kept the LLM and dataset small on purpose, so you can run the training on a
7. **Padding Input to Full Context Length vs. Longest Training Example (Row 1 vs. 10)**: Padding the input to the full supported context length results is significantly worse.

8. **Padding vs no padding (Row 1 vs. 11 and 12)**: The `--no_padding` option disables the padding in the dataset, which requires training the model with a batch size of 1 since the inputs have variable lengths. This results in a better test accuracy but takes longer to train. In row 12, we additionally enable gradient accumulation with 8 steps to achieve the same batch size as in the other experiments, which helps reduce overfitting and slightly boost the test set accuracy.

9. **Disabling the causal attention mask (Row 1 vs. 13)**: Disables the causal attention mask used in the multi-head attention module. This means all tokens can attend all other tokens. The model accuracy is slightly improved compared to the GPT model with causal mask.
11 changes: 10 additions & 1 deletion ch06/02_bonus_additional-experiments/additional-experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def instantiate_model(choose_model, load_weights):

if not load_weights:
torch.manual_seed(123)
model = GPTModel(BASE_CONFIG)
model = GPTModel(BASE_CONFIG, disable_causal_mask=args.disable_causal_mask)

if load_weights:
model_size = choose_model.split(" ")[-1].lstrip("(").rstrip(")")
Expand Down Expand Up @@ -386,6 +386,15 @@ def replace_linear_with_lora(model, rank, alpha):
)
)

parser.add_argument(
"--disable_causal_mask",
action='store_true',
default=False,
help=(
"Disables the causal attention mask."
)
)

args = parser.parse_args()

if args.trainable_token == "first":
Expand Down
26 changes: 16 additions & 10 deletions ch06/02_bonus_additional-experiments/previous_chapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def create_dataloader_v1(txt, batch_size=4, max_length=256,
# Chapter 3
#####################################
class MultiHeadAttention(nn.Module):
def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False, disable_causal_mask=False):
super().__init__()
assert d_out % num_heads == 0, "d_out must be divisible by n_heads"

Expand All @@ -73,7 +73,10 @@ def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=Fal
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs
self.dropout = nn.Dropout(dropout)
self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))

if not disable_causal_mask:
self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))
self.disable_causal_mask = disable_causal_mask

def forward(self, x):
b, num_tokens, d_in = x.shape
Expand All @@ -96,11 +99,12 @@ def forward(self, x):
# Compute scaled dot-product attention (aka self-attention) with a causal mask
attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head

# Original mask truncated to the number of tokens and converted to boolean
mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
if not self.disable_causal_mask:
# Original mask truncated to the number of tokens and converted to boolean
mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

# Use the mask to fill attention scores
attn_scores.masked_fill_(mask_bool, -torch.inf)
# Use the mask to fill attention scores
attn_scores.masked_fill_(mask_bool, -torch.inf)

attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
attn_weights = self.dropout(attn_weights)
Expand Down Expand Up @@ -157,15 +161,17 @@ def forward(self, x):


class TransformerBlock(nn.Module):
def __init__(self, cfg):
def __init__(self, cfg, disable_causal_mask=False):
super().__init__()
self.att = MultiHeadAttention(
d_in=cfg["emb_dim"],
d_out=cfg["emb_dim"],
context_length=cfg["context_length"],
num_heads=cfg["n_heads"],
dropout=cfg["drop_rate"],
qkv_bias=cfg["qkv_bias"])
qkv_bias=cfg["qkv_bias"],
disable_causal_mask=disable_causal_mask
)
self.ff = FeedForward(cfg)
self.norm1 = LayerNorm(cfg["emb_dim"])
self.norm2 = LayerNorm(cfg["emb_dim"])
Expand All @@ -190,14 +196,14 @@ def forward(self, x):


class GPTModel(nn.Module):
def __init__(self, cfg):
def __init__(self, cfg, disable_causal_mask=False):
super().__init__()
self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
self.drop_emb = nn.Dropout(cfg["drop_rate"])

self.trf_blocks = nn.Sequential(
*[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])
*[TransformerBlock(cfg, disable_causal_mask) for _ in range(cfg["n_layers"])])

self.final_norm = LayerNorm(cfg["emb_dim"])
self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False)
Expand Down

0 comments on commit 5ef4edf

Please sign in to comment.