Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama: Add support for RWKV v7 architecture #11452

Open
wants to merge 20 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
226 changes: 194 additions & 32 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -902,6 +902,40 @@ def _set_vocab_llama_hf(self):
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
special_vocab.add_to_gguf(self.gguf_writer)

def _set_vocab_rwkv_world(self):
assert (self.dir_model / "rwkv_vocab_v20230424.txt").is_file()
vocab_size = self.hparams.get("vocab_size", 65536)

tokens: list[bytes] = ['<s>'.encode("utf-8")]
toktypes: list[int] = [gguf.TokenType.CONTROL]

with open(self.dir_model / "rwkv_vocab_v20230424.txt", "r", encoding="utf-8") as f:
lines = f.readlines()
for line in lines:
parts = line.split(' ')
assert len(parts) >= 3
token, token_len = ast.literal_eval(' '.join(parts[1:-1])), int(parts[-1])
token = token.encode("utf-8") if isinstance(token, str) else token
assert isinstance(token, bytes)
assert len(token) == token_len
token_text: str = repr(token)[2:-1] # "b'\xff'" -> "\xff"
tokens.append(token_text.encode("utf-8"))
toktypes.append(gguf.TokenType.NORMAL)
remainder = vocab_size - len(tokens)
assert remainder >= 0
for i in range(len(tokens), vocab_size):
tokens.append(f"[PAD{i}]".encode("utf-8"))
toktypes.append(gguf.TokenType.UNUSED)

self.gguf_writer.add_tokenizer_model("rwkv")
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes)
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False)
special_vocab.chat_template = "rwkv-world"
# hack: Add '\n\n' as the EOT token to make it chat normally
special_vocab._set_special_token("eot", 261)
special_vocab.add_to_gguf(self.gguf_writer)

def _set_vocab_builtin(self, model_name: Literal["gpt-neox", "llama-spm"], vocab_size: int):
tokenizer_path = Path(sys.path[0]) / "models" / f"ggml-vocab-{model_name}.gguf"
logger.warning(f"Using tokenizer from '{os.path.relpath(tokenizer_path, os.getcwd())}'")
Expand Down Expand Up @@ -3327,38 +3361,7 @@ class Rwkv6Model(Model):
model_arch = gguf.MODEL_ARCH.RWKV6

def set_vocab(self):
assert (self.dir_model / "rwkv_vocab_v20230424.txt").is_file()
vocab_size = self.hparams.get("vocab_size", 65536)

tokens: list[bytes] = ['<s>'.encode("utf-8")]
toktypes: list[int] = [gguf.TokenType.CONTROL]

with open(self.dir_model / "rwkv_vocab_v20230424.txt", "r", encoding="utf-8") as f:
lines = f.readlines()
for line in lines:
parts = line.split(' ')
assert len(parts) >= 3
token, token_len = ast.literal_eval(' '.join(parts[1:-1])), int(parts[-1])
token = token.encode("utf-8") if isinstance(token, str) else token
assert isinstance(token, bytes)
assert len(token) == token_len
token_text: str = repr(token)[2:-1] # "b'\xff'" -> "\xff"
tokens.append(token_text.encode("utf-8"))
toktypes.append(gguf.TokenType.NORMAL)
remainder = vocab_size - len(tokens)
assert remainder >= 0
for i in range(len(tokens), vocab_size):
tokens.append(f"[PAD{i}]".encode("utf-8"))
toktypes.append(gguf.TokenType.UNUSED)

self.gguf_writer.add_tokenizer_model("rwkv")
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes)
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False)
special_vocab.chat_template = "rwkv-world"
# hack: Add '\n\n' as the EOT token to make it chat normally
special_vocab._set_special_token("eot", 261)
special_vocab.add_to_gguf(self.gguf_writer)
self._set_vocab_rwkv_world()

def set_gguf_parameters(self):
block_count = self.hparams["num_hidden_layers"]
Expand Down Expand Up @@ -3480,6 +3483,165 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
yield (new_name, data)


@Model.register("Rwkv7ForCausalLM", "RWKV7ForCausalLM")
class Rwkv7Model(Model):
model_arch = gguf.MODEL_ARCH.RWKV7

def set_vocab(self):
self._set_vocab_rwkv_world()

def calc_lora_rank(self, hidden_size, exponent, multiplier):
return max(1, round(hidden_size ** exponent * multiplier / 32)) * 32

