Skip to content

Commit

Permalink
Drop interleave placement in QKV matrix (#1013)
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrei-Aksionov authored Dec 26, 2024
1 parent db308ba commit fabf765
Show file tree
Hide file tree
Showing 14 changed files with 725 additions and 598 deletions.
2 changes: 1 addition & 1 deletion litgpt/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def scaled_dot_product_attention(
ak, av = self.adapter_kv_cache
else:
prefix = self.adapter_wte.weight.reshape(1, aT, self.config.n_embd)
aqkv = self.attn(prefix)
aqkv = self.qkv(prefix)
q_per_kv = self.config.n_head // self.config.n_query_groups
aqkv = aqkv.view(1, aT, self.config.n_query_groups, q_per_kv + 2, self.config.head_size)
aqkv = aqkv.permute(0, 2, 3, 1, 4)
Expand Down
16 changes: 12 additions & 4 deletions litgpt/adapter_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from litgpt.adapter import CausalSelfAttention as BaseCausalSelfAttention
from litgpt.adapter import Config as BaseConfig
from litgpt.model import KVCache
from litgpt.scripts.convert_hf_checkpoint import qkv_reassemble
from litgpt.utils import map_old_state_dict_weights


Expand Down Expand Up @@ -163,7 +164,7 @@ def __init__(self, config: Config, block_idx: int) -> None:
nn.Module.__init__(self)
shape = (config.n_head + 2 * config.n_query_groups) * config.head_size
# key, query, value projections for all heads, but in a batch
self.attn = AdapterV2Linear(in_features=config.n_embd, out_features=shape, bias=config.bias or config.attn_bias)
self.qkv = AdapterV2Linear(in_features=config.n_embd, out_features=shape, bias=config.bias or config.attn_bias)
# output projection
# if `head_size` is explicitly specified in the config, `n_emd` might not be equal to `head_size * n_head`
self.proj = AdapterV2Linear(config.head_size * config.n_head, config.n_embd, bias=config.bias)
Expand All @@ -186,17 +187,24 @@ def __init__(self, config: Config, block_idx: int) -> None:
self.config = config

def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
"""For compatibility with base checkpoints."""
"""For compatibility with base and/or legacy checkpoints."""
mapping = {
"attn.weight": "attn.linear.weight",
"attn.bias": "attn.linear.bias",
"qkv.weight": "qkv.linear.weight",
"qkv.bias": "qkv.linear.bias",
"proj.weight": "proj.linear.weight",
"proj.bias": "proj.linear.bias",
}
state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)
# For compatibility with older checkpoints
if (key := prefix + "gating_factor") in state_dict and state_dict[key].size(1) == self.config.n_head:
state_dict[key] = state_dict[key].permute(0, 2, 1, 3)

for attr in ("weight", "bias"):
legacy_key = f"{prefix}attn.linear.{attr}"
current_key = f"{prefix}qkv.linear.{attr}"
if legacy_key in state_dict:
state_dict[current_key] = qkv_reassemble(state_dict.pop(legacy_key), self.config)

super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)


Expand Down
13 changes: 6 additions & 7 deletions litgpt/generate/tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,30 @@
import logging
import sys
import time
import warnings
from functools import partial
from pathlib import Path
from pprint import pprint
from typing import Literal, Optional, Union
import warnings

import lightning as L
from lightning_utilities.core.imports import RequirementCache
import torch
import torch._dynamo.config
import torch._inductor.config
from lightning.fabric.plugins import BitsandbytesPrecision
from lightning.fabric.utilities import rank_zero_only
from lightning_utilities.core.imports import RequirementCache

import litgpt.generate.base as generate_base
from litgpt.model import GPT
from litgpt.config import Config
from litgpt.tokenizer import Tokenizer
from litgpt.model import CausalSelfAttention, GptNeoxMLP, LLaMAMLP, LLaMAMoE
from litgpt.model import GPT, CausalSelfAttention, GptNeoxMLP, LLaMAMLP, LLaMAMoE
from litgpt.prompts import PromptStyle, has_prompt_style, load_prompt_style
from litgpt.tokenizer import Tokenizer
from litgpt.utils import (
check_nvlink_connectivity,
check_valid_checkpoint_dir,
extend_checkpoint_dir,
get_default_supported_precision
get_default_supported_precision,
)


