Skip to content

Commit

Permalink
Added ordering key option for PubSubPublishMessageOperator GCP Operat…
Browse files Browse the repository at this point in the history
…or (apache#39955)

* feature/gcp-pubsub-operator-ordering-key

* fix provider checks test

---------

Co-authored-by: Mehdi GATI <mehdi_gati@ext.carrefour.com>
  • Loading branch information
2 people authored and romsharon98 committed Jul 26, 2024
1 parent 4257dd6 commit 0c98c5f
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 3 deletions.
11 changes: 10 additions & 1 deletion airflow/providers/google/cloud/hooks/pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
from google.cloud.exceptions import NotFound
from google.cloud.pubsub_v1 import PublisherClient, SubscriberClient
from google.cloud.pubsub_v1.types import PublisherOptions
from google.pubsub_v1.services.subscriber.async_client import SubscriberAsyncClient
from googleapiclient.errors import HttpError

Expand Down Expand Up @@ -79,6 +80,7 @@ def __init__(
self,
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
enable_message_ordering: bool = False,
**kwargs,
) -> None:
if kwargs.get("delegate_to") is not None:
Expand All @@ -90,6 +92,7 @@ def __init__(
gcp_conn_id=gcp_conn_id,
impersonation_chain=impersonation_chain,
)
self.enable_message_ordering = enable_message_ordering
self._client = None

def get_conn(self) -> PublisherClient:
Expand All @@ -99,7 +102,13 @@ def get_conn(self) -> PublisherClient:
:return: Google Cloud Pub/Sub client object.
"""
if not self._client:
self._client = PublisherClient(credentials=self.get_credentials(), client_info=CLIENT_INFO)
self._client = PublisherClient(
credentials=self.get_credentials(),
client_info=CLIENT_INFO,
publisher_options=PublisherOptions(
enable_message_ordering=self.enable_message_ordering,
),
)
return self._client

@cached_property
Expand Down
18 changes: 18 additions & 0 deletions airflow/providers/google/cloud/operators/pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,6 +604,7 @@ class PubSubPublishMessageOperator(GoogleCloudBaseOperator):
m1 = {"data": b"Hello, World!", "attributes": {"type": "greeting"}}
m2 = {"data": b"Knock, knock"}
m3 = {"attributes": {"foo": ""}}
m4 = {"data": b"Who's there?", "attributes": {"ordering_key": "knock_knock"}}
t1 = PubSubPublishMessageOperator(
project_id="my-project",
Expand All @@ -613,6 +614,15 @@ class PubSubPublishMessageOperator(GoogleCloudBaseOperator):
dag=dag,
)
t2 = PubSubPublishMessageOperator(
project_id="my-project",
topic="my_topic",
messages=[m4],
create_topic=True,
enable_message_ordering=True,
dag=dag,
)
``project_id``, ``topic``, and ``messages`` are templated so you can use Jinja templating
in their values.
Expand All @@ -632,6 +642,10 @@ class PubSubPublishMessageOperator(GoogleCloudBaseOperator):
https://cloud.google.com/pubsub/docs/reference/rest/v1/PubsubMessage
:param gcp_conn_id: The connection ID to use connecting to
Google Cloud.
:param enable_message_ordering: If true, messages published with the same
ordering_key in PubsubMessage will be delivered to the subscribers in the order
in which they are received by the Pub/Sub system. Otherwise, they may be
delivered in any order. Default is False.
:param impersonation_chain: Optional service account to impersonate using short-term
credentials, or chained list of accounts required to get the access_token
of the last account in the list, which will be impersonated in the request.
Expand All @@ -646,6 +660,7 @@ class PubSubPublishMessageOperator(GoogleCloudBaseOperator):
"project_id",
"topic",
"messages",
"enable_message_ordering",
"impersonation_chain",
)
ui_color = "#0273d4"
Expand All @@ -657,6 +672,7 @@ def __init__(
messages: list,
project_id: str = PROVIDE_PROJECT_ID,
gcp_conn_id: str = "google_cloud_default",
enable_message_ordering: bool = False,
impersonation_chain: str | Sequence[str] | None = None,
**kwargs,
) -> None:
Expand All @@ -665,12 +681,14 @@ def __init__(
self.topic = topic
self.messages = messages
self.gcp_conn_id = gcp_conn_id
self.enable_message_ordering = enable_message_ordering
self.impersonation_chain = impersonation_chain

def execute(self, context: Context) -> None:
hook = PubSubHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
enable_message_ordering=self.enable_message_ordering,
)

self.log.info("Publishing to topic %s", self.topic)
Expand Down
9 changes: 7 additions & 2 deletions tests/providers/google/cloud/hooks/test_pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from google.api_core.exceptions import AlreadyExists, GoogleAPICallError
from google.api_core.gapic_v1.method import DEFAULT
from google.cloud.exceptions import NotFound
from google.cloud.pubsub_v1.types import ReceivedMessage
from google.cloud.pubsub_v1.types import PublisherOptions, ReceivedMessage
from googleapiclient.errors import HttpError

from airflow.providers.google.cloud.hooks.pubsub import PubSubAsyncHook, PubSubException, PubSubHook
Expand Down Expand Up @@ -86,7 +86,12 @@ def setup_method(self):
def test_publisher_client_creation(self, mock_client, mock_get_creds):
assert self.pubsub_hook._client is None
result = self.pubsub_hook.get_conn()
mock_client.assert_called_once_with(credentials=mock_get_creds.return_value, client_info=CLIENT_INFO)

mock_client.assert_called_once_with(
credentials=mock_get_creds.return_value,
client_info=CLIENT_INFO,
publisher_options=PublisherOptions(enable_message_ordering=False),
)
assert mock_client.return_value == result
assert self.pubsub_hook._client == result

Expand Down
18 changes: 18 additions & 0 deletions tests/providers/google/cloud/operators/test_pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@
{"data": b"Knock, knock"},
{"attributes": {"foo": ""}},
]
TEST_MESSAGES_ORDERING_KEY = [
{"data": b"Hello, World!", "attributes": {"ordering_key": "key"}},
]


class TestPubSubTopicCreateOperator:
Expand Down Expand Up @@ -235,6 +238,21 @@ def test_publish(self, mock_hook):
project_id=TEST_PROJECT, topic=TEST_TOPIC, messages=TEST_MESSAGES
)

@mock.patch("airflow.providers.google.cloud.operators.pubsub.PubSubHook")
def test_publish_with_ordering_key(self, mock_hook):
operator = PubSubPublishMessageOperator(
task_id=TASK_ID,
project_id=TEST_PROJECT,
topic=TEST_TOPIC,
messages=TEST_MESSAGES_ORDERING_KEY,
enable_message_ordering=True,
)

operator.execute(None)
mock_hook.return_value.publish.assert_called_once_with(
project_id=TEST_PROJECT, topic=TEST_TOPIC, messages=TEST_MESSAGES_ORDERING_KEY
)


class TestPubSubPullOperator:
def _generate_messages(self, count):
Expand Down

0 comments on commit 0c98c5f

Please sign in to comment.