From 7a9f3e0f0d7f004564cfb4c4f7a329d3a71c7b28 Mon Sep 17 00:00:00 2001 From: Yunxuan Xiao Date: Mon, 24 Jul 2023 21:12:04 -0700 Subject: [PATCH] [Doc][Example] Fine-tune `vicuna-13b-v1.3` with LightningTrainer + DeepSpeed (#37016) Signed-off-by: woshiyyya Signed-off-by: Yunxuan Xiao Signed-off-by: matthewdeng Co-authored-by: matthewdeng Signed-off-by: e428265 --- doc/source/_toc.yml | 2 + doc/source/ray-overview/examples.rst | 7 + doc/source/train/examples.rst | 8 + doc/source/train/examples/lightning/BUILD | 5 +- ...una_13b_lightning_deepspeed_finetune.ipynb | 1424 +++++++++++++++++ .../test_myst_doc.py | 1 + .../vicuna_13b_deepspeed_compute_aws.yaml | 20 + .../vicuna_13b_deepspeed_env.yaml | 27 + ...una_13b_lightning_deepspeed_finetune.ipynb | 1 + release/release_tests.yaml | 21 + 10 files changed, 1515 insertions(+), 1 deletion(-) create mode 100644 doc/source/train/examples/lightning/vicuna_13b_lightning_deepspeed_finetune.ipynb create mode 120000 release/air_examples/vicuna_13b_lightning_deepspeed_finetuning/test_myst_doc.py create mode 100644 release/air_examples/vicuna_13b_lightning_deepspeed_finetuning/vicuna_13b_deepspeed_compute_aws.yaml create mode 100644 release/air_examples/vicuna_13b_lightning_deepspeed_finetuning/vicuna_13b_deepspeed_env.yaml create mode 120000 release/air_examples/vicuna_13b_lightning_deepspeed_finetuning/vicuna_13b_lightning_deepspeed_finetune.ipynb diff --git a/doc/source/_toc.yml b/doc/source/_toc.yml index e6b0a66357524..35be6f2945091 100644 --- a/doc/source/_toc.yml +++ b/doc/source/_toc.yml @@ -139,6 +139,8 @@ parts: title: "Torch Data Prefetching Benchmark" - file: train/examples/pytorch/pytorch_resnet_finetune title: "PyTorch Finetuning ResNet Example" + - file: train/examples/lightning/vicuna_13b_lightning_deepspeed_finetune + title: "Fine-tune Vicuna-13B with DeepSpeed and PyTorch Lightning" - file: train/faq - file: train/api/api diff --git a/doc/source/ray-overview/examples.rst b/doc/source/ray-overview/examples.rst index 5aa2698943f67..62acc120ce7ee 100644 --- a/doc/source/ray-overview/examples.rst +++ b/doc/source/ray-overview/examples.rst @@ -1618,3 +1618,10 @@ Ray Examples .. button-ref:: /serve/tutorials/streaming Using Ray Serve to deploy a chatbot + + .. grid-item-card:: :bdg-secondary:`Code example` + :class-item: gallery-item training llm + + .. button-ref:: /train/examples/lightning/vicuna_13b_lightning_deepspeed_finetune + + Fine-tune vicuna-13b-v1.3 with DeepSpeed and LightningTrainer diff --git a/doc/source/train/examples.rst b/doc/source/train/examples.rst index b1c7807d330ab..66143800e3405 100644 --- a/doc/source/train/examples.rst +++ b/doc/source/train/examples.rst @@ -72,6 +72,14 @@ Distributed Training Examples using Ray Train .. button-ref:: dolly_lightning_fsdp_finetuning Fine-tune LLM with AIR LightningTrainer and FSDP + + .. grid-item-card:: + :img-top: /images/pytorch_lightning_small.png + :class-img-top: pt-2 w-75 d-block mx-auto fixed-height-img + + .. button-ref:: vicuna_lightning_deepspeed_finetuning + + Fine-tune vicuna-13b-v1.3 with Deepspeed and LightningTrainer Ray Train Examples Using Loggers & Callbacks diff --git a/doc/source/train/examples/lightning/BUILD b/doc/source/train/examples/lightning/BUILD index 97d8822771b25..7532a168e1b79 100644 --- a/doc/source/train/examples/lightning/BUILD +++ b/doc/source/train/examples/lightning/BUILD @@ -10,7 +10,10 @@ filegroup( py_test_run_all_notebooks( size="large", include=["*.ipynb"], - exclude=["lightning_exp_tracking.ipynb"], + exclude=[ + "lightning_exp_tracking.ipynb", # CPU test + "vicuna_13b_lightning_deepspeed_finetune.ipynb", # Release Test + ], data=["//doc/source/train/examples/lightning:lightning_examples"], tags=["exclusive", "team:ml", "gpu", "ray_air"], ) diff --git a/doc/source/train/examples/lightning/vicuna_13b_lightning_deepspeed_finetune.ipynb b/doc/source/train/examples/lightning/vicuna_13b_lightning_deepspeed_finetune.ipynb new file mode 100644 index 0000000000000..debf22976e255 --- /dev/null +++ b/doc/source/train/examples/lightning/vicuna_13b_lightning_deepspeed_finetune.ipynb @@ -0,0 +1,1424 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "(vicuna_lightning_deepspeed_finetuning)=\n", + "\n", + "# Fine-tune `vicuna-13b` with Ray LightningTrainer and DeepSpeed\n", + "\n", + "In this example, we will demonstrate how to perform full fine-tuning for a [`vicuna-13b-v1.3`](https://huggingface.co/lmsys/vicuna-13b-v1.3) model using LightningTrainer with the DeepSpeed ZeRO-3 strategy.\n", + "\n", + "- [DeepSpeed]() is an open-source deep learning optimization library for PyTorch. It's designed to reduce computing power and memory usage, and to train large distributed models by leveraging state-of-the-art innovations like ZeRO, 3D-Parallelism, DeepSpeed-MoE, and ZeRO-Infinity. \n", + "- PyTorch Lightning offers a [DeepSpeed integration](https://lightning.ai/docs/pytorch/stable/api/pytorch_lightning.strategies.DeepSpeedStrategy.html), which provides a simple interface to configure the knobs for DeepSpeed and automatically trigger your training process with the DeepSpeed Engine.\n", + "- {class}`Ray LightningTrainer ` allows you to easily scale your PyTorch Lightning job across multiple nodes in a Ray cluster, without worrying about the underlying cluster management, autoscaling, and distributed process group settings.\n", + "\n", + "Our demo aims to illustrate how these three tools can be combined effectively to finetune the Vicuna-13B model, leveraging the strengths of each to create an efficient and high-performance deep learning solution.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "```{note}\n", + "This is an advanced example of Large Language Model fine-tuning with Ray AIR. If you're a beginner or new to the concepts of Ray AIR and LightningTrainer, it would be beneficial to first explore the introductory documentation below to build a foundational understanding. \n", + "- [Ray AIR Key Concepts](air-key-concepts) \n", + "- [Ray Data Key Concepts](data_key_concepts)\n", + "- {ref}`[Basic] Image Classification with LightningTrainer `\n", + "- {ref}`[Intermediate] Using LightningTrainer with Ray Data `\n", + "```\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Cluster Setting\n", + "\n", + "\n", + "### Compute instances\n", + "In this example, we set up a Ray cluster on AWS with the following settings:\n", + "\n", + "| | num | instance type | GPU per node | GPU Memory | CPU Memory |\n", + "|-|-|-|-|-|-|\n", + "|Head node|1|g5.16xlarge|1 x A10G | 24 GB | 256 GB|\n", + "|Worker node|15|g5.4xlarge|1 x A10G | 24 GB | 64 GB|\n", + "\n", + "```{note}\n", + "In this example, we used 16 A10G GPUs for model training and tuned the DeepSpeed configurations for this setup. If you have a different cluster setup or GPUs with lower memory capacities, you may need to modify the DeepSpeed configurations and batch size to fit the model into the GPUs.\n", + "```\n", + "\n", + "```{tip}\n", + "We selected a GPU instance with additional CPU memory for the head node to demonstrate single-node offline inference. If you are training only, you can still opt for the g5.4xlarge instance for the head node.\n", + "```\n", + "\n", + "\n", + "### Cloud Storage\n", + "\n", + "Additionally, since the checkpoint size for this 13B parameter model can be large (~140GB), we choose to store the checkpoints in AWS S3. Thanks to the newly introduced distributed checkpointing feature in Ray 2.5, each worker can upload its own shards individually to the S3 bucket, greatly reducing the latency and network traffic of checkpoint syncing.\n", + "\n", + "### Local Storage\n", + "To demonstrate offline inference, we need to download and consolidate the model checkpoint onto the head node. This action requires around 200GB disk storage. Therefore, we mounted the NVMe SSD provided by g5 instances at `/dev/nvme1n1` to `/mnt/local_storage`, and we will save the checkpoints in this folder.\n", + "\n", + "For more details, please refer to[Amazon EBS and NVMe on Linux instances](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/nvme-ebs-volumes.html).\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup Ray Environment\n", + "\n", + "We define a runtime environment to ensure that the Ray workers have access to all necessary packages. If you have already included these dependencies in your Docker image or installed them on each node, you can ignore the `runtime_env` argument.\n", + "\n", + "```{note}\n", + "Note that the codebases of `transformers`, `accelerate`, and `deepspeed` are all rapidly changing, so we have pinned the package versions here to ensure testing stability. You can try other version combinations and feel free to report any issues you encounter.\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import ray\n", + "\n", + "NUM_WORKERS = 16\n", + "BATCH_SIZE_PER_WORKER = 8\n", + "MODEL_NAME = \"lmsys/vicuna-13b-v1.3\"\n", + "\n", + "ray.init(\n", + " runtime_env={\n", + " \"pip\": [\n", + " \"datasets==2.13.1\",\n", + " \"torch>=1.13.0\",\n", + " \"deepspeed==0.9.4\",\n", + " \"accelerate==0.20.3\",\n", + " \"transformers==4.30.2\",\n", + " \"pytorch_lightning==2.0.3\",\n", + " ]\n", + " }\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load and preprocess datasets\n", + "\n", + "We were impressed by LLM's ability of zero-shot text-generation, while some LLMs may not perform well in code generation due to the lack of code in the training corpus. The CMU [CoNaLa](https://conala-corpus.github.io/)(The Code/Natural Language Challenge) was designed to test systems for generating program snippets from natural language. Each data record contains an intent sentence and a one-line code snippet. The goal is to fine-tune the Vicuna model on this dataset, enabling the model to generate correct and runnable code snippets, thereby achieving natural language intent. Here are some examples:\n", + "\n", + "| intent | code snippet |\n", + "| - | - |\n", + "| \"convert a list of integers into a single integer\" | `r = int(''.join(map(str, x)))`|\n", + "| \"normalize a pandas dataframe `df` by row\" | `df.div(df.sum(axis=1), axis=0)` | \n", + "| \"Convert string '03:55' into datetime.time object\" | `datetime.datetime.strptime('03:55', '%H:%M').time()` |\n", + "\n", + "The CoNaLa team has released a dataset crawled from Stack Overflow, automatically filtered, then curated by annotators, split into 2379 training and 500 test examples. In addition, they also included an automatically-mined dataset with 600k examples. In this demo, we take all the curated data and the top 5000 mined data for fine-tuning." + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here we preprocess the CoNaLa dataset with Ray Data. You can also use HuggingFace Datasets and pass it directly to `LightningConfigBuilder.fit_params()`." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "535afe3e183b4cdfa61c39cbae788608", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/2 [00:00\",\n", + " axis=1,\n", + " )\n", + " return batch[[\"input_sentence\"]]\n", + "\n", + "\n", + "# Tokenize input sentences to tensors\n", + "def tokenize(batch):\n", + " tokenizer = AutoTokenizer.from_pretrained(\n", + " MODEL_NAME, padding_side=\"left\", use_fast=False\n", + " )\n", + " tokenizer.pad_token = tokenizer.eos_token\n", + " ret = tokenizer(\n", + " list(batch[\"input_sentence\"]),\n", + " truncation=True,\n", + " max_length=128,\n", + " padding=\"max_length\",\n", + " return_tensors=\"np\",\n", + " )\n", + " ret[\"labels\"] = ret[\"input_ids\"].copy()\n", + " return dict(ret)\n", + "\n", + "# Preprocess train dataset\n", + "processed_ds = ray_ds.map_batches(fill_prompt, batch_format=\"pandas\").map_batches(tokenize, batch_format=\"pandas\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Define your model\n", + "\n", + "Here we load the pre-trained model weights from HuggingFace Model Hub, and wrap them into `pl.LightningModule`. We adopted the efficient model initialization techniques introduced in [Lightning-transformers](https://github.com/Lightning-Universe/lightning-transformers) to avoid unnecessary full weights loading." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[2023-06-30 17:39:35,109] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n" + ] + } + ], + "source": [ + "import torch\n", + "import transformers\n", + "import pytorch_lightning as pl\n", + "from transformers import AutoTokenizer, AutoModelForCausalLM\n", + "from deepspeed.ops.adam import DeepSpeedCPUAdam\n", + "\n", + "\n", + "class ZeRO3Config:\n", + " def __init__(self, pl_module):\n", + " self.config = pl_module.trainer.strategy.config\n", + "\n", + " def __call__(self, *args, **kwargs):\n", + " return self\n", + "\n", + " def is_zero3(self) -> bool:\n", + " return True\n", + "\n", + "\n", + "def enable_transformers_pretrained_deepspeed_sharding(\n", + " pl_module: \"pl.LightningModule\",\n", + ") -> None:\n", + " transformers.deepspeed._hf_deepspeed_config_weak_ref = ZeRO3Config(pl_module)\n", + "\n", + "\n", + "class Vicuna13BModel(pl.LightningModule):\n", + " def __init__(self):\n", + " super().__init__()\n", + " # Enable tf32 for better performance\n", + " torch.backends.cuda.matmul.allow_tf32 = True\n", + "\n", + " def setup(self, stage) -> None:\n", + " # Defer model initialization to inject deepspeed configs to HF.\n", + " # During initialization, HF transformers can immediately partition \n", + " # the model across all gpus avoid the overhead in time and memory \n", + " # copying it on CPU or each GPU first.\n", + " enable_transformers_pretrained_deepspeed_sharding(self)\n", + " self.model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)\n", + " if self.global_rank == 0:\n", + " print(\"DeepSpeed Configs: \", self.trainer.strategy.config)\n", + " print(\"Model Archetecture: \", self.model)\n", + "\n", + " def forward(self, batch):\n", + " outputs = self.model(\n", + " batch[\"input_ids\"],\n", + " labels=batch[\"labels\"],\n", + " attention_mask=batch[\"attention_mask\"],\n", + " )\n", + " return outputs.loss\n", + "\n", + " def training_step(self, batch, batch_idx):\n", + " loss = self.forward(batch)\n", + " self.log(\"train_loss\", loss, prog_bar=True, on_step=True, sync_dist=True)\n", + " return loss\n", + "\n", + " def configure_optimizers(self):\n", + " return DeepSpeedCPUAdam(self.parameters(), lr=2e-5, weight_decay=0.01)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Training Configurations\n", + "\n", + "Before training, let's calculate the memory usage of finetuning a `vicuna-13b` model. Assume we are using FP16 mixed-precision training, and the optimizer is Adam with FP32 states.\n", + "\n", + "- Model parameters: 13(billion parameters) * 2(FP16) ≈ 26GB\n", + "- Optimizer states: 13(billion parameters) * 2(momentums per param) * 4 (FP32) ≈ 52GB\n", + "\n", + "As we can see, the model parameters themselves require 26GB, which cannot fit in a single A10G GPU, let alone the activations and optimizers states. Here, we use ZeRO stage-3 to partition the model, gradients, and optimizer states across 16 nodes. Additionally, we employ optimizer CPU offloading to reduce GRAM usage and increase throughput with larger batch sizes. We also disabled parameter offloading and activation checkpointing to improve the training speed.\n", + "\n", + "Regarding other knobs such as `reduce_bucket_size`, `stage3_prefetch_bucket_size` and `stage3_param_persistence_threshold`, we kept them as the [default values in HuggingFace](https://huggingface.co/docs/transformers/main_classes/deepspeed#zero3-config). Feel free to further adjust them to speed up the training process." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "from ray.train.lightning import LightningTrainer, LightningConfigBuilder\n", + "from transformers import AutoConfig\n", + "\n", + "config = AutoConfig.from_pretrained(MODEL_NAME)\n", + "HIDDEN_SIZE = config.hidden_size\n", + "\n", + "deepspeed_configs = {\n", + " \"zero_allow_untested_optimizer\": True,\n", + " \"bf16\": {\"enabled\": True},\n", + " \"zero_optimization\": {\n", + " \"stage\": 3,\n", + " \"offload_optimizer\": {\"device\": \"cpu\", \"pin_memory\": True},\n", + " \"overlap_comm\": True,\n", + " \"contiguous_gradients\": True,\n", + " \"reduce_bucket_size\": HIDDEN_SIZE * HIDDEN_SIZE,\n", + " \"stage3_prefetch_bucket_size\": 0.9 * HIDDEN_SIZE * HIDDEN_SIZE,\n", + " \"stage3_param_persistence_threshold\": 10 * HIDDEN_SIZE,\n", + " },\n", + "}\n", + "\n", + "lightning_config = (\n", + " LightningConfigBuilder()\n", + " .module(cls=Vicuna13BModel)\n", + " .trainer(\n", + " max_epochs=1,\n", + " accelerator=\"gpu\",\n", + " precision=\"bf16-mixed\",\n", + " accumulate_grad_batches=2,\n", + " )\n", + " .strategy(name=\"deepspeed\", config=deepspeed_configs)\n", + " .checkpointing(save_top_k=0, save_weights_only=True, save_last=True)\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "remove-cell" + ] + }, + "outputs": [], + "source": [ + "from pytorch_lightning.callbacks import TQDMProgressBar\n", + "\n", + "# Create a customized progress bar for LightningTrainer\n", + "class VicunaProgressBar(TQDMProgressBar):\n", + " def __init__(self, num_iters_per_epoch, *args, **kwargs):\n", + " super().__init__(*args, **kwargs)\n", + " self.num_iters_per_epoch = num_iters_per_epoch\n", + "\n", + " def on_train_epoch_start(self, trainer, *_):\n", + " super().on_train_epoch_start(trainer, *_)\n", + " self.train_progress_bar.reset(self.num_iters_per_epoch)\n", + "\n", + "\n", + "total_batches = processed_ds.count()\n", + "num_iters_per_epoch = total_batches // (NUM_WORKERS * BATCH_SIZE_PER_WORKER)\n", + "progress_bar = VicunaProgressBar(num_iters_per_epoch)\n", + "\n", + "\n", + "lightning_config.trainer(\n", + " callbacks=[progress_bar],\n", + " # Take a subset to accelerate release tests\n", + " limit_train_batches=20,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Finally, combine all the configurations with {class}`LightningConfigBuilder ` and instantiate a LightningTrainer. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from ray.air.config import CheckpointConfig, RunConfig, ScalingConfig\n", + "\n", + "trainer = LightningTrainer(\n", + " lightning_config=lightning_config.build(),\n", + " run_config=RunConfig(\n", + " name=\"vicuna-13b-finetune\",\n", + " storage_path=\"s3://anyscale-staging-data-cld-kvedzwag2qa8i5bjxuevf5i7/air-release-tests\",\n", + " checkpoint_config=CheckpointConfig(\n", + " num_to_keep=1,\n", + " # Enable distributed checkpointing\n", + " _checkpoint_keep_all_ranks=True,\n", + " _checkpoint_upload_from_workers=True,\n", + " ),\n", + " ),\n", + " scaling_config=ScalingConfig(\n", + " num_workers=NUM_WORKERS,\n", + " use_gpu=True,\n", + " resources_per_worker={\"CPU\": 15, \"GPU\": 1},\n", + " ),\n", + " datasets={\"train\": processed_ds},\n", + " datasets_iter_config={\"batch_size\": BATCH_SIZE_PER_WORKER},\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "```{tip}\n", + "\n", + "Here, we highly recommend saving checkpoints with cloud storage and enabling distributed checkpointing by setting `_checkpoint_keep_all_ranks` and `_checkpoint_upload_from_workers` to True when training huge models. Otherwise, all checkpoint shards will be synced to the head node, which may introduce enormous syncing overhead and even cause out-of-memory.\n", + "\n", + "```" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Model Fine-tuning\n", + "\n", + "Once everything is configured in LightningTrainer, training becomes easy. Simply call `trainer.fit()`, and your workload will be scaled to the Ray cluster, initiating ZeRO-3 parallel training." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "
\n", + "
\n", + "

Tune Status

\n", + " \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
Current time:2023-06-30 18:21:59
Running for: 00:42:22.75
Memory: 10.7/249.1 GiB
\n", + "
\n", + "
\n", + "
\n", + "

System Info

\n", + " Using FIFO scheduling algorithm.
Logical resource usage: 241.0/304 CPUs, 16.0/16 GPUs (0.0/16.0 accelerator_type:A10G)\n", + "
\n", + " \n", + "
\n", + "
\n", + "
\n", + "

Trial Status

\n", + " \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
Trial name status loc iter total time (s) train_loss epoch step
LightningTrainer_c1544_00000TERMINATED10.0.55.20:134103 1 2473.94 0.523438 0 29
\n", + "
\n", + "
\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[2m\u001b[36m(pid=134103)\u001b[0m [2023-06-30 17:39:41,637] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[2m\u001b[36m(LightningTrainer pid=134103)\u001b[0m The `preprocessor` arg to Trainer is deprecated. Apply preprocessor transformations ahead of time by calling `preprocessor.transform(ds)`. Support for the preprocessor arg will be dropped in a future release.\n", + "\u001b[2m\u001b[36m(LightningTrainer pid=134103)\u001b[0m \u001b[33mImportant: Ray Data requires schemas for all datasets in Ray 2.5. This means that standalone Python objects are no longer supported. In addition, the default batch format is fixed to NumPy. To revert to legacy behavior temporarily, set the environment variable RAY_DATA_STRICT_MODE=0 on all cluster processes.\n", + "\u001b[2m\u001b[36m(LightningTrainer pid=134103)\u001b[0m \n", + "\u001b[2m\u001b[36m(LightningTrainer pid=134103)\u001b[0m Learn more here: https://docs.ray.io/en/master/data/faq.html#migrating-to-strict-mode\u001b[0m\n", + "\u001b[2m\u001b[36m(LightningTrainer pid=134103)\u001b[0m Starting distributed worker processes: ['134267 (10.0.55.20)', '74152 (10.0.63.141)', '75476 (10.0.51.205)', '75547 (10.0.42.158)', '74711 (10.0.45.211)', '75132 (10.0.20.140)', '74502 (10.0.60.86)', '75695 (10.0.53.69)', '74457 (10.0.47.2)', '74569 (10.0.33.23)', '74341 (10.0.29.61)', '74274 (10.0.36.152)', '74561 (10.0.35.16)', '74427 (10.0.16.236)', '74273 (10.0.54.55)', '74996 (10.0.9.249)']\n", + "\u001b[2m\u001b[36m(RayTrainWorker pid=134267)\u001b[0m Setting up process group for: env:// [rank=0, world_size=16]\n", + "\u001b[2m\u001b[36m(LightningTrainer pid=134103)\u001b[0m Executing DAG InputDataBuffer[Input] -> TaskPoolMapOperator[MapBatches(BatchMapper._transform_pandas)->MapBatches(BatchMapper._transform_pandas)] -> AllToAllOperator[RandomizeBlockOrder]\n", + "\u001b[2m\u001b[36m(LightningTrainer pid=134103)\u001b[0m Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=None), locality_with_output=False, preserve_order=False, actor_locality_enabled=True, verbose_progress=False)\n", + "\u001b[2m\u001b[36m(LightningTrainer pid=134103)\u001b[0m Tip: For detailed progress reporting, run `ray.data.DataContext.get_current().execution_options.verbose_progress = True`\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "da7f200767b448d7b409fcdd07daecce", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "(pid=134103) - RandomizeBlockOrder 1: 0%| | 0/1 [00:00.*<' in xml string `line`\",\n", + " },\n", + " {\n", + " \"intent\": \"send a signal `signal.SIGUSR1` to the current process\",\n", + " },\n", + "]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's begin by examining the generated outputs without fine-tuning. In this case study, we utilize [Aviary Explorer](https://aviary.anyscale.com), an open-source multi-LLM serving platform supported by Ray and Anyscale. You can easily select from a variety of open-source LLMs and compare their generation quality, cost, latency, and many other metrics.\n", + "\n", + "We constructed a prompt in a zero-shot learning manner and feed it into 3 OSS LLMs.\n", + "\n", + "![](https://user-images.githubusercontent.com/26745457/250704232-65a20f1b-6752-4d6c-bba1-8296a373162f.png)\n", + "\n", + "\n", + "- `vicuna-13b-v1.3` begins to speak Chinese.\n", + "- `mpt-7b-chat` generates a reasonable code snippet, but with multiple lines.\n", + "- `falcon-7b-sft` generates a one line snippet, but it doesn't seem to work.\n", + "\n", + "As we can see, none of them generate a satisfactory code snippet. \n", + "\n", + "Now let's check the performance of our fine-tuned `vicuna-13b-v1.3` model:" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/ray/anaconda3/lib/python3.10/site-packages/transformers/pipelines/base.py:1081: UserWarning: You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset\n", + " warnings.warn(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Intent: replace white spaces in colunm 'col' of dataframe `df` with '_'\n", + "One-line code snippet: `df['col'] = df['col'].str.replace(' ', '_')`\n", + "\n", + "Intent: search for occurrences of regex pattern '>.*<' in xml string `line`\n", + "One-line code snippet: `re.findall('>.*<', line)``\n", + "\n", + "Intent: send a signal `signal.SIGUSR1` to the current process\n", + "One-line code snippet: `os.kill(os.getpid(), signal.SIGUSR1)``\n" + ] + } + ], + "source": [ + "for case in testcases:\n", + " prompt = PROMPT_TEMPLATE.format(intent=case[\"intent\"], snippet=\"\")\n", + " output = generator(prompt, max_new_tokens=30, do_sample=True)\n", + " print(output[0][\"generated_text\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Test the Generated Code Snippets\n", + "\n", + "The generated code snippets look pretty reasonable. The results covered Pandas operations, regular expressions, and Linux commands. Let's test them one by one." + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Before\n", + " col\n", + "0 abc def ghi\n", + "1 12 3 456\n", + "2 \n", + "After\n", + " col\n", + "0 abc_def_ghi\n", + "1 _12_3_456\n", + "2 _____\n" + ] + } + ], + "source": [ + "import pandas as pd\n", + "\n", + "df = pd.DataFrame.from_dict({\"col\": [\"abc def ghi\", \" 12 3 456\", \" \"]})\n", + "print(\"Before\\n\", df)\n", + "\n", + "df[\"col\"] = df[\"col\"].str.replace(\" \", \"_\")\n", + "print(\"After\\n\", df)" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['>The Great Gatsby<',\n", + " '>F. Scott Fitzgerald<',\n", + " '>1925<',\n", + " '>Sapiens: A Brief History of Humankind<',\n", + " '>Yuval Noah Harari<',\n", + " '>2011<']" + ] + }, + "execution_count": 47, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import re\n", + "\n", + "line = \"\"\"\n", + "\n", + " \n", + " The Great Gatsby\n", + " F. Scott Fitzgerald\n", + " 1925\n", + " \n", + " \n", + " Sapiens: A Brief History of Humankind\n", + " Yuval Noah Harari\n", + " 2011\n", + " \n", + "\n", + "\"\"\"\n", + "re.findall(\">.*<\", line)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Finally, let's hand it over to LLM and let it wrap up the demo:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os, signal\n", + "\n", + "os.kill(os.getpid(), signal.SIGUSR1) # Terminate the current process~" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## References:\n", + "\n", + "- [CoNaLa: The Code/Natural Language Challenge](https://conala-corpus.github.io/)\n", + "- [HuggingFace: DeepSpeed Integration](https://huggingface.co/docs/transformers/main_classes/deepspeed#deepspeed-integration)\n", + "- [HuggingFace: Handling big models for inference](https://huggingface.co/docs/accelerate/main/usage_guides/big_modeling)\n", + "- [Lightning Transformers: DeepSpeed Training with Big Transformer Models](https://lightning-transformers.readthedocs.io/en/latest/)\n", + "- [Aviary: Open Source Multi-LLM Serving](https://www.anyscale.com/blog/announcing-aviary-open-source-multi-llm-serving-solution)\n", + "- Rajbhandari, S., Rasley, J., et al. (2020). ZeRO: Memory Optimizations Toward Training Trillion Parameter Models. [arXiv:1910.02054](https://arxiv.org/abs/1910.02054)\n", + "- Zheng, L., Chiang, W-L., Sheng, Y., et al. (2023). Judging LLM-as-a-judge with MT-Bench and Chatbot Arena. [arXiv:2306.05685](https://arxiv.org/abs/2306.05685)\n", + "\n", + "\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.15" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/release/air_examples/vicuna_13b_lightning_deepspeed_finetuning/test_myst_doc.py b/release/air_examples/vicuna_13b_lightning_deepspeed_finetuning/test_myst_doc.py new file mode 120000 index 0000000000000..c265ccc7b062b --- /dev/null +++ b/release/air_examples/vicuna_13b_lightning_deepspeed_finetuning/test_myst_doc.py @@ -0,0 +1 @@ +../../../doc/test_myst_doc.py \ No newline at end of file diff --git a/release/air_examples/vicuna_13b_lightning_deepspeed_finetuning/vicuna_13b_deepspeed_compute_aws.yaml b/release/air_examples/vicuna_13b_lightning_deepspeed_finetuning/vicuna_13b_deepspeed_compute_aws.yaml new file mode 100644 index 0000000000000..17f69c81a906a --- /dev/null +++ b/release/air_examples/vicuna_13b_lightning_deepspeed_finetuning/vicuna_13b_deepspeed_compute_aws.yaml @@ -0,0 +1,20 @@ +cloud_id: {{env["ANYSCALE_CLOUD_ID"]}} +region: us-west-2 + +head_node_type: + name: head_node + instance_type: g5.16xlarge + +worker_node_types: + - name: worker_node + instance_type: g5.4xlarge + min_workers: 15 + max_workers: 15 + use_spot: false + +aws: + TagSpecifications: + - ResourceType: "instance" + Tags: + - Key: ttl-hours + Value: '24' diff --git a/release/air_examples/vicuna_13b_lightning_deepspeed_finetuning/vicuna_13b_deepspeed_env.yaml b/release/air_examples/vicuna_13b_lightning_deepspeed_finetuning/vicuna_13b_deepspeed_env.yaml new file mode 100644 index 0000000000000..77acb25855284 --- /dev/null +++ b/release/air_examples/vicuna_13b_lightning_deepspeed_finetuning/vicuna_13b_deepspeed_env.yaml @@ -0,0 +1,27 @@ +base_image: {{ env["RAY_IMAGE_ML_NIGHTLY_GPU"] | default("anyscale/ray:nightly-py38-cu118") }} +env_vars: {} +debian_packages: + - curl + +python: + pip_packages: + - datasets==2.13.1 + - evaluate==0.4.0 + - scikit-learn==1.3.0 + - boto3==1.28.5 + - myst-parser==0.15.2 + - myst-nb==0.13.1 + - jupytext==1.13.6 + - typing-extensions<4.6.0 + conda_packages: [] + +post_build_cmds: + - pip uninstall -y ray || true && pip3 install -U {{ env["RAY_WHEELS"] | default("ray") }} + - {{ env["RAY_WHEELS_SANITY_CHECK"] | default("echo No Ray wheels sanity check") }} + - echo "sudo lsblk -f" >> ~/.bashrc + - echo "yes N | sudo mkfs -t ext4 /dev/nvme1n1 || true" >> ~/.bashrc + - echo "mkdir -p /mnt/local_storage" >> ~/.bashrc + - echo "sudo chmod 0777 /mnt/local_storage" >> ~/.bashrc + - echo "sudo mount /dev/nvme1n1 /mnt/local_storage || true" >> ~/.bashrc + - pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 + - pip3 install "pytorch_lightning==2.0.3" "transformers==4.30.2" "accelerate==0.20.3" "deepspeed==0.9.4" diff --git a/release/air_examples/vicuna_13b_lightning_deepspeed_finetuning/vicuna_13b_lightning_deepspeed_finetune.ipynb b/release/air_examples/vicuna_13b_lightning_deepspeed_finetuning/vicuna_13b_lightning_deepspeed_finetune.ipynb new file mode 120000 index 0000000000000..ccd34dcfc22fa --- /dev/null +++ b/release/air_examples/vicuna_13b_lightning_deepspeed_finetuning/vicuna_13b_lightning_deepspeed_finetune.ipynb @@ -0,0 +1 @@ +../../../doc/source/train/examples/lightning/vicuna_13b_lightning_deepspeed_finetune.ipynb \ No newline at end of file diff --git a/release/release_tests.yaml b/release/release_tests.yaml index 7699ea7bcd31a..0fd334c7e33bb 100644 --- a/release/release_tests.yaml +++ b/release/release_tests.yaml @@ -967,6 +967,27 @@ # variations: TODO(jungong): add GCP variation. +- name: air_example_vicuna_13b_lightning_deepspeed_finetuning + group: AIR examples + working_dir: air_examples/vicuna_13b_lightning_deepspeed_finetuning + + python: "3.8" + + frequency: weekly + team: ml + cluster: + byod: + type: cu118 + pip: + - myst-parser==0.15.2 + - myst-nb==0.13.1 + - jupytext==1.13.6 + cluster_env: vicuna_13b_deepspeed_env.yaml + cluster_compute: vicuna_13b_deepspeed_compute_aws.yaml + + run: + timeout: 4700 + script: python test_myst_doc.py --path vicuna_13b_lightning_deepspeed_finetune.ipynb ##################################### # Workspace templates release tests #