def set_gguf_parameters(self):
block_count = self.hparams["num_hidden_layers"]
try:
head_size = self.hparams["head_size"]
layer_norm_eps = self.hparams["layer_norm_epsilon"]
except KeyError:
head_size = self.hparams["head_dim"]
layer_norm_eps = self.hparams["norm_eps"]
hidden_size = self.hparams["hidden_size"]
intermediate_size = self.hparams["intermediate_size"] if self.hparams["intermediate_size"] is not None else (hidden_size * 4)

# ICLR: In-Context-Learning-Rate
try:
lora_rank_decay = self.hparams["lora_rank_decay"] if self.hparams["lora_rank_decay"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.8)
lora_rank_iclr = self.hparams["lora_rank_iclr"] if self.hparams["lora_rank_iclr"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.8)
lora_rank_value_residual_mix = self.hparams["lora_rank_value_residual_mix"] if self.hparams["lora_rank_value_residual_mix"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.3)
lora_rank_gate = self.hparams["lora_rank_gate"] if self.hparams["lora_rank_gate"] is not None else self.calc_lora_rank(hidden_size, 0.8, 0.6)
except KeyError:
lora_rank_decay = self.hparams["decay_low_rank_dim"] if self.hparams["decay_low_rank_dim"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.8)
lora_rank_iclr = self.hparams["a_low_rank_dim"] if self.hparams["a_low_rank_dim"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.8)
lora_rank_value_residual_mix = self.hparams["v_low_rank_dim"] if self.hparams["v_low_rank_dim"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.3)
lora_rank_gate = self.hparams["gate_low_rank_dim"] if self.hparams["gate_low_rank_dim"] is not None else self.calc_lora_rank(hidden_size, 0.8, 0.6)

# RWKV isn't context limited
self.gguf_writer.add_context_length(1048576)
self.gguf_writer.add_embedding_length(hidden_size)
self.gguf_writer.add_block_count(block_count)
self.gguf_writer.add_layer_norm_eps(layer_norm_eps)
self.gguf_writer.add_wkv_head_size(head_size)
self.gguf_writer.add_decay_lora_rank(lora_rank_decay)
self.gguf_writer.add_iclr_lora_rank(lora_rank_iclr)
self.gguf_writer.add_value_residual_mix_lora_rank(lora_rank_value_residual_mix)
self.gguf_writer.add_gate_lora_rank(lora_rank_gate)
self.gguf_writer.add_feed_forward_length(intermediate_size)
self.gguf_writer.add_file_type(self.ftype)

# required by llama.cpp, unused
self.gguf_writer.add_head_count(0)

lerp_weights: dict[int, dict[str, Tensor]] = {}
lora_needs_transpose: bool = True

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# unify tensor names here to make life easier
name = name.replace("blocks", "layers").replace("ffn", "feed_forward")
name = name.replace("self_attn", "attention").replace("attn", "attention")
name = name.replace("time_mixer.", "")
# lora layer names in fla-hub's impl
if "_lora.lora" in name:
self.lora_needs_transpose = False
name = name.replace("_lora.lora.0.weight", "1.weight")
name = name.replace("_lora.lora.2.weight", "2.weight")
name = name.replace("_lora.lora.2.bias", "0.weight")

name = name.replace("feed_forward_norm", "ln2")
name = name.replace("g_norm", "ln_x")

if "attention.v" in name and "value" not in self.map_tensor_name(name) and bid == 0:
# some models have dummy v0/v1/v2 on first layer while others don't
# ignore them all since they are not used
return

if bid is not None and "attention.x_" in name:
if "attention.x_x" in name:
# already concatenated
new_name = f"blk.{bid}.time_mix_lerp_fused.weight"
data = data_torch.reshape(6, 1, -1)
yield (new_name, data)
else:
try:
self.lerp_weights[bid][name] = data_torch
except KeyError:
self.lerp_weights[bid] = {name: data_torch}
if all(f"model.layers.{bid}.attention.x_{i}" in self.lerp_weights[bid].keys() for i in ["r", "w", "k", "v", "a", "g"]):
new_name = f"blk.{bid}.time_mix_lerp_fused.weight"
data = torch.stack([self.lerp_weights[bid][f"model.layers.{bid}.attention.x_{i}"].squeeze(0) for i in ["r", "w", "k", "v", "a", "g"]], dim=0)
yield (new_name, data)
return
else:
data_torch = data_torch.squeeze()
new_name = self.map_tensor_name(name)

if not (new_name.endswith(".weight") or new_name.endswith(".bias")):
new_name += ".weight"

