Skip to content

Commit

Permalink
Move the heavy lifting into base hook - Individual service hooks are …
Browse files Browse the repository at this point in the history
…no longer modified

Add unit testing
Update README
  • Loading branch information
ferruzzi committed Dec 12, 2022
1 parent 2d97536 commit 617665d
Show file tree
Hide file tree
Showing 9 changed files with 211 additions and 116 deletions.
42 changes: 42 additions & 0 deletions airflow/providers/amazon/aws/hooks/base_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
import warnings
from copy import deepcopy
from functools import wraps
from os import PathLike
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union

import boto3
Expand All @@ -43,6 +45,7 @@
from botocore.client import ClientMeta
from botocore.config import Config
from botocore.credentials import ReadOnlyCredentials
from botocore.waiter import Waiter, WaiterModel
from dateutil.tz import tzlocal
from slugify import slugify

Expand All @@ -51,6 +54,7 @@
from airflow.exceptions import AirflowException, AirflowNotFoundException
from airflow.hooks.base import BaseHook
from airflow.providers.amazon.aws.utils.connection_wrapper import AwsConnectionWrapper
from airflow.providers.amazon.aws.waiters.base_waiter import BaseBotoWaiter
from airflow.providers_manager import ProvidersManager
from airflow.utils.helpers import exactly_one
from airflow.utils.log.logging_mixin import LoggingMixin
Expand Down Expand Up @@ -764,6 +768,44 @@ def test_connection(self):
except Exception as e:
return False, str(f"{type(e).__name__!r} error occurred while testing connection: {e}")

@cached_property
def waiter_path(self) -> PathLike[str] | None:
path = Path(__file__).parents[1].joinpath(f"waiters/{self.client_type}.json").resolve()
return path if path.exists() else None

def get_waiter(self, waiter_name: str) -> Waiter:
"""
First checks if there is a custom waiter with the provided waiter_name and
uses that if it exists, otherwise it will check the service client for a
waiter that matches the name and pass that through.
:param waiter_name: The name of the waiter. The name should exactly match the
name of the key in the waiter model file (typically this is CamelCase).
"""
if self.waiter_path and (waiter_name in self._list_custom_waiters()):
# Technically if waiter_name is in custom_waiters then self.waiter_path must
# exist but MyPy doesn't like the fact that self.waiter_path could be None.
with open(self.waiter_path) as config_file:
config = json.load(config_file)
return BaseBotoWaiter(client=self.conn, model_config=config).waiter(waiter_name)
# If there is no custom waiter found for the provided name,
# then try checking the service's official waiters.
return self.conn.get_waiter(waiter_name)

def list_waiters(self) -> list[str]:
"""Returns a list containing the names of all waiters for the service, official and custom."""
return [*self._list_official_waiters(), *self._list_custom_waiters()]

def _list_official_waiters(self) -> list[str]:
return self.conn.waiter_names

def _list_custom_waiters(self) -> list[str]:
if not self.waiter_path:
return []
with open(self.waiter_path) as config_file:
model_config = json.load(config_file)
return WaiterModel(model_config).waiter_names


