Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GenerationConfig and SafetySettings parameters to Google Cloud Multimodal Model Operators #40126

Merged
merged 4 commits into from
Jun 14, 2024
Merged

Add GenerationConfig and SafetySettings parameters to Google Cloud Multimodal Model Operators #40126

merged 4 commits into from
Jun 14, 2024

Conversation

CYarros10
Copy link
Contributor

@CYarros10 CYarros10 commented Jun 7, 2024

Gemini / Multimodal model Airflow Hooks/Operators should mimic typical usage as defined in the sample code block below. This provides fine-tuning model output via the additional configurations GenerationConfig and SafetySettings

Updated Operator Usage

    prompt_multimodal_model_task = PromptMultimodalModelOperator(
        task_id="prompt_multimodal_model_task",
        project_id=PROJECT_ID,
        location=REGION,
        prompt=PROMPT,
        generation_config=GENERATION_CONFIG,
        safety_settings=SAFETY_SETTINGS,
        pretrained_model=MULTIMODAL_MODEL,
    )

Sample Code

vertexai.init(project=project_id, location="us-central1")

model = generative_models.GenerativeModel(model_name="gemini-1.0-pro-vision-001")

# Generation config
generation_config = generative_models.GenerationConfig(
    max_output_tokens=2048, temperature=0.4, top_p=1, top_k=32
)

# Safety config
safety_config = [
    generative_models.SafetySetting(
        category=generative_models.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
        threshold=generative_models.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
    ),
    generative_models.SafetySetting(
        category=generative_models.HarmCategory.HARM_CATEGORY_HARASSMENT,
        threshold=generative_models.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
    ),
]

image_file = Part.from_uri(
    "gs://cloud-samples-data/generative-ai/image/scones.jpg", "image/jpeg"
)

# Generate content
responses = model.generate_content(
    [image_file, "What is in this image?"],
    generation_config=generation_config,
    safety_settings=safety_config,
    stream=True,
)

^ Add meaningful description above
Read the Pull Request Guidelines for more information.
In case of fundamental code changes, an Airflow Improvement Proposal (AIP) is needed.
In case of a new dependency, check compliance with the ASF 3rd Party License Policy.
In case of backwards incompatible changes please leave a note in a newsfragment file, named {pr_number}.significant.rst or {issue_number}.significant.rst, in newsfragments.

@boring-cyborg boring-cyborg bot added area:providers area:system-tests provider:google Google (including GCP) related issues labels Jun 7, 2024
Copy link
Contributor

@MaksYermak MaksYermak left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@amoghrajesh amoghrajesh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good!

@potiuk potiuk merged commit e2b8f68 into apache:main Jun 14, 2024
108 checks passed
romsharon98 pushed a commit to romsharon98/airflow that referenced this pull request Jul 26, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
area:providers area:system-tests provider:google Google (including GCP) related issues
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants