Skip to content

Commit

Permalink
Merge branch 'master' of github.com:ludwig-ai/ludwig into release-0.8
Browse files Browse the repository at this point in the history
  • Loading branch information
justinxzhao committed Sep 12, 2023
2 parents 5bd9287 + 6178b48 commit d3008c6
Show file tree
Hide file tree
Showing 21 changed files with 842 additions and 29 deletions.
8 changes: 8 additions & 0 deletions ludwig/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,14 @@ def on_epoch_end(self, trainer, progress_tracker, save_path):
# auto tune batch size
self._tune_batch_size(trainer, training_set, random_seed=random_seed)

if (
self.config_obj.model_type == "LLM"
and trainer.config.type == "none"
and self.config_obj.adapter is not None
and self.config_obj.adapter.pretrained_adapter_weights is not None
):
trainer.model.initialize_adapter() # Load pre-trained adapter weights for inference only

# train model
if self.backend.is_coordinator():
print_boxed("TRAINING")
Expand Down
79 changes: 71 additions & 8 deletions ludwig/config_validation/checks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Checks that are not easily covered by marshmallow JSON schema validation like parameter interdependencies."""

from abc import ABC, abstractmethod
from re import findall
from typing import Callable, TYPE_CHECKING

from transformers import AutoConfig
Expand Down Expand Up @@ -493,6 +494,14 @@ def check_llm_finetuning_trainer_config(config: "ModelConfig"): # noqa: F821
if config.model_type != MODEL_LLM:
return

if (
config.trainer.type == "none"
and config.adapter is not None
and config.adapter.pretrained_adapter_weights is not None
):
# If performing zero-shot, we must specify pretrained adapter weights
return

if config.adapter is not None and config.trainer.type != "finetune":
raise ConfigValidationError("LLM finetuning requires trainer type to be finetune.")

Expand All @@ -508,7 +517,11 @@ def check_llm_finetuning_backend_config(config: "ModelConfig"): # noqa: F821
return

# LLM finetuning is only supported by the finetune trainer type
if config.trainer.type != "finetune":
if (
config.trainer.type != "finetune"
and config.adapter is not None
and config.adapter.pretrained_adapter_weights is not None
):
return

# Using local backend, so skip the checks below
Expand All @@ -528,9 +541,8 @@ def check_llm_finetuning_backend_config(config: "ModelConfig"): # noqa: F821
def check_llm_finetuning_adalora_config(config: "ModelConfig"):
"""Checks that the adalora adapter is configured correctly.
It requires a set of target_modules to be specified in the config for the model. If it isn't specified by the user,
we also check against PEFT's predefined target module list for ADALORA to see if this key is present there. If
neither is true, AdaloraModel will run into issues downstream.
We check against PEFT's predefined target module list for ADALORA to see if this target_modules is present there. If
not, AdaloraModel will run into issues downstream.
"""
if config.model_type != MODEL_LLM:
return
Expand All @@ -544,10 +556,7 @@ def check_llm_finetuning_adalora_config(config: "ModelConfig"):
from peft.utils import TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING

model_config = _get_llm_model_config(config.base_model)
if (
not config.adapter.target_modules
and model_config.model_type not in TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING
):
if model_config.model_type not in TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING:
raise ConfigValidationError(
f"Adalora adapter is not supported for {model_config.model_type} model. "
f"Supported model types are: {list(TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING.keys())}. "
Expand Down Expand Up @@ -606,3 +615,57 @@ def check_qlora_requirements(config: "ModelConfig") -> None: # noqa: F821

if config.quantization and (not config.adapter or config.adapter.type != "lora"):
raise ConfigValidationError("Fine-tuning and LLM with quantization requires using the 'lora' adapter")


@register_config_check
def check_prompt_requirements(config: "ModelConfig") -> None: # noqa: F821
"""Checks that prompt's template and task properties are valid, according to the description on the schema."""
if config.model_type != MODEL_LLM:
return

# TODO: `prompt` by default should be set to null, not a default dict:
# # If no prompt is provided, no validation necessary:
# if not config.prompt:
# return
from ludwig.schema.llms.prompt import PromptConfig, RetrievalConfig

