Skip to content

Commit

Permalink
Include mathematical breakdown for exercise solution 4.1 (#483)
Browse files Browse the repository at this point in the history
  • Loading branch information
rasbt authored Jan 15, 2025
1 parent b3150ee commit 126adb7
Showing 1 changed file with 54 additions and 2 deletions.
56 changes: 54 additions & 2 deletions ch04/01_main-chapter-code/exercise-solutions.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,33 @@
"execution_count": 2,
"id": "2751b0e5-ffd3-4be2-8db3-e20dd4d61d69",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"TransformerBlock(\n",
" (att): MultiHeadAttention(\n",
" (W_query): Linear(in_features=768, out_features=768, bias=False)\n",
" (W_key): Linear(in_features=768, out_features=768, bias=False)\n",
" (W_value): Linear(in_features=768, out_features=768, bias=False)\n",
" (out_proj): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (ff): FeedForward(\n",
" (layers): Sequential(\n",
" (0): Linear(in_features=768, out_features=3072, bias=True)\n",
" (1): GELU()\n",
" (2): Linear(in_features=3072, out_features=768, bias=True)\n",
" )\n",
" )\n",
" (norm1): LayerNorm()\n",
" (norm2): LayerNorm()\n",
" (drop_shortcut): Dropout(p=0.1, inplace=False)\n",
")\n"
]
}
],
"source": [
"from gpt import TransformerBlock\n",
"\n",
Expand All @@ -76,7 +102,8 @@
" \"qkv_bias\": False\n",
"}\n",
"\n",
"block = TransformerBlock(GPT_CONFIG_124M)"
"block = TransformerBlock(GPT_CONFIG_124M)\n",
"print(block)"
]
},
{
Expand Down Expand Up @@ -126,6 +153,31 @@
"- Optionally multiply by 12 to capture all transformer blocks in the 124M GPT model"
]
},
{
"cell_type": "markdown",
"id": "597e9251-e0a9-4972-8df6-f280f35939f9",
"metadata": {},
"source": [
"**Bonus: Mathematical breakdown**\n",
"\n",
"- For those interested in how these parameter counts are calculated mathematically, you can find the breakdown below (assuming `emb_dim=768`):\n",
"\n",
"\n",
"Feed forward module:\n",
"\n",
"- 1st `Linear` layer: 768 inputs × 4×768 outputs + 4×768 bias units = 2,362,368\n",
"- 2nd `Linear` layer: 4×768 inputs × 768 outputs + 768 bias units = 2,360,064\n",
"- Total: 1st `Linear` layer + 2nd `Linear` layer = 2,362,368 + 2,360,064 = 4,722,432\n",
"\n",
"Attention module:\n",
"\n",
"- `W_query`: 768 inputs × 768 outputs = 589,824 \n",
"- `W_key`: 768 inputs × 768 outputs = 589,824\n",
"- `W_value`: 768 inputs × 768 outputs = 589,824 \n",
"- `out_proj`: 768 inputs × 768 outputs + 768 bias units = 590,592\n",
"- Total: `W_query` + `W_key` + `W_value` + `out_proj` = 3×589,824 + 590,592 = 2,360,064 "
]
},
{
"cell_type": "markdown",
"id": "0f7b7c7f-0fa1-4d30-ab44-e499edd55b6d",
Expand Down

0 comments on commit 126adb7

Please sign in to comment.