Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add information about Amazon Elastic MapReduce Connection #26687

Merged
merged 7 commits into from
Oct 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 92 additions & 13 deletions airflow/providers/amazon/aws/hooks/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# under the License.
from __future__ import annotations

import json
import warnings
from time import sleep
from typing import Any, Callable
Expand All @@ -30,8 +31,11 @@

class EmrHook(AwsBaseHook):
"""
Interact with AWS EMR. emr_conn_id is only necessary for using the
create_job_flow method.
Interact with Amazon Elastic MapReduce Service.

:param emr_conn_id: :ref:`Amazon Elastic MapReduce Connection <howto/connection:emr>`.
This attribute is only necessary when using
the :meth:`~airflow.providers.amazon.aws.hooks.emr.EmrHook.create_job_flow` method.

Additional arguments (such as ``aws_conn_id``) may be specified and
are passed down to the underlying AwsBaseHook.
Expand All @@ -45,8 +49,8 @@ class EmrHook(AwsBaseHook):
conn_type = 'emr'
hook_name = 'Amazon Elastic MapReduce'

def __init__(self, emr_conn_id: str = default_conn_name, *args, **kwargs) -> None:
self.emr_conn_id: str = emr_conn_id
def __init__(self, emr_conn_id: str | None = default_conn_name, *args, **kwargs) -> None:
self.emr_conn_id = emr_conn_id
kwargs["client_type"] = "emr"
super().__init__(*args, **kwargs)

Expand Down Expand Up @@ -77,22 +81,97 @@ def get_cluster_id_by_name(self, emr_cluster_name: str, cluster_states: list[str

def create_job_flow(self, job_flow_overrides: dict[str, Any]) -> dict[str, Any]:
"""
Creates a job flow using the config from the EMR connection.
Keys of the json extra hash may have the arguments of the boto3
run_job_flow method.
Overrides for this config may be passed as the job_flow_overrides.
Create and start running a new cluster (job flow).

This method uses ``EmrHook.emr_conn_id`` to receive the initial Amazon EMR cluster configuration.
If ``EmrHook.emr_conn_id`` is empty or the connection does not exist, then an empty initial
configuration is used.

:param job_flow_overrides: Is used to overwrite the parameters in the initial Amazon EMR configuration
cluster. The resulting configuration will be used in the boto3 emr client run_job_flow method.

.. seealso::
- :ref:`Amazon Elastic MapReduce Connection <howto/connection:emr>`
- `API RunJobFlow <https://docs.aws.amazon.com/emr/latest/APIReference/API_RunJobFlow.html>`_
- `boto3 emr client run_job_flow method <https://boto3.amazonaws.com/v1/documentation/\
api/latest/reference/services/emr.html#EMR.Client.run_job_flow>`_.
"""
try:
emr_conn = self.get_connection(self.emr_conn_id)
config = emr_conn.extra_dejson.copy()
except AirflowNotFoundException:
config = {}
config = {}
if self.emr_conn_id:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Neat-picking: Should we have a separate function ?

def _validate_params_emr_conn_id(emr_conn_id: str):

IMHO it does increase readability by ditching a few level of indentation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually for this purpose usually use get_conn() however we can not overwrite this method because it uses for obtain AWS credentials (aws_conn_id).

We could create this method, but only use in one place. Current implementation not contain any complex logic, so personally I do not see any benefits with this separate method.

try:
emr_conn = self.get_connection(self.emr_conn_id)
except AirflowNotFoundException:
warnings.warn(
f"Unable to find {self.hook_name} Connection ID {self.emr_conn_id!r}, "
"using an empty initial configuration. If you want to get rid of this warning "
"message please provide a valid `emr_conn_id` or set it to None.",
UserWarning,
stacklevel=2,
)
else:
if emr_conn.conn_type and emr_conn.conn_type != self.conn_type:
warnings.warn(
f"{self.hook_name} Connection expected connection type {self.conn_type!r}, "
f"Connection {self.emr_conn_id!r} has conn_type={emr_conn.conn_type!r}. "
f"This connection might not work correctly.",
UserWarning,
stacklevel=2,
)
config = emr_conn.extra_dejson.copy()
config.update(job_flow_overrides)

response = self.get_conn().run_job_flow(**config)

return response

def test_connection(self):
"""
Return failed state for test Amazon Elastic MapReduce Connection (untestable).