if config.prompt == PromptConfig():
return

template = config.prompt.template
task = config.prompt.task
retrieval = config.prompt.retrieval

# If template is NOT provided, then task is required for zero/few shot learning:
if not template and not task:
raise ConfigValidationError("A prompt task is required if no template is provided!")

template_refs = set(findall(r"\{(.*?)\}", template)) if isinstance(template, str) else set()

# If a template IS provided (i.e. we are not doing a built-in zero/few-shot learning), then...
if template:
# If task is also provided, the template must contain it:
if task and "__task__" not in template_refs:
raise ConfigValidationError(
"When providing a task, you must make sure that the task keyword `{__task__} is "
"present somewhere in the template string!"
)

# If retrieval is also provided, the template must reference it:
# TODO: retrieval by default should be set to null, not a default dict:
if retrieval and retrieval != RetrievalConfig() and "__context__" not in template_refs:
raise ConfigValidationError(
"When providing a retrieval config, you must make sure that the task keyword `{__context__}` is "
"present somewhere in the template string!"
)

# Otherwise, the template should at least contain the sample keyword or some input column:
# TODO: len(template_refs) is a hacky attempt to check that there are references to *something* in the
# string. The proper validation is to check the references against the features in the user's dataset - but we
# do not have access to the dataset in this code path right now.
if not task:
if len(template_refs) == 0 and "__sample__" not in template_refs:
raise ConfigValidationError(
"A template must contain at least one reference to a column or the sample keyword {__sample__} for "
"a JSON-serialized representation of non-output feature columns."
)
1 change: 1 addition & 0 deletions ludwig/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@
GENERATION = "generation"
PROMPT = "prompt"
ADAPTER = "adapter"
PRETRAINED_ADAPTER_WEIGHTS = "pretrained_adapter_weights"

# CrossEntropyLoss for LLMs
IGNORE_INDEX_TOKEN_ID = -100
Expand Down
10 changes: 10 additions & 0 deletions ludwig/contribs/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,16 @@ def on_train_start(self, model, config, *args, **kwargs):
del config["output_features"]
wandb.config.update(config)

def on_eval_end(self, trainer, progress_tracker, save_path):
"""Called from ludwig/models/model.py."""
for key, value in progress_tracker.log_metrics().items():
wandb.log({key: value})

def on_epoch_end(self, trainer, progress_tracker, save_path):
"""Called from ludwig/models/model.py."""
for key, value in progress_tracker.log_metrics().items():
wandb.log({key: value})