Expand Down Expand Up @@ -71,7 +70,7 @@ def tensor_parallel_mlp(fabric: L.Fabric, mlp: Union[GptNeoxMLP, LLaMAMLP, LLaMA


def tensor_parallel_attn(fabric: L.Fabric, attn: CausalSelfAttention) -> None:
tensor_parallel_linear(fabric, attn.attn, "colwise")
tensor_parallel_linear(fabric, attn.qkv, "colwise")
tensor_parallel_linear(fabric, attn.proj, "rowwise")
attn.register_forward_hook(partial(all_reduce_output, fabric.world_size))

Expand Down
63 changes: 25 additions & 38 deletions litgpt/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
from litgpt.model import Block as BaseBlock
from litgpt.model import CausalSelfAttention as BaseCausalSelfAttention
from litgpt.model import KVCache
from litgpt.scripts.convert_hf_checkpoint import qkv_reassemble
from litgpt.utils import map_old_state_dict_weights


Expand Down Expand Up @@ -267,18 +268,14 @@ def lora_ind(self) -> torch.Tensor:
# Indices are needed to properly pad weight updates with zeros.
if not hasattr(self, "_lora_ind"):
enable_q, enable_k, enable_v = self.enable_lora
qkv_group_size = self.n_head // self.n_query_groups + 2
candidate_indices = range(self.linear.out_features)
kv_embd_size = self.linear.in_features // (self.n_head // self.n_query_groups)
lora_ind = []
if enable_q:
q_ind = [x for x in candidate_indices if (x // self.head_size) % qkv_group_size < qkv_group_size - 2]
lora_ind.extend(q_ind)
lora_ind.extend(range(0, self.linear.in_features))
if enable_k:
k_ind = [x for x in candidate_indices if (x // self.head_size) % qkv_group_size == qkv_group_size - 2]
lora_ind.extend(k_ind)
lora_ind.extend(range(self.linear.in_features, self.linear.in_features + kv_embd_size))
if enable_v:
v_ind = [x for x in candidate_indices if (x // self.head_size) % qkv_group_size == qkv_group_size - 1]
lora_ind.extend(v_ind)
lora_ind.extend(range(self.linear.in_features + kv_embd_size, self.linear.out_features))
self.register_buffer(
"_lora_ind", torch.tensor(lora_ind, device=self.linear.weight.device), persistent=False
)
Expand All @@ -298,27 +295,6 @@ def zero_pad(self, x: torch.Tensor) -> torch.Tensor:
________________________________________
| query | key | value |
----------------------------------------
For Llama2's GQA support, Q, K, and V weights are interleaved, so that weights for grouped
queries are adjacent to their associated key and value weights.
For example, suppose we have n_head = 12 with 3 query groups.
Then along the embedding dimension the interleaved weights would look like
[Q, Q, Q, Q, K, V, Q, Q, Q, Q, K, V, Q, Q, Q, Q, K, V],
where each Q, K, and V has size head_size.
In this case, the previously-described weight update applies separately to each
individual block, so the update will take the form
[[ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW, ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW, ...],
[.............................................................................],
[ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW, ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW, ...]]
↑ ↑ ↑ ↑ ↑ ↑
________________________________________________________________________________
| q block 1 | k block 1 | v block 1 | q block 2 | k block 2 | v block 2 | ...
--------------------------------------------------------------------------------
Note that in the above diagram, the size of each q block will equal q_per_kv
times the size of each k and v block.
Args:
x: tensor with weights update that will be padded with zeros if necessary
Expand Down Expand Up @@ -391,7 +367,9 @@ def get_lora_AB(self) -> torch.Tensor:
lora = self.conv1d(
self.lora_A.data.unsqueeze(0), # (4, 128) -> (1, 4, 128)
self.lora_B.data.unsqueeze(-1), # (256, 2) -> (256, 2, 1)
).squeeze(0) # (1, 4, 128) @ (256, 2, 1) -> (1, 256, 128) -> (256, 128)
).squeeze(
0
) # (1, 4, 128) @ (256, 2, 1) -> (1, 256, 128) -> (256, 128)
return self.zero_pad(lora.T * self.scaling).T # (256, 128) after zero_pad (384, 128)

def merge(self) -> None:
Expand Down Expand Up @@ -430,7 +408,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
after_B = self.conv1d(
after_A.transpose(-2, -1), # (64, 64, 4) -> (64, 4, 64)
self.lora_B.unsqueeze(-1), # (256, 2) -> (256, 2, 1)
).transpose(-2, -1) # (64, 4, 64) @ (256, 2, 1) -> (64, 256, 64) -> (64, 64, 256)
).transpose(
-2, -1
) # (64, 4, 64) @ (256, 2, 1) -> (64, 256, 64) -> (64, 64, 256)
lora = self.zero_pad(after_B) * self.scaling # (64, 64, 256) after zero_pad (64, 64, 384)
return pretrained + lora

Expand Down Expand Up @@ -602,7 +582,7 @@ def __init__(self, config: Config, block_idx: int) -> None:
nn.Module.__init__(self)
shape = (config.n_head + 2 * config.n_query_groups) * config.head_size
# key, query, value projections for all heads, but in a batch
self.attn = LoRAQKVLinear(
self.qkv = LoRAQKVLinear(
in_features=config.n_embd,
out_features=shape,
r=config.lora_r,
Expand All @@ -628,21 +608,28 @@ def __init__(self, config: Config, block_idx: int) -> None:
# disabled by default
self.kv_cache: Optional[KVCache] = None
self.apply_sliding_window_attention = (
config.sliding_window_size is not None and
block_idx % config.sliding_window_layer_stride == 0
config.sliding_window_size is not None and
block_idx % config.sliding_window_layer_stride == 0
)

self.config = config

def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None:
"""For compatibility with base checkpoints."""
"""For compatibility with base and/or legacy checkpoints."""
mapping = {
"attn.weight": "attn.linear.weight",
"attn.bias": "attn.linear.bias",
"qkv.weight": "qkv.linear.weight",
"qkv.bias": "qkv.linear.bias",
"proj.weight": "proj.linear.weight",
"proj.bias": "proj.linear.bias",
}
state_dict = map_old_state_dict_weights(state_dict, mapping, prefix)

for attr in ("weight", "bias"):
legacy_key = f"{prefix}attn.linear.{attr}"
current_key = f"{prefix}qkv.linear.{attr}"
if legacy_key in state_dict:
state_dict[current_key] = qkv_reassemble(state_dict.pop(legacy_key), self.config)

super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)


Expand Down Expand Up @@ -758,4 +745,4 @@ def merge_lora_weights(model: GPT) -> None:
"""Merge LoRA weights into the full-rank weights to speed up inference."""
for module in model.modules():
if isinstance(module, LoRALinear):
module.merge()
module.merge()
Loading

0 comments on commit fabf765

Please sign in to comment.