We need to overwrite this method because this hook is based on
:class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsGenericHook`,
otherwise it will try to test connection to AWS STS by using the default boto3 credential strategy.
"""
msg = (
f"{self.hook_name!r} Airflow Connection cannot be tested, by design it stores "
f"only key/value pairs and does not make a connection to an external resource."
)
return False, msg

@staticmethod
def get_ui_field_behaviour() -> dict[str, Any]:
"""Returns custom UI field behaviour for Amazon Elastic MapReduce Connection."""
return {
"hidden_fields": ["host", "schema", "port", "login", "password"],
"relabeling": {
"extra": "Run Job Flow Configuration",
},
"placeholders": {
"extra": json.dumps(
{
"Name": "MyClusterName",
"ReleaseLabel": "emr-5.36.0",
"Applications": [{"Name": "Spark"}],
"Instances": {
"InstanceGroups": [
{
"Name": "Primary node",
"Market": "SPOT",
"InstanceRole": "MASTER",
"InstanceType": "m5.large",
"InstanceCount": 1,
},
],
"KeepJobFlowAliveWhenNoSteps": False,
"TerminationProtected": False,
},
"StepConcurrencyLevel": 2,
},
indent=2,
),
},
}


class EmrServerlessHook(AwsBaseHook):
"""
Expand Down
15 changes: 8 additions & 7 deletions airflow/providers/amazon/aws/operators/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,10 +332,13 @@ class EmrCreateJobFlowOperator(BaseOperator):
running Airflow in a distributed manner and aws_conn_id is None or
empty, then default boto3 configuration would be used (and must be
maintained on each worker node)
:param emr_conn_id: emr connection to use for run_job_flow request body.
This will be overridden by the job_flow_overrides param
:param emr_conn_id: :ref:`Amazon Elastic MapReduce Connection <howto/connection:emr>`.
Use to receive an initial Amazon EMR cluster configuration:
``boto3.client('emr').run_job_flow`` request body.
If this is None or empty or the connection does not exist,
then an empty initial configuration is used.
:param job_flow_overrides: boto3 style arguments or reference to an arguments file
(must be '.json') to override emr_connection extra. (templated)
(must be '.json') to override specific ``emr_conn_id`` extra parameters. (templated)
:param region_name: Region named passed to EmrHook
"""

Expand All @@ -349,17 +352,15 @@ def __init__(
self,
*,
aws_conn_id: str = 'aws_default',
emr_conn_id: str = 'emr_default',
emr_conn_id: str | None = 'emr_default',
job_flow_overrides: str | dict[str, Any] | None = None,
region_name: str | None = None,
**kwargs,
):
super().__init__(**kwargs)
self.aws_conn_id = aws_conn_id
self.emr_conn_id = emr_conn_id
if job_flow_overrides is None:
job_flow_overrides = {}
self.job_flow_overrides = job_flow_overrides
self.job_flow_overrides = job_flow_overrides or {}
self.region_name = region_name

def execute(self, context: Context) -> str:
Expand Down
43 changes: 43 additions & 0 deletions docs/apache-airflow-providers-amazon/connections/emr.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
.. Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at

.. http://www.apache.org/licenses/LICENSE-2.0

.. Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.

.. _howto/connection:emr:

Amazon Elastic MapReduce (EMR) Connection
=========================================

.. note::
This connection type is only used to store parameters to Start EMR Cluster (`run_job_flow` boto3 EMR client method).

This connection not intend to store any credentials for ``boto3`` client, if you try to pass any
parameters not listed in `RunJobFlow API <https://docs.aws.amazon.com/emr/latest/APIReference/API_RunJobFlow.html>`_
you will get an error like this.

.. code-block:: text

Parameter validation failed: Unknown parameter in input: "region_name", must be one of:

For Authenticating to AWS please use :ref:`Amazon Web Services Connection <howto/connection:aws>`.

Configuring the Connection
--------------------------

Extra (optional)
Specify the parameters (as a `json` dictionary) that can be used as an initial configuration
in :meth:`airflow.providers.amazon.aws.hooks.emr.EmrHook.create_job_flow` to propagate to
`RunJobFlow API <https://docs.aws.amazon.com/emr/latest/APIReference/API_RunJobFlow.html>`_.
All parameters are optional.
61 changes: 56 additions & 5 deletions tests/providers/amazon/aws/hooks/test_emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@
# under the License.
from __future__ import annotations

import unittest
from unittest import mock

import boto3
import pytest

from airflow.providers.amazon.aws.hooks.emr import EmrHook

Expand All @@ -29,8 +30,8 @@
mock_emr = None


@unittest.skipIf(mock_emr is None, 'moto package not present')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Love the conversion to pytest here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I try to replace unittests by pytest in most cases this is not required huge afford.
However still a lot of unittests tests

# Total number of files which use `unittest.TestCase `
❯ grep -rl 'unittest.TestCase' ./tests | wc -l
528

# By tests packages
❯ grep -rl 'unittest.TestCase' ./tests | cut -d"/" -f3 | sort | uniq -c | sort -nr
    379 providers
     36 charts
     23 utils
     20 cli
     10 ti_deps
     10 api_connexion
      6 sensors
      6 operators
      6 executors
      6 core
      5 www
      5 always
      4 models
      3 api
      2 task
      2 kubernetes
      1 security
      1 plugins
      1 macros
      1 hooks
      1 dag_processing

