Skip to content

Commit

Permalink
schema: Add prompt validation check (#3564)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
ksbrar and pre-commit-ci[bot] authored Sep 6, 2023

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent 07ea472 commit 3d2ff0b
Showing 2 changed files with 82 additions and 0 deletions.
55 changes: 55 additions & 0 deletions ludwig/config_validation/checks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Checks that are not easily covered by marshmallow JSON schema validation like parameter interdependencies."""

from abc import ABC, abstractmethod
from re import findall
from typing import Callable, TYPE_CHECKING

from transformers import AutoConfig
@@ -606,3 +607,57 @@ def check_qlora_requirements(config: "ModelConfig") -> None: # noqa: F821

if config.quantization and (not config.adapter or config.adapter.type != "lora"):
raise ConfigValidationError("Fine-tuning and LLM with quantization requires using the 'lora' adapter")


@register_config_check
def check_prompt_requirements(config: "ModelConfig") -> None: # noqa: F821
"""Checks that prompt's template and task properties are valid, according to the description on the schema."""
if config.model_type != MODEL_LLM:
return

# TODO: `prompt` by default should be set to null, not a default dict:
# # If no prompt is provided, no validation necessary:
# if not config.prompt:
# return
from ludwig.schema.llms.prompt import PromptConfig, RetrievalConfig

if config.prompt == PromptConfig():
return

template = config.prompt.template
task = config.prompt.task
retrieval = config.prompt.retrieval

# If template is NOT provided, then task is required for zero/few shot learning:
if not template and not task:
raise ConfigValidationError("A prompt task is required if no template is provided!")

template_refs = set(findall(r"\{(.*?)\}", template)) if isinstance(template, str) else set()

# If a template IS provided (i.e. we are not doing a built-in zero/few-shot learning), then...
if template:
# If task is also provided, the template must contain it:
if task and "__task__" not in template_refs:
raise ConfigValidationError(
"When providing a task, you must make sure that the task keyword `{__task__} is "
"present somewhere in the template string!"
)

# If retrieval is also provided, the template must reference it:
# TODO: retrieval by default should be set to null, not a default dict:
if retrieval and retrieval != RetrievalConfig() and "__context__" not in template_refs:
raise ConfigValidationError(
"When providing a retrieval config, you must make sure that the task keyword `{__context__}` is "
"present somewhere in the template string!"
)

# Otherwise, the template should at least contain the sample keyword or some input column:
# TODO: len(template_refs) is a hacky attempt to check that there are references to *something* in the
# string. The proper validation is to check the references against the features in the user's dataset - but we
# do not have access to the dataset in this code path right now.
if not task:
if len(template_refs) == 0 and "__sample__" not in template_refs:
raise ConfigValidationError(
"A template must contain at least one reference to a column or the sample keyword {__sample__} for "
"a JSON-serialized representation of non-output feature columns."
)
27 changes: 27 additions & 0 deletions tests/ludwig/config_validation/test_checks.py
Original file line number Diff line number Diff line change
@@ -459,3 +459,30 @@ def test_check_qlora():
"type": "lora",
}
ModelConfig.from_dict(config)


def test_check_prompt_requirements():
config = {
"model_type": "llm",
"input_features": [
text_feature(name="test1", column="col1", encoder={"type": "passthrough"}),
],
"output_features": [text_feature()],
"base_model": "opt-350m",
}

ModelConfig.from_dict(config)

config["prompt"] = {"task": "Some task"}
ModelConfig.from_dict(config)

config["prompt"] = {"task": "Some task", "template": "Some template not mentioning the task"}
with pytest.raises(ConfigValidationError):
ModelConfig.from_dict(config)

config["prompt"] = {"task": "Some task", "template": "{__invalid__}"}
with pytest.raises(ConfigValidationError):
ModelConfig.from_dict(config)

config["prompt"] = {"task": "Some task", "template": "{__task__}"}
ModelConfig.from_dict(config)

0 comments on commit 3d2ff0b

Please sign in to comment.