From 6557f214e012e8efb24ba5dbb377c51101720961 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Mon, 27 Mar 2023 16:02:03 +0200 Subject: [PATCH] Disable non-functioning torch.compile (#17) --- generate.py | 11 ++++++----- scripts/download.py | 2 +- train.py | 7 ++++--- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/generate.py b/generate.py index f625c6e..84a18b3 100644 --- a/generate.py +++ b/generate.py @@ -5,7 +5,7 @@ import lightning as L -@torch.inference_mode() +@torch.no_grad() def generate(model, idx, max_new_tokens, max_seq_length, temperature=1.0, top_k=None): """ Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested. @@ -79,7 +79,8 @@ def main( max_new_tokens: int = 20, top_k: int = 200, temperature: float = 0.8, - compile: bool = False, + # compilation fails as it does not support torch.complex64 for RoPE + # compile: bool = False, accelerator: str = "auto", precision: str = "32-true", checkpoint_path: str = "/srv/data/checkpoints/llama/converted_nano/7B/state_dict.pth", @@ -96,7 +97,7 @@ def main( top_k: The number of top most probable tokens to consider in the sampling process. temperature: A value controlling the randomness of the sampling process. Higher values result in more random samples. - compile: Whether to compile the model. + # compile: Whether to compile the model. accelerator: The hardware to run on. Possible choices are: ``"cpu"``, ``"cuda"``, ``"mps"``, ``"gpu"``, ``"tpu"``, ``"auto"``. precision: Double precision (``"64"``), full precision (``"32"``), half precision AMP (``"16-mixed"``), @@ -117,8 +118,8 @@ def main( model.load_state_dict(checkpoint, strict=(not original_model)) model.eval() - if compile: - model = torch.compile(model) + # if compile: + # model = torch.compile(model) model = fabric.setup_module(model, move_to_device=False) tokenizer = Tokenizer(tokenizer_path) diff --git a/scripts/download.py b/scripts/download.py index f3f6d41..40c7022 100644 --- a/scripts/download.py +++ b/scripts/download.py @@ -7,7 +7,7 @@ def download_original(wd: str): if not os.path.isfile(filepath): print(f"Downloading original implementation to {filepath!r}") urllib.request.urlretrieve( - url="https://gist.githubusercontent.com/lantiga/fd36849fb1c498da949a0af635318a7b/raw/c4509e48a53ebb6a195a6f073b5267a69e47b45a/llama_model.py", + url="https://gist.githubusercontent.com/lantiga/fd36849fb1c498da949a0af635318a7b/raw/fe9561d1abd8d2c61c82dd62155fe98d3ac74c43/llama_model.py", filename="original_model.py", ) print("Done") diff --git a/train.py b/train.py index d42c2c9..f360cc9 100644 --- a/train.py +++ b/train.py @@ -16,7 +16,8 @@ eval_interval = 2000 eval_iters = 200 log_interval = 1 -compile = False +# compilation fails as it does not support torch.complex64 for RoPE +# compile = False # Hyperparameters learning_rate = 6e-4 @@ -60,8 +61,8 @@ def main(): with fabric.device: model = LLaMA(config) - if compile: - model = torch.compile(model) + # if compile: + # model = torch.compile(model) model = fabric.setup_module(model)