Skip to content

Commit

Permalink
Remove batched generation support (#162)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored Apr 19, 2023
1 parent 33ef184 commit 06302ea
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 42 deletions.
6 changes: 2 additions & 4 deletions finetune_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,7 @@ def generate_response(model, instruction, input=""):
tokenizer = Tokenizer("checkpoints/lit-llama/tokenizer.model")
sample = {"instruction": instruction, "input": input}
prompt = generate_prompt(sample)
encoded = tokenizer.encode(prompt, bos=True, eos=False)
encoded = encoded[None, :] # add batch dimension
encoded = encoded.to(model.device)
encoded = tokenizer.encode(prompt, bos=True, eos=False, device=model.device)

output = generate(
model,
Expand All @@ -167,7 +165,7 @@ def generate_response(model, instruction, input=""):
max_new_tokens=100,
temperature=0.8,
)
output = tokenizer.decode(output[0].cpu())
output = tokenizer.decode(output)
return output # output.split("### Response:")[1].strip()


Expand Down
6 changes: 2 additions & 4 deletions finetune_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,17 +131,15 @@ def generate_response(model, instruction):
tokenizer = Tokenizer("checkpoints/lit-llama/tokenizer.model")
sample = {"instruction": instruction, "input": ""}
prompt = generate_prompt(sample)
encoded = tokenizer.encode(prompt, bos=True, eos=False)
encoded = encoded[None, :] # add batch dimension
encoded = encoded.to(model.device)
encoded = tokenizer.encode(prompt, bos=True, eos=False, device=model.device)

output = generate(
model,
idx=encoded,
max_seq_length=block_size,
max_new_tokens=100,
)
output = tokenizer.decode(output[0].cpu())
output = tokenizer.decode(output)
return output # output.split("### Response:")[1].strip()


Expand Down
24 changes: 12 additions & 12 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,40 +26,40 @@ def generate(
Args:
model: The model to use.
idx: Tensor of shape (B, T) with indices of the prompt sequence.
idx: Tensor of shape (T) with indices of the prompt sequence.
max_new_tokens: The number of new tokens to generate.
max_seq_length: The maximum sequence length allowed.
temperature: Scales the predicted logits by 1 / temperature
top_k: If specified, only sample among the tokens with the k highest probabilities
"""
# create an empty tensor of the expected final shape and fill in the current tokens
B, T = idx.shape
T = idx.size(0)
T_new = T + max_new_tokens
empty = torch.empty(B, T_new, dtype=idx.dtype, device=idx.device)
empty[:, :T] = idx
empty = torch.empty(T_new, dtype=idx.dtype, device=idx.device)
empty[:T] = idx
idx = empty

# generate max_new_tokens tokens
for t in range(T, T_new):
# ignore the not-filled-yet tokens
idx_cond = idx[:, :t]
idx_cond = idx[:t]
# if the sequence context is growing too long we must crop it at max_seq_length
idx_cond = idx_cond if T <= max_seq_length else idx_cond[:, -max_seq_length:]
idx_cond = idx_cond if T <= max_seq_length else idx_cond[-max_seq_length:]

# forward
logits = model(idx_cond)
logits = logits[:, -1] / temperature
logits = logits[-1] / temperature

# optionally crop the logits to only the top k options
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float("Inf")
logits[logits < v[[-1]]] = -float("Inf")

probs = torch.nn.functional.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)

# concatenate the new column
idx[:, t:] = idx_next
# concatenate the new generation
idx[t] = idx_next

return idx

Expand Down Expand Up @@ -87,6 +87,7 @@ def main(
samples.
checkpoint_path: The checkpoint path to load.
tokenizer_path: The tokenizer path to load.
model_size: The model size to load.
quantize: Whether to quantize the model and using which method:
``"llm.int8"``: LLM.int8() mode,
``"gptq.int4"``: GPTQ 4-bit mode.
Expand Down Expand Up @@ -116,7 +117,6 @@ def main(

tokenizer = Tokenizer(tokenizer_path)
encoded_prompt = tokenizer.encode(prompt, bos=True, eos=False, device=fabric.device)
encoded_prompt = encoded_prompt[None, :] # add batch dimension

L.seed_everything(1234)
t0 = time.perf_counter()
Expand All @@ -129,7 +129,7 @@ def main(
model.config.block_size, # type: ignore[union-attr,arg-type]
temperature=temperature,
top_k=top_k,
)[0] # unpack batch dimension
)
print(tokenizer.decode(y))

t = time.perf_counter() - t0
Expand Down
9 changes: 2 additions & 7 deletions generate_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,9 @@ def main(
model = fabric.setup_module(model)

tokenizer = Tokenizer(tokenizer_path)
encoded_prompt = tokenizer.encode(prompt, bos=True, eos=False, device=fabric.device)
encoded_prompt = encoded_prompt[None, :] # add batch dimension

sample = {"instruction": prompt, "input": input}
prompt = generate_prompt(sample)
encoded = tokenizer.encode(prompt, bos=True, eos=False)
encoded = encoded[None, :] # add batch dimension
encoded = encoded.to(model.device)
encoded = tokenizer.encode(prompt, bos=True, eos=False, device=model.device)

t0 = time.perf_counter()
output = generate(
Expand All @@ -100,7 +95,7 @@ def main(
top_k=top_k,
)
# The end of the response is where the model generates the EOS token
output = truncate_output_to_eos(output[0].cpu(), tokenizer.eos_id)
output = truncate_output_to_eos(output, tokenizer.eos_id)
output = tokenizer.decode(output)
output = output.split("### Response:")[1].strip()

Expand Down
9 changes: 2 additions & 7 deletions generate_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,14 +95,9 @@ def main(
model = fabric.setup_module(model)

tokenizer = Tokenizer(tokenizer_path)
encoded_prompt = tokenizer.encode(prompt, bos=True, eos=False, device=fabric.device)
encoded_prompt = encoded_prompt[None, :] # add batch dimension

sample = {"instruction": prompt, "input": input}
prompt = generate_prompt(sample)
encoded = tokenizer.encode(prompt, bos=True, eos=False)
encoded = encoded[None, :] # add batch dimension
encoded = encoded.to(model.device)
encoded = tokenizer.encode(prompt, bos=True, eos=False, device=model.device)

t0 = time.perf_counter()
output = generate(
Expand All @@ -114,7 +109,7 @@ def main(
top_k=top_k,
)
# The end of the response is where the model generates the EOS token
output = truncate_output_to_eos(output[0].cpu(), tokenizer.eos_id)
output = truncate_output_to_eos(output, tokenizer.eos_id)
output = tokenizer.decode(output)
output = output.split("### Response:")[1].strip()

Expand Down
14 changes: 6 additions & 8 deletions tests/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from unittest import mock
from unittest.mock import Mock, call, ANY

import pytest
import torch

wd = Path(__file__).parent.parent.absolute()
Expand All @@ -22,13 +21,12 @@ def load_generate_script():
return generate


@pytest.mark.parametrize("B", (1, 2))
def test_generate(B):
def test_generate():
generate = load_generate_script()

T, C = 5, 3
logits = torch.randn(B, T, C)
input_idx = torch.randint(10, size=(B, T))
logits = torch.randn(T, C)
input_idx = torch.randint(10, size=(T,))

model = Mock(return_value=logits)
max_new_tokens = 20
Expand All @@ -42,11 +40,11 @@ def multinomial(*args, **kwargs):
return out

with mock.patch("torch.multinomial", multinomial):
out = generate.generate(model, input_idx, max_new_tokens, max_seq_length=10)
out = generate.generate(model, input_idx, max_new_tokens, max_seq_length=10, top_k=4)

assert out.shape == (B, T + max_new_tokens)
assert out.size(0) == T + max_new_tokens
multinomial_results = torch.hstack(multinomial_results)
expected = torch.cat((input_idx, multinomial_results), dim=1)
expected = torch.cat((input_idx, multinomial_results))
assert out.shape == expected.shape
torch.testing.assert_close(out, expected)

Expand Down

0 comments on commit 06302ea

Please sign in to comment.