Skip to content

Commit

Permalink
add in equinox interface (aka shrek sampler)
Browse files Browse the repository at this point in the history
  • Loading branch information
iamlemec committed Oct 14, 2024
1 parent 17d6d0d commit cae45fb
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 2 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ For more examples, see the `test_*` functions in `model.py` and `compute.py`, or
- tokenization is hell, let Huggingface handle it!
- can do rapid prototyping and experimentation without having to compile anything
- no need for round-trips to and from the GPU (could be important for embeddings?)
- easier to rapidly integrate things like novel sampling methods (entropix??)
- easier to rapidly integrate things like novel sampling methods

**BONUS**: We now have `entropix` integration! If you check out `gadget/shrek.py`, you'll find `ShrekGen` and `ShrekChat` analogs of `TextGen` and `TextChat` that use the sampling method from [entropix](https://github.com/xjdr-alt/entropix) (aka "Shrek sampler"). Note that you'll need to install `entropix` separately, as it is not a strict dependency of `gadget`.

# Install

Expand Down
2 changes: 1 addition & 1 deletion gadget/models/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def attention_layer(

# compute interactions
head_wgt = 1.0/sqrt(head_dim)
kq = ggml_mul_mat(ctx, k, q)
kq = ggml_mul_mat(ctx, k, q, name=f'{name}_pre_scores')
kq = ggml_soft_max_ext(ctx, kq, mask, head_wgt, alibi, name=f'{name}_scores')

# pull in values
Expand Down
30 changes: 30 additions & 0 deletions gadget/shrek.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# shrek llama

import torch
from math import sqrt

from entropix.torch_sampler import sample, device
from .textgen import TextGen, TextChat

class ShrekMixin:
def sample(self, tokens, **kwargs):
head_dim = self.model.params['head_dim_kv']
n_layers = self.model.params['llama.block_count']
score_name = f'attn{n_layers-1}_pre_scores'

tokens = torch.tensor(tokens, dtype=torch.int32, device=device)
logits = self.logits(tokens)
scores = self.model.get_named_node(score_name) / sqrt(head_dim)

batch_tokens = tokens.unsqueeze(0)
batch_logits = logits.unsqueeze(0)
batch_scores = scores.unsqueeze(0)

nexts = sample(batch_tokens, batch_logits, batch_scores, **kwargs)
return nexts.squeeze(0).item()

class ShrekGen(TextGen, ShrekMixin):
pass

class ShrekChat(TextChat, ShrekMixin):
pass

0 comments on commit cae45fb

Please sign in to comment.