Skip to content

Commit

Permalink
Add vocab_size to Tokenizer.train (Lightning-AI#69)
Browse files Browse the repository at this point in the history
  • Loading branch information
lantiga authored Mar 31, 2023
1 parent ba505cb commit f7aa8b9
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 3 deletions.
4 changes: 2 additions & 2 deletions lit_llama/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,6 @@ def decode(self, tokens: torch.Tensor) -> str:
return self.processor.decode(tokens.tolist())

@staticmethod
def train(input: str, destination: str) -> None:
def train(input: str, destination: str, vocab_size=32000) -> None:
model_prefix = os.path.join(destination, "tokenizer")
SentencePieceTrainer.Train(input=input, model_prefix=model_prefix)
SentencePieceTrainer.Train(input=input, model_prefix=model_prefix, vocab_size=vocab_size)
2 changes: 1 addition & 1 deletion scripts/prepare_shakespeare.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def prepare(destination_path: Path = Path("data/shakespeare")) -> None:

from lit_llama import Tokenizer

Tokenizer.train(input=input_file_path, destination=destination_path)
Tokenizer.train(input=input_file_path, destination=destination_path, vocab_size=100)
tokenizer = Tokenizer(destination_path / "tokenizer.model")
train_ids = tokenizer.encode(train_data)
val_ids = tokenizer.encode(val_data)
Expand Down
1 change: 1 addition & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def main() -> None:

config = LLaMAConfig.from_name("7B")
config.block_size = block_size
config.vocab_size = 100 # from prepare_shakespeare.py

with fabric.device:
model = LLaMA(config)
Expand Down

0 comments on commit f7aa8b9

Please sign in to comment.