Skip to content

Commit

Permalink
Disable non-functioning torch.compile (#17)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored Mar 27, 2023
1 parent 51df486 commit 6557f21
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 9 deletions.
11 changes: 6 additions & 5 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import lightning as L


@torch.inference_mode()
@torch.no_grad()
def generate(model, idx, max_new_tokens, max_seq_length, temperature=1.0, top_k=None):
"""
Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
Expand Down Expand Up @@ -79,7 +79,8 @@ def main(
max_new_tokens: int = 20,
top_k: int = 200,
temperature: float = 0.8,
compile: bool = False,
# compilation fails as it does not support torch.complex64 for RoPE
# compile: bool = False,
accelerator: str = "auto",
precision: str = "32-true",
checkpoint_path: str = "/srv/data/checkpoints/llama/converted_nano/7B/state_dict.pth",
Expand All @@ -96,7 +97,7 @@ def main(
top_k: The number of top most probable tokens to consider in the sampling process.
temperature: A value controlling the randomness of the sampling process. Higher values result in more random
samples.
compile: Whether to compile the model.
# compile: Whether to compile the model.
accelerator: The hardware to run on. Possible choices are:
``"cpu"``, ``"cuda"``, ``"mps"``, ``"gpu"``, ``"tpu"``, ``"auto"``.
precision: Double precision (``"64"``), full precision (``"32"``), half precision AMP (``"16-mixed"``),
Expand All @@ -117,8 +118,8 @@ def main(
model.load_state_dict(checkpoint, strict=(not original_model))

model.eval()
if compile:
model = torch.compile(model)
# if compile:
# model = torch.compile(model)
model = fabric.setup_module(model, move_to_device=False)

tokenizer = Tokenizer(tokenizer_path)
Expand Down
2 changes: 1 addition & 1 deletion scripts/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ def download_original(wd: str):
if not os.path.isfile(filepath):
print(f"Downloading original implementation to {filepath!r}")
urllib.request.urlretrieve(
url="https://gist.githubusercontent.com/lantiga/fd36849fb1c498da949a0af635318a7b/raw/c4509e48a53ebb6a195a6f073b5267a69e47b45a/llama_model.py",
url="https://gist.githubusercontent.com/lantiga/fd36849fb1c498da949a0af635318a7b/raw/fe9561d1abd8d2c61c82dd62155fe98d3ac74c43/llama_model.py",
filename="original_model.py",
)
print("Done")
Expand Down
7 changes: 4 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
eval_interval = 2000
eval_iters = 200
log_interval = 1
compile = False
# compilation fails as it does not support torch.complex64 for RoPE
# compile = False

# Hyperparameters
learning_rate = 6e-4
Expand Down Expand Up @@ -60,8 +61,8 @@ def main():
with fabric.device:
model = LLaMA(config)

if compile:
model = torch.compile(model)
# if compile:
# model = torch.compile(model)

model = fabric.setup_module(model)

Expand Down

0 comments on commit 6557f21

Please sign in to comment.