class AwsBaseHook(AwsGenericHook[Union[boto3.client, boto3.resource]]):
"""
Expand Down
4 changes: 0 additions & 4 deletions airflow/providers/amazon/aws/hooks/eks.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from botocore.signers import RequestSigner

from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.providers.amazon.aws.waiters.eks import EksBotoWaiter
from airflow.utils import yaml
from airflow.utils.json import AirflowJsonEncoder

Expand Down Expand Up @@ -597,6 +596,3 @@ def fetch_access_token_for_cluster(self, eks_cluster_name: str) -> str:

# remove any base64 encoding padding:
return "k8s-aws-v1." + base64_url.rstrip("=")

def get_waiter(self, waiter_name):
return EksBotoWaiter(client=self.conn).waiter(waiter_name)
77 changes: 26 additions & 51 deletions airflow/providers/amazon/aws/waiters/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,33 +17,31 @@
under the License.
-->

This module is for custom Boto3 waiters. A BaseWaiter is provided in order to make
future additions as painless as possible. Since documentation on creating custom
waiters is pretty sparse out in the wild, this document can act as a short and rough
guide. It does not cover all edge cases and is meant as a rough quickstart guide.
This module is for custom Boto3 waiter configuration files. Since documentation
on creating custom waiters is pretty sparse out in the wild, this document can
act as a rough quickstart guide. It is not meant to cover all edge cases.

# To add a new custom waiter

## Create or modify the service waiter file
## Create or modify the service waiter config file

Find or create a file for the service it is related to, for example waiters/eks.py
Find or create a file for the service it is related to, for example waiters/eks.json

### In the service waiter file
### In the service waiter config file

#### 1: Create the waiter model config

Build or add to the waiter model config dictionary in that file. For examples of what these should
look like, have a look through some official waiter models. Some examples:
Build or add to the waiter model config json in that file. For examples of what these
should look like, have a look through some official waiter models. Some examples:

* [Cloudwatch](https://github.com/boto/botocore/blob/develop/botocore/data/cloudwatch/2010-08-01/waiters-2.json)
* [EC2](https://github.com/boto/botocore/blob/develop/botocore/data/ec2/2016-11-15/waiters-2.json)
* [EKS](https://github.com/boto/botocore/blob/develop/botocore/data/eks/2017-11-01/waiters-2.json)

Below is an example of a working waiter model config that will make an EKS waiter which will wait for
all Nodegroups in a cluster to be deleted. An explanation follows the code snippet.
all Nodegroups in a cluster to be deleted. An explanation follows the code snippet. Note the backticks
to escape the integers in the "argument" values.

```python
eks_waiter_model_config = {
```json
{
"version": 2,
"waiters": {
"all_nodegroups_deleted": {
Expand All @@ -53,20 +51,19 @@ eks_waiter_model_config = {
"acceptors": [
{
"matcher": "path",
# Note the backticks to escape the integer value.
"argument": "length(nodegroups[]) == `0`",
"expected": True,
"state": "success",
"expected": true,
"state": "success"
},
{
"matcher": "path",
"expected": True,
"expected": true,
"argument": "length(nodegroups[]) > `0`",
"state": "retry",
},
],
"state": "retry"
}
]
}
},
}
}
```

Expand All @@ -81,44 +78,22 @@ state does not go to `success` before the maxAttempts number of tries, the waite
WaiterException. Both `retry` and `maxAttempts` can be overridden by the user when calling
`waiter.wait()` like any other waiter.

Using the above waiter will look like this:
`EksHook().get_waiter("all_nodegroups_deleted").wait(clusterName="my_cluster")`


#### 2: Create the waiter class


Create the waiter class in the format of {Service}Waiter and inherit BaseBotoWaiter. The init method
for the new class should call `super().__init__` and provide the client (likely `hook.conn`) and the
model config dictionary (created above).
### That's It!

```python
class EksBotoWaiter(BaseBotoWaiter):
def __init__(self, client: BaseAwsConnection):
super().__init__(client=client, model_config=eks_waiter_model_config)
```

### In the hook file

#### Add a get_waiter method

For example, in hooks/eks.py, import the above EksBotoWaiter and add the following method to the EksHook class:
The AwsBaseHook handles the rest. Using the above waiter will look like this:
`EksHook().get_waiter("all_nodegroups_deleted").wait(clusterName="my_cluster")`
and for testing purposes, a `list_custom_waiters()` helper method is proved which can
be used the same way: `EksHook().list_custom_waiters()`

```python
def get_waiter(self, waiter_name):
return EksBotoWaiter(client=self.conn).waiter(waiter_name)
```

### In your Operators (How to use these)

Once configured correctly, the custom waiter will be nearly indistinguishable from an official waiter.
Below is an example of an official waiter followed by a custom one.

```python
eks_hook.conn.get_waiter("nodegroup_deleted").wait(
clusterName=self.cluster_name, nodegroupName=self.nodegroup_name
)
eks_hook.get_waiter("all_nodegroups_deleted").wait(clusterName=self.cluster_name)
EksHook().conn.get_waiter("nodegroup_deleted").wait(clusterName=cluster_name, nodegroupName=nodegroup_name)
EksHook().get_waiter("all_nodegroups_deleted").wait(clusterName=cluster_name)
```

Note that since the get_waiter is in the hook instead of on the client side, a custom waiter is
Expand Down
5 changes: 2 additions & 3 deletions airflow/providers/amazon/aws/waiters/base_waiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,9 @@

from __future__ import annotations

import boto3
from botocore.waiter import Waiter, WaiterModel, create_waiter_with_client

from airflow.providers.amazon.aws.hooks.base_aws import BaseAwsConnection


class BaseBotoWaiter:
"""
Expand All @@ -29,7 +28,7 @@ class BaseBotoWaiter:
For more details, see airflow/providers/amazon/aws/waiters/README.md
"""

def __init__(self, client: BaseAwsConnection, model_config: dict) -> None:
def __init__(self, client: boto3.client, model_config: dict) -> None:
self.model = WaiterModel(model_config)
self.client = client

Expand Down
24 changes: 24 additions & 0 deletions airflow/providers/amazon/aws/waiters/eks.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
{
"version": 2,
"waiters": {
"all_nodegroups_deleted": {
"operation": "ListNodegroups",
"delay": 30,
"maxAttempts": 60,
"acceptors": [
{
"matcher": "path",
"argument": "length(nodegroups[]) == `0`",
"expected": true,
"state": "success"
},
{
"matcher": "path",
"expected": true,
"argument": "length(nodegroups[]) > `0`",
"state": "retry"
}
]
}
}
}
58 changes: 0 additions & 58 deletions airflow/providers/amazon/aws/waiters/eks.py

This file was deleted.

1 change: 1 addition & 0 deletions dev/provider_packages/MANIFEST_TEMPLATE.in.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

{% if PROVIDER_PACKAGE_ID == 'amazon' %}
include airflow/providers/amazon/aws/hooks/batch_waiters.json
include airflow/providers/amazon/aws/waiters/*.json
{% elif PROVIDER_PACKAGE_ID == 'google' %}
include airflow/providers/google/cloud/example_dags/*.yaml
include airflow/providers/google/cloud/example_dags/*.sql
Expand Down
16 changes: 16 additions & 0 deletions tests/providers/amazon/aws/waiters/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
Loading

0 comments on commit 617665d

Please sign in to comment.