From e1d023e41606c9b76b35e1d231c2f13368a30eca Mon Sep 17 00:00:00 2001 From: Arnav Garg <106701836+arnavgarg1@users.noreply.github.com> Date: Fri, 1 Sep 2023 22:39:32 +0300 Subject: [PATCH 1/8] Add test to show global_max_sequence_length can never exceed an LLMs context length (#3548) --- tests/integration_tests/test_llm.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/tests/integration_tests/test_llm.py b/tests/integration_tests/test_llm.py index f32a7716911..a10d7745b39 100644 --- a/tests/integration_tests/test_llm.py +++ b/tests/integration_tests/test_llm.py @@ -504,3 +504,27 @@ def _compare_models(model_1: torch.nn.Module, model_2: torch.nn.Module) -> bool: if not torch.equal(key_item_1[1], key_item_2[1]): return False return True + + +def test_global_max_sequence_length_for_llms(): + """Ensures that user specified global_max_sequence_length can never be greater than the model's context + length.""" + config = { + MODEL_TYPE: MODEL_LLM, + BASE_MODEL: "HuggingFaceH4/tiny-random-LlamaForCausalLM", + INPUT_FEATURES: [text_feature(name="input", encoder={"type": "passthrough"})], + OUTPUT_FEATURES: [text_feature(name="output")], + } + config_obj = ModelConfig.from_dict(config) + model = LLM(config_obj) + + # Default value is set based on model's context_len + assert model.global_max_sequence_length == 2048 + + # Override to a larger value in the config + config["preprocessing"] = {"global_max_sequence_length": 4096} + config_obj = ModelConfig.from_dict(config) + model = LLM(config_obj) + + # Check that the value can never be larger than the model's context_len + assert model.global_max_sequence_length == 2048 From 07ea4725e33fb3cf26a99cd6ea394839ad8755c2 Mon Sep 17 00:00:00 2001 From: Arnav Garg <106701836+arnavgarg1@users.noreply.github.com> Date: Wed, 6 Sep 2023 01:42:19 +0300 Subject: [PATCH 2/8] WandB: Add metric logging support on eval end and epoch end (#3586) --- ludwig/contribs/wandb.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/ludwig/contribs/wandb.py b/ludwig/contribs/wandb.py index 2710d958460..17f1011d327 100644 --- a/ludwig/contribs/wandb.py +++ b/ludwig/contribs/wandb.py @@ -53,6 +53,16 @@ def on_train_start(self, model, config, *args, **kwargs): del config["output_features"] wandb.config.update(config) + def on_eval_end(self, trainer, progress_tracker, save_path): + """Called from ludwig/models/model.py.""" + for key, value in progress_tracker.log_metrics().items(): + wandb.log({key: value}) + + def on_epoch_end(self, trainer, progress_tracker, save_path): + """Called from ludwig/models/model.py.""" + for key, value in progress_tracker.log_metrics().items(): + wandb.log({key: value}) + def on_visualize_figure(self, fig): logger.info("wandb.on_visualize_figure() called...") if wandb.run: From 3d2ff0b1537473b5b5aea1a672cf673639f89b10 Mon Sep 17 00:00:00 2001 From: Kabir Brar Date: Tue, 5 Sep 2023 20:28:11 -0400 Subject: [PATCH 3/8] schema: Add `prompt` validation check (#3564) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- ludwig/config_validation/checks.py | 55 +++++++++++++++++++ tests/ludwig/config_validation/test_checks.py | 27 +++++++++ 2 files changed, 82 insertions(+) diff --git a/ludwig/config_validation/checks.py b/ludwig/config_validation/checks.py index 9387048519a..1621ee4809f 100644 --- a/ludwig/config_validation/checks.py +++ b/ludwig/config_validation/checks.py @@ -1,6 +1,7 @@ """Checks that are not easily covered by marshmallow JSON schema validation like parameter interdependencies.""" from abc import ABC, abstractmethod +from re import findall from typing import Callable, TYPE_CHECKING from transformers import AutoConfig @@ -606,3 +607,57 @@ def check_qlora_requirements(config: "ModelConfig") -> None: # noqa: F821 if config.quantization and (not config.adapter or config.adapter.type != "lora"): raise ConfigValidationError("Fine-tuning and LLM with quantization requires using the 'lora' adapter") + + +@register_config_check +def check_prompt_requirements(config: "ModelConfig") -> None: # noqa: F821 + """Checks that prompt's template and task properties are valid, according to the description on the schema.""" + if config.model_type != MODEL_LLM: + return + + # TODO: `prompt` by default should be set to null, not a default dict: + # # If no prompt is provided, no validation necessary: + # if not config.prompt: + # return + from ludwig.schema.llms.prompt import PromptConfig, RetrievalConfig + + if config.prompt == PromptConfig(): + return + + template = config.prompt.template + task = config.prompt.task + retrieval = config.prompt.retrieval + + # If template is NOT provided, then task is required for zero/few shot learning: + if not template and not task: + raise ConfigValidationError("A prompt task is required if no template is provided!") + + template_refs = set(findall(r"\{(.*?)\}", template)) if isinstance(template, str) else set() + + # If a template IS provided (i.e. we are not doing a built-in zero/few-shot learning), then... + if template: + # If task is also provided, the template must contain it: + if task and "__task__" not in template_refs: + raise ConfigValidationError( + "When providing a task, you must make sure that the task keyword `{__task__} is " + "present somewhere in the template string!" + ) + + # If retrieval is also provided, the template must reference it: + # TODO: retrieval by default should be set to null, not a default dict: + if retrieval and retrieval != RetrievalConfig() and "__context__" not in template_refs: + raise ConfigValidationError( + "When providing a retrieval config, you must make sure that the task keyword `{__context__}` is " + "present somewhere in the template string!" + ) + + # Otherwise, the template should at least contain the sample keyword or some input column: + # TODO: len(template_refs) is a hacky attempt to check that there are references to *something* in the + # string. The proper validation is to check the references against the features in the user's dataset - but we + # do not have access to the dataset in this code path right now. + if not task: + if len(template_refs) == 0 and "__sample__" not in template_refs: + raise ConfigValidationError( + "A template must contain at least one reference to a column or the sample keyword {__sample__} for " + "a JSON-serialized representation of non-output feature columns." + ) diff --git a/tests/ludwig/config_validation/test_checks.py b/tests/ludwig/config_validation/test_checks.py index fb25c42152f..b46d14a17c2 100644 --- a/tests/ludwig/config_validation/test_checks.py +++ b/tests/ludwig/config_validation/test_checks.py @@ -459,3 +459,30 @@ def test_check_qlora(): "type": "lora", } ModelConfig.from_dict(config) + + +def test_check_prompt_requirements(): + config = { + "model_type": "llm", + "input_features": [ + text_feature(name="test1", column="col1", encoder={"type": "passthrough"}), + ], + "output_features": [text_feature()], + "base_model": "opt-350m", + } + + ModelConfig.from_dict(config) + + config["prompt"] = {"task": "Some task"} + ModelConfig.from_dict(config) + + config["prompt"] = {"task": "Some task", "template": "Some template not mentioning the task"} + with pytest.raises(ConfigValidationError): + ModelConfig.from_dict(config) + + config["prompt"] = {"task": "Some task", "template": "{__invalid__}"} + with pytest.raises(ConfigValidationError): + ModelConfig.from_dict(config) + + config["prompt"] = {"task": "Some task", "template": "{__task__}"} + ModelConfig.from_dict(config) From d2d682e091303bff1bce24878734569e4c7f5bac Mon Sep 17 00:00:00 2001 From: Arnav Garg <106701836+arnavgarg1@users.noreply.github.com> Date: Fri, 8 Sep 2023 10:37:43 +0300 Subject: [PATCH 4/8] Unpin Transformers for CodeLlama support (#3592) --- requirements.txt | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/requirements.txt b/requirements.txt index 42f8d994dab..ae1528cef32 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,9 +11,7 @@ torchaudio torchtext torchvision pydantic<2.0 -# TODO(Arnav): Lift transformers package restriction once 4.32.2 is released. -# Issue: https://github.com/ludwig-ai/ludwig/issues/3571 -transformers>=4.31.0,<4.32.0 +transformers tokenizers>=0.13.3 spacy>=2.3 PyYAML>=3.12,<6.0.1,!=5.4.* #Exlude PyYAML 5.4.* due to incompatibility with awscli From 6f9ed8fa6d5d6f70d7729eb17c71d51746b348bf Mon Sep 17 00:00:00 2001 From: Arnav Garg <106701836+arnavgarg1@users.noreply.github.com> Date: Fri, 8 Sep 2023 20:36:19 +0300 Subject: [PATCH 5/8] Add support for Paged Optimizers (Adam, Adamw), 8-bit optimizers, and new optimizers: LARS, LAMB and LION (#3588) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- ludwig/distributed/deepspeed.py | 4 + ludwig/modules/optimization_modules.py | 9 + ludwig/schema/optimizers.py | 445 +++++++++++++++++- ludwig/trainers/trainer.py | 3 + ludwig/utils/llm_utils.py | 60 ++- ludwig/utils/model_utils.py | 17 + requirements.txt | 3 + requirements_llm.txt | 1 - .../test_model_training_options.py | 10 + tests/ludwig/utils/test_model_utils.py | 38 +- .../training_success/test_training_success.py | 16 + 11 files changed, 599 insertions(+), 7 deletions(-) diff --git a/ludwig/distributed/deepspeed.py b/ludwig/distributed/deepspeed.py index 02f24312fc0..12986f2afdb 100644 --- a/ludwig/distributed/deepspeed.py +++ b/ludwig/distributed/deepspeed.py @@ -82,6 +82,10 @@ def prepare( batch_size = ( trainer_config.batch_size if isinstance(trainer_config.batch_size, int) else MIN_POSSIBLE_BATCH_SIZE ) + # Paged and 8-bit optimizers are not supported by Deepspeed - just whatever is supported + # by torch.optim.Optimizer. https://www.deepspeed.ai/docs/config-json/#optimizer-parameters. + if trainer_config.optimizer.is_paged or trainer_config.optimizer.is_8bit: + raise ValueError("Cannot use a paged or 8-bit optimizer with DeepSpeed.") optimizer_cls, optimizer_kwargs = get_optimizer_class_and_kwargs(trainer_config.optimizer, base_learning_rate) ds_config = { "amp": { diff --git a/ludwig/modules/optimization_modules.py b/ludwig/modules/optimization_modules.py index 1d2cb0b7c40..2a2b8760541 100644 --- a/ludwig/modules/optimization_modules.py +++ b/ludwig/modules/optimization_modules.py @@ -65,5 +65,14 @@ def create_optimizer( :param optimizer_config: Instance of `ludwig.modules.optimization_modules.BaseOptimizerConfig`. :return: Initialized instance of a torch optimizer. """ + # Make sure the optimizer is compatible with the available resources: + if (optimizer_config.is_paged or optimizer_config.is_8bit) and ( + not torch.cuda.is_available() or torch.cuda.device_count() == 0 + ): + raise ValueError( + "Cannot use a paged or 8-bit optimizer on a non-GPU machine. " + "Please use a different optimizer or run on a machine with a GPU." + ) + optimizer_cls, optimizer_kwargs = get_optimizer_class_and_kwargs(optimizer_config, learning_rate) return optimizer_cls(model.parameters(), **optimizer_kwargs) diff --git a/ludwig/schema/optimizers.py b/ludwig/schema/optimizers.py index 3c1fa6ed749..b7d6d0a8268 100644 --- a/ludwig/schema/optimizers.py +++ b/ludwig/schema/optimizers.py @@ -2,6 +2,7 @@ from dataclasses import field from typing import ClassVar, Dict, Optional, Tuple, Type +import bitsandbytes as bnb import torch from marshmallow import fields, ValidationError @@ -50,6 +51,16 @@ class BaseOptimizerConfig(schema_utils.BaseMarshmallowConfig, ABC): a `ValidationError`. """ + @property + def is_paged(self) -> bool: + """Returns True if the optimizer is a Paged optimizer.""" + return False + + @property + def is_8bit(self) -> bool: + """Returns True if the optimizer is an 8-bit optimizer.""" + return False + @DeveloperAPI @register_optimizer(name="sgd") @@ -66,18 +77,56 @@ class SGDOptimizerConfig(BaseOptimizerConfig): # Defaults taken from https://pytorch.org/docs/stable/generated/torch.optim.SGD.html#torch.optim.SGD : momentum: float = schema_utils.NonNegativeFloat( - default=0.0, description="Momentum factor.", parameter_metadata=OPTIMIZER_METADATA["momentum"] + default=0.0, + description="Momentum factor.", + parameter_metadata=OPTIMIZER_METADATA["momentum"], ) + weight_decay: float = schema_utils.NonNegativeFloat( - default=0.0, description="Weight decay ($L2$ penalty).", parameter_metadata=OPTIMIZER_METADATA["weight_decay"] + default=0.0, + description="Weight decay ($L2$ penalty).", + parameter_metadata=OPTIMIZER_METADATA["weight_decay"], ) + dampening: float = schema_utils.NonNegativeFloat( - default=0.0, description="Dampening for momentum.", parameter_metadata=OPTIMIZER_METADATA["dampening"] + default=0.0, + description="Dampening for momentum.", + parameter_metadata=OPTIMIZER_METADATA["dampening"], ) + nesterov: bool = schema_utils.Boolean( - default=False, description="Enables Nesterov momentum.", parameter_metadata=OPTIMIZER_METADATA["nesterov"] + default=False, + description="Enables Nesterov momentum.", + parameter_metadata=OPTIMIZER_METADATA["nesterov"], + ) + + +@DeveloperAPI +@register_optimizer(name="sgd_8bit") +@ludwig_dataclass +class SGD8BitOptimizerConfig(SGDOptimizerConfig): + """Parameters for stochastic gradient descent.""" + + optimizer_class: ClassVar[torch.optim.Optimizer] = bnb.optim.SGD8bit + + type: str = schema_utils.ProtectedString("sgd_8bit") + + block_wise: bool = schema_utils.Boolean( + default=False, + description="Whether to use block wise update.", + ) + + percentile_clipping: int = schema_utils.IntegerRange( + default=100, + min=0, + max=100, + description="Percentile clipping.", ) + @property + def is_8bit(self) -> bool: + return True + @DeveloperAPI @register_optimizer(name="lbfgs") @@ -169,6 +218,61 @@ class AdamOptimizerConfig(BaseOptimizerConfig): ) +@DeveloperAPI +@register_optimizer(name="adam_8bit") +@ludwig_dataclass +class Adam8BitOptimizerConfig(AdamOptimizerConfig): + optimizer_class: ClassVar[torch.optim.Optimizer] = bnb.optim.Adam8bit + + type: str = schema_utils.ProtectedString("adam_8bit") + + block_wise: bool = schema_utils.Boolean( + default=True, + description="Whether to use block wise update.", + ) + + percentile_clipping: int = schema_utils.IntegerRange( + default=100, + min=0, + max=100, + description="Percentile clipping.", + ) + + @property + def is_8bit(self) -> bool: + return True + + +@DeveloperAPI +@register_optimizer(name="paged_adam") +@ludwig_dataclass +class PagedAdamOptimizerConfig(Adam8BitOptimizerConfig): + optimizer_class: ClassVar[torch.optim.Optimizer] = bnb.optim.PagedAdam + + type: str = schema_utils.ProtectedString("paged_adam") + + @property + def is_paged(self) -> bool: + return True + + @property + def is_8bit(self) -> bool: + return False + + +@DeveloperAPI +@register_optimizer(name="paged_adam_8bit") +@ludwig_dataclass +class PagedAdam8BitOptimizerConfig(PagedAdamOptimizerConfig): + optimizer_class: ClassVar[torch.optim.Optimizer] = bnb.optim.PagedAdam8bit + + type: str = schema_utils.ProtectedString("paged_adam_8bit") + + @property + def is_8bit(self) -> bool: + return True + + @DeveloperAPI @register_optimizer(name="adamw") @ludwig_dataclass @@ -207,6 +311,61 @@ class AdamWOptimizerConfig(BaseOptimizerConfig): ) +@DeveloperAPI +@register_optimizer(name="adamw_8bit") +@ludwig_dataclass +class AdamW8BitOptimizerConfig(AdamWOptimizerConfig): + optimizer_class: ClassVar[torch.optim.Optimizer] = bnb.optim.AdamW8bit + + type: str = schema_utils.ProtectedString("adamw_8bit") + + block_wise: bool = schema_utils.Boolean( + default=True, + description="Whether to use block wise update.", + ) + + percentile_clipping: int = schema_utils.IntegerRange( + default=100, + min=0, + max=100, + description="Percentile clipping.", + ) + + @property + def is_8bit(self) -> bool: + return True + + +@DeveloperAPI +@register_optimizer(name="paged_adamw") +@ludwig_dataclass +class PagedAdamWOptimizerConfig(AdamW8BitOptimizerConfig): + optimizer_class: ClassVar[torch.optim.Optimizer] = bnb.optim.PagedAdamW + + type: str = schema_utils.ProtectedString("paged_adamw") + + @property + def is_paged(self) -> bool: + return True + + @property + def is_8bit(self) -> bool: + return False + + +@DeveloperAPI +@register_optimizer(name="paged_adamw_8bit") +@ludwig_dataclass +class PagedAdamW8BitOptimizerConfig(PagedAdamWOptimizerConfig): + optimizer_class: ClassVar[torch.optim.Optimizer] = bnb.optim.PagedAdamW8bit + + type: str = schema_utils.ProtectedString("paged_adamw_8bit") + + @property + def is_8bit(self) -> bool: + return True + + @DeveloperAPI @register_optimizer(name="adadelta") @ludwig_dataclass @@ -274,6 +433,31 @@ class AdagradOptimizerConfig(BaseOptimizerConfig): ) +@DeveloperAPI +@register_optimizer(name="adagrad_8bit") +@ludwig_dataclass +class Adagrad8BitOptimizerConfig(AdagradOptimizerConfig): + optimizer_class: ClassVar[torch.optim.Optimizer] = bnb.optim.Adagrad8bit + + type: str = schema_utils.ProtectedString("adagrad_8bit") + + block_wise: bool = schema_utils.Boolean( + default=True, + description="Whether to use block wise update.", + ) + + percentile_clipping: int = schema_utils.IntegerRange( + default=100, + min=0, + max=100, + description="Percentile clipping.", + ) + + @property + def is_8bit(self) -> bool: + return True + + @DeveloperAPI @register_optimizer(name="adamax") @ludwig_dataclass @@ -404,6 +588,259 @@ class RMSPropOptimizerConfig(BaseOptimizerConfig): weight_decay: float = schema_utils.NonNegativeFloat(default=0.0, description="Weight decay ($L2$ penalty).") +@DeveloperAPI +@register_optimizer(name="rmsprop_8bit") +@ludwig_dataclass +class RMSProp8BitOptimizerConfig(RMSPropOptimizerConfig): + optimizer_class: ClassVar[torch.optim.Optimizer] = bnb.optim.RMSprop8bit + + type: str = schema_utils.ProtectedString("rmsprop_8bit") + + block_wise: bool = schema_utils.Boolean( + default=True, + description="Whether to use block wise update.", + ) + + percentile_clipping: int = schema_utils.IntegerRange( + default=100, + min=0, + max=100, + description="Percentile clipping.", + ) + + @property + def is_8bit(self) -> bool: + return True + + +@DeveloperAPI +@register_optimizer(name="lamb") +@ludwig_dataclass +class LAMBOptimizerConfig(BaseOptimizerConfig): + """Layer-wise Adaptive Moments optimizer for Batch training. + + Paper: https://arxiv.org/pdf/1904.00962.pdf + """ + + optimizer_class: ClassVar[torch.optim.Optimizer] = bnb.optim.LAMB + + type: str = schema_utils.ProtectedString("lamb") + + bias_correction: bool = schema_utils.Boolean( + default=True, + ) + + betas: Tuple[float, float] = schema_utils.FloatRangeTupleDataclassField( + default=(0.9, 0.999), + description="Coefficients used for computing running averages of gradient and its square.", + parameter_metadata=OPTIMIZER_METADATA["betas"], + ) + + eps: float = schema_utils.NonNegativeFloat( + default=1e-08, + description="Term added to the denominator to improve numerical stability.", + parameter_metadata=OPTIMIZER_METADATA["eps"], + ) + + weight_decay: float = schema_utils.NonNegativeFloat( + default=0.0, + description="Weight decay (L2 penalty).", + parameter_metadata=OPTIMIZER_METADATA["weight_decay"], + ) + + amsgrad: bool = schema_utils.Boolean( + default=False, + description="Whether to use the AMSGrad variant of this algorithm from the paper 'On the Convergence of Adam " + "and Beyond'.", + parameter_metadata=OPTIMIZER_METADATA["amsgrad"], + ) + + adam_w_mode: bool = schema_utils.Boolean( + default=True, + description="Whether to use the AdamW mode of this algorithm from the paper " + "'Decoupled Weight Decay Regularization'.", + ) + + percentile_clipping: int = schema_utils.IntegerRange( + default=100, + min=0, + max=100, + description="Percentile clipping.", + ) + + block_wise: bool = schema_utils.Boolean( + default=False, + description="Whether to use block wise update.", + ) + + max_unorm: float = schema_utils.FloatRange( + default=1.0, + min=0.0, + max=1.0, + ) + + +@DeveloperAPI +@register_optimizer(name="lamb_8bit") +@ludwig_dataclass +class LAMB8BitOptimizerConfig(LAMBOptimizerConfig): + optimizer_class: ClassVar[torch.optim.Optimizer] = bnb.optim.LAMB8bit + + type: str = schema_utils.ProtectedString("lamb_8bit") + + @property + def is_8bit(self) -> bool: + return True + + +@DeveloperAPI +@register_optimizer(name="lars") +@ludwig_dataclass +class LARSOptimizerConfig(BaseOptimizerConfig): + """Layerwise Adaptive Rate Scaling. + + Paper: https://arxiv.org/pdf/1708.03888.pdf + """ + + optimizer_class: ClassVar[torch.optim.Optimizer] = bnb.optim.LARS + + type: str = schema_utils.ProtectedString("lars") + + # 0.9 taken from the original paper - momentum requires a non zero value + # https://arxiv.org/pdf/1708.03888v3.pdf + momentum: float = schema_utils.FloatRange( + default=0.9, + min=0.0, + max=1.0, + min_inclusive=False, + description="Momentum factor.", + parameter_metadata=OPTIMIZER_METADATA["momentum"], + ) + + dampening: float = schema_utils.FloatRange( + default=0.0, + min=0.0, + max=1.0, + description="Dampening for momentum.", + parameter_metadata=OPTIMIZER_METADATA["dampening"], + ) + + weight_decay: float = schema_utils.NonNegativeFloat( + default=0.0, + description="Weight decay (L2 penalty).", + parameter_metadata=OPTIMIZER_METADATA["weight_decay"], + ) + + nesterov: bool = schema_utils.Boolean( + default=False, + description="Enables Nesterov momentum.", + parameter_metadata=OPTIMIZER_METADATA["nesterov"], + ) + + percentile_clipping: int = schema_utils.IntegerRange( + default=100, + min=0, + max=100, + description="Percentile clipping.", + ) + + max_unorm: float = schema_utils.FloatRange( + default=1.0, + min=0.0, + max=1.0, + ) + + +@DeveloperAPI +@register_optimizer(name="lars_8bit") +@ludwig_dataclass +class LARS8BitOptimizerConfig(LARSOptimizerConfig): + optimizer_class: ClassVar[torch.optim.Optimizer] = bnb.optim.LARS8bit + + type: str = schema_utils.ProtectedString("lars_8bit") + + @property + def is_8bit(self) -> bool: + return True + + +@DeveloperAPI +@register_optimizer(name="lion") +@ludwig_dataclass +class LIONOptimizerConfig(BaseOptimizerConfig): + """Evolved Sign Momentum. + + Paper: https://arxiv.org/pdf/2302.06675.pdf + """ + + optimizer_class: ClassVar[torch.optim.Optimizer] = bnb.optim.Lion + + type: str = schema_utils.ProtectedString("lion") + + betas: Tuple[float, float] = schema_utils.FloatRangeTupleDataclassField( + default=(0.9, 0.999), + description="Coefficients used for computing running averages of gradient and its square.", + parameter_metadata=OPTIMIZER_METADATA["betas"], + ) + + weight_decay: float = schema_utils.NonNegativeFloat( + default=0.0, + description="Weight decay (L2 penalty).", + parameter_metadata=OPTIMIZER_METADATA["weight_decay"], + ) + + percentile_clipping: int = schema_utils.IntegerRange( + default=100, + min=0, + max=100, + description="Percentile clipping.", + ) + + block_wise: bool = schema_utils.Boolean( + default=True, + description="Whether to use block wise update.", + ) + + +@DeveloperAPI +@register_optimizer(name="lion_8bit") +@ludwig_dataclass +class LION8BitOptimizerConfig(LIONOptimizerConfig): + optimizer_class: ClassVar[torch.optim.Optimizer] = bnb.optim.Lion8bit + + type: str = schema_utils.ProtectedString("lion_8bit") + + @property + def is_8bit(self) -> bool: + return True + + +@DeveloperAPI +@register_optimizer(name="paged_lion") +@ludwig_dataclass +class PagedLionOptimizerConfig(LIONOptimizerConfig): + optimizer_class: ClassVar[torch.optim.Optimizer] = bnb.optim.PagedLion + + type: str = schema_utils.ProtectedString("paged_lion") + + @property + def is_paged(self) -> bool: + return True + + +@DeveloperAPI +@register_optimizer(name="paged_lion_8bit") +@ludwig_dataclass +class PagedLion8BitOptimizerConfig(PagedLionOptimizerConfig): + optimizer_class: ClassVar[torch.optim.Optimizer] = bnb.optim.PagedLion8bit + + type: str = schema_utils.ProtectedString("paged_lion_8bit") + + @property + def is_8bit(self) -> bool: + return True + + @DeveloperAPI def get_optimizer_conds(): """Returns a JSON schema of conditionals to validate against optimizer types defined in diff --git a/ludwig/trainers/trainer.py b/ludwig/trainers/trainer.py index 40f9abbca90..ff770b6db8c 100644 --- a/ludwig/trainers/trainer.py +++ b/ludwig/trainers/trainer.py @@ -57,6 +57,7 @@ from ludwig.utils.data_utils import load_json from ludwig.utils.defaults import default_random_seed from ludwig.utils.fs_utils import path_exists +from ludwig.utils.llm_utils import update_embedding_layer from ludwig.utils.metric_utils import get_metric_names, TrainerMetric from ludwig.utils.metrics_printed_table import print_metrics_table from ludwig.utils.misc_utils import set_random_seed @@ -211,6 +212,8 @@ def prepare(self): base_learning_rate *= lr_scale_fn(self.distributed.size()) self.base_learning_rate = base_learning_rate + # We may need to replace the embedding layer when using 8-bit optimizers from bitsandbytes. + update_embedding_layer(self.compiled_model, self.config) self.dist_model, self.optimizer = self.distributed.prepare( self.compiled_model, self.config, diff --git a/ludwig/utils/llm_utils.py b/ludwig/utils/llm_utils.py index 995a9a787e2..412abf33006 100644 --- a/ludwig/utils/llm_utils.py +++ b/ludwig/utils/llm_utils.py @@ -1,10 +1,23 @@ +import logging from typing import Dict, Tuple import torch import torch.nn.functional as F -from transformers import GPT2Tokenizer, GPT2TokenizerFast, LlamaTokenizer, LlamaTokenizerFast, PreTrainedTokenizer +from bitsandbytes.nn.modules import Embedding +from transformers import ( + AutoModelForCausalLM, + GPT2Tokenizer, + GPT2TokenizerFast, + LlamaTokenizer, + LlamaTokenizerFast, + PreTrainedTokenizer, +) from ludwig.constants import IGNORE_INDEX_TOKEN_ID, LOGITS, PREDICTIONS, PROBABILITIES +from ludwig.schema.trainer import LLMTrainerConfig +from ludwig.utils.model_utils import find_embedding_layer_with_path + +logger = logging.getLogger(__name__) def set_pad_token(tokenizer: PreTrainedTokenizer): @@ -376,3 +389,48 @@ def realign_target_and_prediction_tensors_for_inference( targets[of_name] = F.pad(targets[of_name], (0, zeros_to_add), value=pad_value).to(torch.int64) return targets, predictions + + +def update_embedding_layer(model: AutoModelForCausalLM, config_obj: LLMTrainerConfig) -> AutoModelForCausalLM: + """Updates the embedding layer of the model to use the 8-bit embedding layer from bitsandbytes.nn.modules. + + This is necessary when using 8-bit optimizers from bitsandbytes. + See: https://github.com/TimDettmers/bitsandbytes#tldr + """ + # If we're using an 8-bit optimizer, we need to replace the embedding layer with a custom embedding layer from + # bnb.nn.modules.Embedding. + if hasattr(config_obj, "optimizer") and config_obj.optimizer.is_8bit: + embedding_layer, module_path = find_embedding_layer_with_path(model) + if embedding_layer is None: + raise ValueError( + "Could not find an embedding layer in the model. This is required when using 8-bit optimizers" + " since a custom 8-bit embedding layer is used in place of the original embedding layer." + ) + + # Initialize the BNB embedding layer with the same parameters and weights as the original embedding layer. + bnb_embedding = Embedding( + num_embeddings=embedding_layer.num_embeddings, + embedding_dim=embedding_layer.embedding_dim, + padding_idx=embedding_layer.padding_idx, + max_norm=embedding_layer.max_norm, + norm_type=embedding_layer.norm_type, + scale_grad_by_freq=embedding_layer.scale_grad_by_freq, + sparse=embedding_layer.sparse, + _weight=embedding_layer.weight, + device=model.device, + ) + + # Update the model's original embedding layer to use the BNB embedding layer using the module_path + # returned by find_embedding_layer_with_path. + module_path = module_path.split(".") + module = model + for module_name in module_path[:-1]: + module = getattr(module, module_name) + setattr(module, module_path[-1], bnb_embedding) + + # Set the get input embeddings lambda function to return the BNB embedding layer + model.get_input_embeddings = lambda: bnb_embedding + + logger.info("Updated the pretrained embedding layer to use the embedding layer from bitsandbytes.") + + return model diff --git a/ludwig/utils/model_utils.py b/ludwig/utils/model_utils.py index f4a879859c4..5955dacbb34 100644 --- a/ludwig/utils/model_utils.py +++ b/ludwig/utils/model_utils.py @@ -74,3 +74,20 @@ def replace_tensors(m: torch.nn.Module, tensors: List[Dict], device: torch.devic name, torch.as_tensor(array, device=device, dtype=NUMPY_TO_TORCH_DTYPE.get(array.dtype)), ) + + +def find_embedding_layer_with_path(module, module_names=[]): + """Recursively search through a module to find an embedding layer and its module path. + + Returns a tuple containing the embedding layer and its module path. + """ + for name, child_module in module.named_children(): + if isinstance(child_module, torch.nn.Embedding): + # If an embedding layer is found, return it along with the module path + return child_module, ".".join(module_names + [name]) + else: + # Recursively search in the child module and update the module_names list + found, path = find_embedding_layer_with_path(child_module, module_names + [name]) + if found is not None: + return found, path + return None, None diff --git a/requirements.txt b/requirements.txt index ae1528cef32..e8edf83a555 100644 --- a/requirements.txt +++ b/requirements.txt @@ -44,6 +44,9 @@ sentencepiece # requirements for daft getdaft +# requirement for various paged and 8-bit optimizers +bitsandbytes<0.41.0 + # new data format support xlwt # excel xlrd>=2.0.1 # excel diff --git a/requirements_llm.txt b/requirements_llm.txt index 421c711b787..c036a7e7e3a 100644 --- a/requirements_llm.txt +++ b/requirements_llm.txt @@ -3,5 +3,4 @@ faiss-cpu accelerate loralib -bitsandbytes<0.41.0 peft>=0.4.0 diff --git a/tests/integration_tests/test_model_training_options.py b/tests/integration_tests/test_model_training_options.py index 1e46a5a4b23..bb150bd1339 100644 --- a/tests/integration_tests/test_model_training_options.py +++ b/tests/integration_tests/test_model_training_options.py @@ -238,6 +238,16 @@ def test_resume_training_mlflow(optimizer, tmp_path): @pytest.mark.parametrize("optimizer_type", optimizer_registry) def test_optimizers(optimizer_type, tmp_path): + if (optimizer_type in {"lars", "lamb", "lion"}) and ( + not torch.cuda.is_available() or torch.cuda.device_count() == 0 + ): + pytest.skip("Skip: lars, lamb, and lion optimizers require GPU and none are available.") + + if ("paged" in optimizer_type or "8bit" in optimizer_type) and ( + not torch.cuda.is_available() or torch.cuda.device_count() == 0 + ): + pytest.skip("Skip: paged and 8-bit optimizers require GPU and none are available.") + input_features, output_features = synthetic_test_data.get_feature_configs() config = { diff --git a/tests/ludwig/utils/test_model_utils.py b/tests/ludwig/utils/test_model_utils.py index ea6374979d4..2a3564255f3 100644 --- a/tests/ludwig/utils/test_model_utils.py +++ b/tests/ludwig/utils/test_model_utils.py @@ -1,6 +1,7 @@ import torch +from transformers import AutoModelForCausalLM -from ludwig.utils.model_utils import extract_tensors, replace_tensors +from ludwig.utils.model_utils import extract_tensors, find_embedding_layer_with_path, replace_tensors # Define a sample model for testing @@ -59,3 +60,38 @@ def test_replace_tensors(): for name, array in tensor_dict["buffers"].items(): assert name in module._buffers assert torch.allclose(module._buffers[name], torch.as_tensor(array, device=device)) + + +# Define a sample module structure for testing +class SampleModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.embedding = torch.nn.Embedding(10, 20) + self.rnn = torch.nn.LSTM(20, 30) + + +def test_find_embedding_layer_with_path_simple(): + # Test case 1: Test the function with a simple module structure + module = SampleModule() + embedding_layer, path = find_embedding_layer_with_path(module) + assert embedding_layer is not None + assert isinstance(embedding_layer, torch.nn.Embedding) + assert path == "embedding" + + +def test_find_embedding_layer_with_path_complex(): + # Test case 2: Test the function with a more complex module structure including AutoModelForCausalLM + model = AutoModelForCausalLM.from_pretrained("HuggingFaceM4/tiny-random-LlamaForCausalLM") + + embedding_layer, path = find_embedding_layer_with_path(model) + assert embedding_layer is not None + assert isinstance(embedding_layer, torch.nn.Embedding) + assert path == "model.embed_tokens" + + +def test_no_embedding_layer(): + # Test case 3: Embedding layer is not present + no_embedding_model = torch.nn.Sequential(torch.nn.Linear(10, 10), torch.nn.Linear(10, 10)) + embedding_layer, path = find_embedding_layer_with_path(no_embedding_model) + assert embedding_layer is None + assert path is None diff --git a/tests/training_success/test_training_success.py b/tests/training_success/test_training_success.py index d6320ee2aca..f54eb22beaf 100644 --- a/tests/training_success/test_training_success.py +++ b/tests/training_success/test_training_success.py @@ -69,6 +69,22 @@ def ecd_trainer_config_generator(static_schema: Dict[str, Any] = None) -> Tuple[ config["trainer"] = {"train_steps": 1} combined_configs = combine_configs(explored, config) + + # HACK(Arnav): Remove configs that have LARS, LAMB or Lion optimizers, or Paged or 8-bit optimizers. + # This is because they require GPUs. + filtered_configs = [] + + for config, dataset in combined_configs: + optimizer_type = config.get("trainer", {}).get("optimizer", "").get("type", "") + + if optimizer_type not in {"lars", "lamb", "lion"} and not ( + "paged" in optimizer_type or "8bit" in optimizer_type + ): + filtered_configs.append((config, dataset)) + + # Replace combined_configs with the filtered_configs + combined_configs = filtered_configs + logging.info(f"Generated {len(combined_configs)} for ECD trainer combinatorial tests.") for config, dataset in combined_configs: From 4a72e67b58233b3a22256cb24fc700e0f05436a7 Mon Sep 17 00:00:00 2001 From: Jim Thompson Date: Mon, 11 Sep 2023 12:36:58 -0400 Subject: [PATCH 6/8] FIX: Failure in TabTransformer Combiner Unit test (#3596) --- tests/ludwig/combiners/test_combiners.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/ludwig/combiners/test_combiners.py b/tests/ludwig/combiners/test_combiners.py index 4fd5fa6e304..645d5afec71 100644 --- a/tests/ludwig/combiners/test_combiners.py +++ b/tests/ludwig/combiners/test_combiners.py @@ -672,19 +672,17 @@ def test_tabtransformer_combiner_number_and_binary_with_category( @pytest.mark.parametrize( "feature_list", # defines parameter for fixture features_to_test() [ - [ - ("binary", [BATCH_SIZE, 1]), - ], [ ("binary", [BATCH_SIZE, 1]), ("binary", [BATCH_SIZE, 1]), ], [ ("number", [BATCH_SIZE, 1]), + ("number", [BATCH_SIZE, 1]), ], [ ("number", [BATCH_SIZE, 1]), - ("number", [BATCH_SIZE, 1]), + ("binary", [BATCH_SIZE, 1]), ], ], ) From 63be68328c151976aaac3632d6a5314523dcbbe7 Mon Sep 17 00:00:00 2001 From: Jeff Kinnison Date: Mon, 11 Sep 2023 15:45:01 -0400 Subject: [PATCH 7/8] fix: Move target tensor to model output device in `check_module_parameters_updated` (#3567) --- tests/integration_tests/parameter_update_utils.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/integration_tests/parameter_update_utils.py b/tests/integration_tests/parameter_update_utils.py index 62653d33ed6..fc232ab4994 100644 --- a/tests/integration_tests/parameter_update_utils.py +++ b/tests/integration_tests/parameter_update_utils.py @@ -69,15 +69,20 @@ def check_module_parameters_updated( # do update of model parameters optimizer.zero_grad() if isinstance(module_output, torch.Tensor): + module_target = module_target.to(device=module_output.device) loss = loss_function(module_output, target_tensor) elif isinstance(module_output, dict): if "logits" in module_output: + module_target = module_target.to(device=module_output["logits"].device) loss = loss_function(module_output["logits"], target_tensor) elif ENCODER_OUTPUT in module_output: + module_target = module_target.to(device=module_output[ENCODER_OUTPUT].device) loss = loss_function(module_output[ENCODER_OUTPUT], target_tensor) elif "combiner_output" in module_output: + module_target = module_target.to(device=module_output["combiner_output"].device) loss = loss_function(module_output["combiner_output"], target_tensor) elif isinstance(module_output, (list, tuple)): + module_target = module_target.to(device=module_output[0].device) loss = loss_function(module_output[0], target_tensor) else: raise ValueError(f"Unexpected output type. Module type found is {type(module_output)}") From 6178b482be83ca34576737d400e419b379559921 Mon Sep 17 00:00:00 2001 From: Infernaught <72055086+Infernaught@users.noreply.github.com> Date: Mon, 11 Sep 2023 20:36:39 -0400 Subject: [PATCH 8/8] Allow user to specify huggingface link or local path to pretrained lora weights (#3572) --- ludwig/api.py | 8 +++++ ludwig/config_validation/checks.py | 24 +++++++++----- ludwig/constants.py | 1 + ludwig/models/llm.py | 45 ++++++++++++++++++++++---- ludwig/schema/llms/peft.py | 6 +++- tests/integration_tests/test_llm.py | 50 +++++++++++++++++++++++++++++ 6 files changed, 119 insertions(+), 15 deletions(-) diff --git a/ludwig/api.py b/ludwig/api.py index 942034cfc1c..6abe4068831 100644 --- a/ludwig/api.py +++ b/ludwig/api.py @@ -615,6 +615,14 @@ def on_epoch_end(self, trainer, progress_tracker, save_path): # auto tune batch size self._tune_batch_size(trainer, training_set, random_seed=random_seed) + if ( + self.config_obj.model_type == "LLM" + and trainer.config.type == "none" + and self.config_obj.adapter is not None + and self.config_obj.adapter.pretrained_adapter_weights is not None + ): + trainer.model.initialize_adapter() # Load pre-trained adapter weights for inference only + # train model if self.backend.is_coordinator(): print_boxed("TRAINING") diff --git a/ludwig/config_validation/checks.py b/ludwig/config_validation/checks.py index 1621ee4809f..b873c1087a2 100644 --- a/ludwig/config_validation/checks.py +++ b/ludwig/config_validation/checks.py @@ -494,6 +494,14 @@ def check_llm_finetuning_trainer_config(config: "ModelConfig"): # noqa: F821 if config.model_type != MODEL_LLM: return + if ( + config.trainer.type == "none" + and config.adapter is not None + and config.adapter.pretrained_adapter_weights is not None + ): + # If performing zero-shot, we must specify pretrained adapter weights + return + if config.adapter is not None and config.trainer.type != "finetune": raise ConfigValidationError("LLM finetuning requires trainer type to be finetune.") @@ -509,7 +517,11 @@ def check_llm_finetuning_backend_config(config: "ModelConfig"): # noqa: F821 return # LLM finetuning is only supported by the finetune trainer type - if config.trainer.type != "finetune": + if ( + config.trainer.type != "finetune" + and config.adapter is not None + and config.adapter.pretrained_adapter_weights is not None + ): return # Using local backend, so skip the checks below @@ -529,9 +541,8 @@ def check_llm_finetuning_backend_config(config: "ModelConfig"): # noqa: F821 def check_llm_finetuning_adalora_config(config: "ModelConfig"): """Checks that the adalora adapter is configured correctly. - It requires a set of target_modules to be specified in the config for the model. If it isn't specified by the user, - we also check against PEFT's predefined target module list for ADALORA to see if this key is present there. If - neither is true, AdaloraModel will run into issues downstream. + We check against PEFT's predefined target module list for ADALORA to see if this target_modules is present there. If + not, AdaloraModel will run into issues downstream. """ if config.model_type != MODEL_LLM: return @@ -545,10 +556,7 @@ def check_llm_finetuning_adalora_config(config: "ModelConfig"): from peft.utils import TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING model_config = _get_llm_model_config(config.base_model) - if ( - not config.adapter.target_modules - and model_config.model_type not in TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING - ): + if model_config.model_type not in TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING: raise ConfigValidationError( f"Adalora adapter is not supported for {model_config.model_type} model. " f"Supported model types are: {list(TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING.keys())}. " diff --git a/ludwig/constants.py b/ludwig/constants.py index 226a60a4ed6..d2cc455df24 100644 --- a/ludwig/constants.py +++ b/ludwig/constants.py @@ -282,6 +282,7 @@ GENERATION = "generation" PROMPT = "prompt" ADAPTER = "adapter" +PRETRAINED_ADAPTER_WEIGHTS = "pretrained_adapter_weights" # CrossEntropyLoss for LLMs IGNORE_INDEX_TOKEN_ID = -100 diff --git a/ludwig/models/llm.py b/ludwig/models/llm.py index fc7319e3114..408d8c89cf4 100644 --- a/ludwig/models/llm.py +++ b/ludwig/models/llm.py @@ -216,18 +216,51 @@ def output_feature_decoder(self) -> OutputFeature: def initialize_adapter(self): """If an adapter config is provided, we want to wrap the model with a PEFT model for fine-tuning.""" if self.config_obj.adapter: - if self.config_obj.trainer.type != "finetune": + if self.config_obj.trainer.type != "finetune" and not self.config_obj.adapter.pretrained_adapter_weights: raise ValueError( "Adapter config was provided, but trainer type is not set to `finetune`. Either set the trainer to " "`finetune` or remove the adapter config." ) - from peft import get_peft_model, TaskType + from peft import get_peft_model + + if self.config_obj.adapter.pretrained_adapter_weights: + logger.info(f"Using pretrained adapter weights: {self.config_obj.adapter.pretrained_adapter_weights}") + # If pretrained adapter weights are provided, we want to load them into the model + from peft import MODEL_TYPE_TO_PEFT_MODEL_MAPPING, PeftConfig + + peft_config = PeftConfig.from_pretrained(self.config_obj.adapter.pretrained_adapter_weights) + peft_dict = peft_config.to_dict() + + # Need to update the peft config with some of the values from config_obj because not all of them are set + for param_name, param_value in self.config_obj.adapter.to_config().to_dict().items(): + # Not all parameters are supported by all models, so we only add the parameter to the load kwargs + # if it is supported by the model. + if param_value is None: + # param_name and param_value come from the config object and contain default + # values for the adapter. Examples of parameters with missing values might be: + # 'auto_mapping', 'base_model_name_or_path', and 'task_type'. + # Note that some of these values might already be set in peft_config, which comes from HF + # directly (specifically, adapter_config.json in the model repo), and we don't want to override + # those values with None. + continue + if param_name not in peft_dict: + # If any parameters are not set in adapter_config.json in HF, we want to populate them with the + # appropriate default values. + setattr(peft_config, param_name, param_value) + + self.model = MODEL_TYPE_TO_PEFT_MODEL_MAPPING[peft_config.task_type].from_pretrained( + self.model, self.config_obj.adapter.pretrained_adapter_weights + ) + else: + # If no pretrained adapter is provided, we want to load untrained weights into the model + from peft import TaskType - peft_config = self.config_obj.adapter.to_config( - task_type=TaskType.CAUSAL_LM, tokenizer_name_or_path=self.model_name - ) - self.model = get_peft_model(self.model, peft_config) + peft_config = self.config_obj.adapter.to_config( + task_type=TaskType.CAUSAL_LM, tokenizer_name_or_path=self.model_name + ) + + self.model = get_peft_model(self.model, peft_config) logger.info("==================================================") logger.info("Trainable Parameter Summary For Fine-Tuning") diff --git a/ludwig/schema/llms/peft.py b/ludwig/schema/llms/peft.py index 6e127ee5fbf..3ce30aeb07c 100644 --- a/ludwig/schema/llms/peft.py +++ b/ludwig/schema/llms/peft.py @@ -30,6 +30,10 @@ def wrap(config: BaseAdapterConfig): class BaseAdapterConfig(schema_utils.BaseMarshmallowConfig, ABC): type: str + pretrained_adapter_weights: Optional[str] = schema_utils.String( + default=None, description="Path to pretrained weights.", allow_none=True + ) + @abstractmethod def to_config(self, **kwargs) -> "PeftConfig": pass @@ -359,7 +363,7 @@ def description(cls) -> str: @register_adapter("adaption_prompt") @ludwig_dataclass class AdaptionPromptConfig(BaseAdapterConfig): - """Adapted from https://github.com/huggingface/peft/blob/main/src/peft/tuners/adaption_prompt.py.""" + """Adapted from https://github.com/huggingface/peft/blob/main/src/peft/tuners/adaption_prompt/config.py.""" def __post_init__(self): if not self.adapter_len: diff --git a/tests/integration_tests/test_llm.py b/tests/integration_tests/test_llm.py index a10d7745b39..91c53c9a7e8 100644 --- a/tests/integration_tests/test_llm.py +++ b/tests/integration_tests/test_llm.py @@ -18,6 +18,7 @@ MODEL_TYPE, OUTPUT_FEATURES, PREPROCESSING, + PRETRAINED_ADAPTER_WEIGHTS, PROMPT, TRAINER, TYPE, @@ -492,12 +493,61 @@ def test_default_max_sequence_length(): BATCH_SIZE: 8, EPOCHS: 2, }, + ADAPTER: {TYPE: "lora", PRETRAINED_ADAPTER_WEIGHTS: "Infernaught/test_adapter_weights"}, + BACKEND: {TYPE: "local"}, } config_obj = ModelConfig.from_dict(config) assert config_obj.input_features[0].preprocessing.max_sequence_length is None assert config_obj.output_features[0].preprocessing.max_sequence_length is None +@pytest.mark.parametrize("adapter", ["lora", "adalora", "adaption_prompt"]) +def test_load_pretrained_adapter_weights(adapter): + from peft import PeftModel + from transformers import PreTrainedModel + + weights = "" + model = "" + if adapter == "lora": + weights = "Infernaught/test_adapter_weights" + base_model = TEST_MODEL_NAME + elif adapter == "adalora": + weights = "Infernaught/test_adalora_weights" + base_model = "HuggingFaceH4/tiny-random-LlamaForCausalLM" + elif adapter == "adaption_prompt": + weights = "Infernaught/test_ap_weights" + base_model = "HuggingFaceH4/tiny-random-LlamaForCausalLM" + else: + raise () + + config = { + MODEL_TYPE: MODEL_LLM, + BASE_MODEL: base_model, + INPUT_FEATURES: [text_feature(name="input", encoder={"type": "passthrough"})], + OUTPUT_FEATURES: [text_feature(name="output")], + TRAINER: { + TYPE: "none", + BATCH_SIZE: 8, + EPOCHS: 2, + }, + ADAPTER: {TYPE: adapter, PRETRAINED_ADAPTER_WEIGHTS: weights}, + BACKEND: {TYPE: "local"}, + } + config_obj = ModelConfig.from_dict(config) + model = LLM(config_obj) + + assert model.config_obj.adapter.pretrained_adapter_weights + assert model.config_obj.adapter.pretrained_adapter_weights == weights + + model.prepare_for_training() + assert not isinstance(model.model, PreTrainedModel) + assert isinstance(model.model, PeftModel) + + config_obj = ModelConfig.from_dict(config) + assert config_obj.input_features[0].preprocessing.max_sequence_length is None + assert config_obj.output_features[0].preprocessing.max_sequence_length is None + + def _compare_models(model_1: torch.nn.Module, model_2: torch.nn.Module) -> bool: # Source: https://discuss.pytorch.org/t/check-if-models-have-same-weights/4351/6 for key_item_1, key_item_2 in zip(model_1.state_dict().items(), model_2.state_dict().items()):