def on_visualize_figure(self, fig):
logger.info("wandb.on_visualize_figure() called...")
if wandb.run:
Expand Down
4 changes: 4 additions & 0 deletions ludwig/distributed/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ def prepare(
batch_size = (
trainer_config.batch_size if isinstance(trainer_config.batch_size, int) else MIN_POSSIBLE_BATCH_SIZE
)
# Paged and 8-bit optimizers are not supported by Deepspeed - just whatever is supported
# by torch.optim.Optimizer. https://www.deepspeed.ai/docs/config-json/#optimizer-parameters.
if trainer_config.optimizer.is_paged or trainer_config.optimizer.is_8bit:
raise ValueError("Cannot use a paged or 8-bit optimizer with DeepSpeed.")
optimizer_cls, optimizer_kwargs = get_optimizer_class_and_kwargs(trainer_config.optimizer, base_learning_rate)
ds_config = {
"amp": {
Expand Down
45 changes: 39 additions & 6 deletions ludwig/models/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,18 +216,51 @@ def output_feature_decoder(self) -> OutputFeature:
def initialize_adapter(self):
"""If an adapter config is provided, we want to wrap the model with a PEFT model for fine-tuning."""
if self.config_obj.adapter:
if self.config_obj.trainer.type != "finetune":
if self.config_obj.trainer.type != "finetune" and not self.config_obj.adapter.pretrained_adapter_weights:
raise ValueError(
"Adapter config was provided, but trainer type is not set to `finetune`. Either set the trainer to "
"`finetune` or remove the adapter config."
)

from peft import get_peft_model, TaskType
from peft import get_peft_model

if self.config_obj.adapter.pretrained_adapter_weights:
logger.info(f"Using pretrained adapter weights: {self.config_obj.adapter.pretrained_adapter_weights}")
# If pretrained adapter weights are provided, we want to load them into the model
from peft import MODEL_TYPE_TO_PEFT_MODEL_MAPPING, PeftConfig

peft_config = PeftConfig.from_pretrained(self.config_obj.adapter.pretrained_adapter_weights)
peft_dict = peft_config.to_dict()

# Need to update the peft config with some of the values from config_obj because not all of them are set
for param_name, param_value in self.config_obj.adapter.to_config().to_dict().items():
# Not all parameters are supported by all models, so we only add the parameter to the load kwargs
# if it is supported by the model.
if param_value is None:
# param_name and param_value come from the config object and contain default
# values for the adapter. Examples of parameters with missing values might be:
# 'auto_mapping', 'base_model_name_or_path', and 'task_type'.
# Note that some of these values might already be set in peft_config, which comes from HF
# directly (specifically, adapter_config.json in the model repo), and we don't want to override
# those values with None.
continue
if param_name not in peft_dict:
# If any parameters are not set in adapter_config.json in HF, we want to populate them with the
# appropriate default values.
setattr(peft_config, param_name, param_value)

self.model = MODEL_TYPE_TO_PEFT_MODEL_MAPPING[peft_config.task_type].from_pretrained(
self.model, self.config_obj.adapter.pretrained_adapter_weights
)
else:
# If no pretrained adapter is provided, we want to load untrained weights into the model
from peft import TaskType

peft_config = self.config_obj.adapter.to_config(
task_type=TaskType.CAUSAL_LM, tokenizer_name_or_path=self.model_name
)
self.model = get_peft_model(self.model, peft_config)
peft_config = self.config_obj.adapter.to_config(
task_type=TaskType.CAUSAL_LM, tokenizer_name_or_path=self.model_name
)

self.model = get_peft_model(self.model, peft_config)

logger.info("==================================================")
logger.info("Trainable Parameter Summary For Fine-Tuning")
Expand Down
9 changes: 9 additions & 0 deletions ludwig/modules/optimization_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,5 +65,14 @@ def create_optimizer(
:param optimizer_config: Instance of `ludwig.modules.optimization_modules.BaseOptimizerConfig`.
:return: Initialized instance of a torch optimizer.
"""
# Make sure the optimizer is compatible with the available resources:
if (optimizer_config.is_paged or optimizer_config.is_8bit) and (
not torch.cuda.is_available() or torch.cuda.device_count() == 0
):
raise ValueError(
"Cannot use a paged or 8-bit optimizer on a non-GPU machine. "
"Please use a different optimizer or run on a machine with a GPU."
)

optimizer_cls, optimizer_kwargs = get_optimizer_class_and_kwargs(optimizer_config, learning_rate)
return optimizer_cls(model.parameters(), **optimizer_kwargs)
6 changes: 5 additions & 1 deletion ludwig/schema/llms/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ def wrap(config: BaseAdapterConfig):
class BaseAdapterConfig(schema_utils.BaseMarshmallowConfig, ABC):
type: str

pretrained_adapter_weights: Optional[str] = schema_utils.String(
default=None, description="Path to pretrained weights.", allow_none=True
)

@abstractmethod
def to_config(self, **kwargs) -> "PeftConfig":
pass
Expand Down Expand Up @@ -359,7 +363,7 @@ def description(cls) -> str:
@register_adapter("adaption_prompt")
@ludwig_dataclass
class AdaptionPromptConfig(BaseAdapterConfig):
"""Adapted from https://github.com/huggingface/peft/blob/main/src/peft/tuners/adaption_prompt.py."""
"""Adapted from https://github.com/huggingface/peft/blob/main/src/peft/tuners/adaption_prompt/config.py."""

def __post_init__(self):
if not self.adapter_len:
Expand Down
Loading

0 comments on commit d3008c6

Please sign in to comment.