# By provider ("apache", "microsoft", "common" has subpackages which is separate provider)
❯ grep -rl 'unittest.TestCase' ./tests/providers/ | cut -d"/" -f5 | sort | uniq -c | sort -nr
    113 google
    101 amazon
     32 apache
     26 microsoft
      6 databricks
      4 redis
      4 mysql
      4 alibaba
      3 trino
      3 tableau
      3 qubole
      3 oracle
      3 jenkins
      3 http
      3 docker
      3 atlassian
      3 arangodb
      3 airbyte
      2 yandex
      2 vertica
      2 telegram
      2 sqlite
      2 snowflake
      2 sftp
      2 segment
      2 salesforce
      2 presto
      2 postgres
      2 opsgenie
      2 neo4j
      2 mongo
      2 jdbc
      2 influxdb
      2 imap
      2 grpc
      2 ftp
      2 exasol
      2 discord
      2 dingding
      2 datadog
      2 common
      2 cncf
      2 asana
      1 ssh
      1 singularity
      1 sendgrid
      1 samba
      1 papermill
      1 openfaas
      1 elasticsearch
      1 cloudant
      1 celery

class TestEmrHook(unittest.TestCase):
@pytest.mark.skipif(mock_emr is None, reason='moto package not present')
class TestEmrHook:
@mock_emr
def test_get_conn_returns_a_boto3_connection(self):
hook = EmrHook(aws_conn_id='aws_default', region_name='ap-southeast-2')
Expand Down Expand Up @@ -59,13 +60,63 @@ def test_create_job_flow_extra_args(self):
# AmiVersion is really old and almost no one will use it anymore, but
# it's one of the "optional" request params that moto supports - it's
# coverage of EMR isn't 100% it turns out.
cluster = hook.create_job_flow({'Name': 'test_cluster', 'ReleaseLabel': '', 'AmiVersion': '3.2'})

with pytest.warns(None): # Expected no warnings if ``emr_conn_id`` exists with correct conn_type
cluster = hook.create_job_flow({'Name': 'test_cluster', 'ReleaseLabel': '', 'AmiVersion': '3.2'})
cluster = client.describe_cluster(ClusterId=cluster['JobFlowId'])['Cluster']

# The AmiVersion comes back as {Requested,Running}AmiVersion fields.
assert cluster['RequestedAmiVersion'] == '3.2'

@mock.patch("airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook.get_conn")
def test_empty_emr_conn_id(self, mock_boto3_client):
"""Test empty ``emr_conn_id``."""
mock_run_job_flow = mock.MagicMock()
mock_boto3_client.return_value.run_job_flow = mock_run_job_flow
job_flow_overrides = {"foo": "bar"}

hook = EmrHook(aws_conn_id="aws_default", emr_conn_id=None)
hook.create_job_flow(job_flow_overrides)
mock_run_job_flow.assert_called_once_with(**job_flow_overrides)

@mock.patch("airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook.get_conn")
def test_missing_emr_conn_id(self, mock_boto3_client):
"""Test not exists ``emr_conn_id``."""
mock_run_job_flow = mock.MagicMock()
mock_boto3_client.return_value.run_job_flow = mock_run_job_flow
job_flow_overrides = {"foo": "bar"}

hook = EmrHook(aws_conn_id="aws_default", emr_conn_id="not-exists-emr-conn-id")
warning_message = r"Unable to find Amazon Elastic MapReduce Connection ID 'not-exists-emr-conn-id',.*"
with pytest.warns(UserWarning, match=warning_message):
hook.create_job_flow(job_flow_overrides)
mock_run_job_flow.assert_called_once_with(**job_flow_overrides)

@mock.patch("airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook.get_conn")
def test_emr_conn_id_wrong_conn_type(self, mock_boto3_client):
"""Test exists ``emr_conn_id`` have unexpected ``conn_type``."""
mock_run_job_flow = mock.MagicMock()
mock_boto3_client.return_value.run_job_flow = mock_run_job_flow
job_flow_overrides = {"foo": "bar"}

with mock.patch.dict("os.environ", AIRFLOW_CONN_WRONG_TYPE_CONN="aws://"):
hook = EmrHook(aws_conn_id="aws_default", emr_conn_id="wrong_type_conn")
warning_message = (
r"Amazon Elastic MapReduce Connection expected connection type 'emr'"
r".* This connection might not work correctly."
)
with pytest.warns(UserWarning, match=warning_message):
hook.create_job_flow(job_flow_overrides)
mock_run_job_flow.assert_called_once_with(**job_flow_overrides)

@pytest.mark.parametrize("aws_conn_id", ["aws_default", None])
@pytest.mark.parametrize("emr_conn_id", ["emr_default", None])
def test_emr_connection(self, aws_conn_id, emr_conn_id):
"""Test that ``EmrHook`` always return False state."""
hook = EmrHook(aws_conn_id=aws_conn_id, emr_conn_id=emr_conn_id)
result, message = hook.test_connection()
assert not result
assert message.startswith("'Amazon Elastic MapReduce' Airflow Connection cannot be tested")

@mock_emr
def test_get_cluster_id_by_name(self):
"""
Expand Down