if self.lora_needs_transpose and any(
new_name.endswith(t) for t in [
"time_mix_w1.weight", "time_mix_w2.weight",
"time_mix_a1.weight", "time_mix_a2.weight",
"time_mix_v1.weight", "time_mix_v2.weight",
"time_mix_g1.weight", "time_mix_g2.weight",
]
):
data_torch = data_torch.transpose(0, 1)

if 'r_k' in new_name:
data_torch = data_torch.flatten()

if bid == 0 and "time_mix_a" in new_name:
# dummy v0/v1/v2 on first layer
# easist way to make llama happy
yield (new_name.replace("time_mix_a", "time_mix_v"), data_torch)

yield (new_name, data_torch)


@Model.register("RwkvHybridForCausalLM")
class ARwkv7Model(Rwkv7Model):
model_arch = gguf.MODEL_ARCH.ARWKV7

def set_vocab(self):
try:
self._set_vocab_sentencepiece()
except FileNotFoundError:
self._set_vocab_gpt2()

def set_gguf_parameters(self):
block_count = self.hparams["num_hidden_layers"]
hidden_size = self.hparams["hidden_size"]
head_size = self.hparams["head_size"]
rms_norm_eps = self.hparams["rms_norm_eps"]
intermediate_size = self.hparams["intermediate_size"]
wkv_has_gate = self.hparams["wkv_has_gate"]
assert self.hparams["wkv_version"] == 7

# ICLR: In-Context-Learning-Rate
lora_rank_decay = 64
lora_rank_iclr = 64
lora_rank_value_residual_mix = 32
lora_rank_gate = 128 if wkv_has_gate else 0

# RWKV isn't context limited
self.gguf_writer.add_context_length(1048576)
self.gguf_writer.add_embedding_length(hidden_size)
self.gguf_writer.add_block_count(block_count)
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
self.gguf_writer.add_wkv_head_size(head_size)
self.gguf_writer.add_decay_lora_rank(lora_rank_decay)
self.gguf_writer.add_iclr_lora_rank(lora_rank_iclr)
self.gguf_writer.add_value_residual_mix_lora_rank(lora_rank_value_residual_mix)
self.gguf_writer.add_gate_lora_rank(lora_rank_gate)
self.gguf_writer.add_feed_forward_length(intermediate_size)
self.gguf_writer.add_file_type(self.ftype)
self.gguf_writer.add_token_shift_count(1)

# required by llama.cpp, unused
self.gguf_writer.add_head_count(0)


@Model.register("MambaForCausalLM", "MambaLMHeadModel", "FalconMambaForCausalLM")
class MambaModel(Model):
model_arch = gguf.MODEL_ARCH.MAMBA
Expand Down
24 changes: 24 additions & 0 deletions ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,7 @@ extern "C" {
GGML_OP_RMS_NORM,
GGML_OP_RMS_NORM_BACK,
GGML_OP_GROUP_NORM,
GGML_OP_L2_NORM,

GGML_OP_MUL_MAT,
GGML_OP_MUL_MAT_ID,
Expand Down Expand Up @@ -501,6 +502,7 @@ extern "C" {
GGML_OP_GET_REL_POS,
GGML_OP_ADD_REL_POS,
GGML_OP_RWKV_WKV6,
GGML_OP_RWKV_WKV7,
GGML_OP_GATED_LINEAR_ATTN,

GGML_OP_UNARY,
Expand Down Expand Up @@ -1095,6 +1097,18 @@ extern "C" {
int n_groups,
float eps);

// l2 normalize along rows
// used in rwkv v7
GGML_API struct ggml_tensor * ggml_l2_norm(
struct ggml_context * ctx,
struct ggml_tensor * a,
float eps);

GGML_API struct ggml_tensor * ggml_l2_norm_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
float eps);

// a - x
// b - dy
GGML_API struct ggml_tensor * ggml_rms_norm_back(
Expand Down Expand Up @@ -1881,6 +1895,16 @@ extern "C" {
struct ggml_tensor * td,
struct ggml_tensor * state);

GGML_API struct ggml_tensor * ggml_rwkv_wkv7(
struct ggml_context * ctx,
struct ggml_tensor * r,
struct ggml_tensor * w,
struct ggml_tensor * k,
struct ggml_tensor * v,
struct ggml_tensor * a,
struct ggml_tensor * b,
struct ggml_tensor * state);

GGML_API struct ggml_tensor * ggml_gated_linear_attn(
struct ggml_context * ctx,
struct ggml_tensor * k,
Expand Down
Loading