Skip to content

Commit

Permalink
Add future-compatible mongo Hook typing (#31289)
Browse files Browse the repository at this point in the history
We are migrating to new mongo library soon and it has typing
support. This one adds propert typing that will prevent
MyPy from failing once mongo gets upgraded.
  • Loading branch information
potiuk authored May 15, 2023
1 parent c605ef0 commit 0117246
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 5 deletions.
5 changes: 4 additions & 1 deletion airflow/providers/amazon/aws/transfers/mongo_to_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from typing import TYPE_CHECKING, Any, Iterable, Sequence, cast

from bson import json_util
from pymongo.command_cursor import CommandCursor
from pymongo.cursor import Cursor

from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
Expand Down Expand Up @@ -96,7 +98,7 @@ def execute(self, context: Context):

# Grab collection and execute query according to whether or not it is a pipeline
if self.is_pipeline:
results = MongoHook(self.mongo_conn_id).aggregate(
results: CommandCursor[Any] | Cursor = MongoHook(self.mongo_conn_id).aggregate(
mongo_collection=self.mongo_collection,
aggregate_query=cast(list, self.mongo_query),
mongo_db=self.mongo_db,
Expand All @@ -109,6 +111,7 @@ def execute(self, context: Context):
query=cast(dict, self.mongo_query),
projection=self.mongo_projection,
mongo_db=self.mongo_db,
find_one=False,
)

# Performs transform then stringifies the docs results into json format
Expand Down
30 changes: 28 additions & 2 deletions airflow/providers/mongo/hooks/mongo.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@

from ssl import CERT_NONE
from types import TracebackType
from typing import Any, overload
from urllib.parse import quote_plus, urlunsplit

import pymongo
from pymongo import MongoClient, ReplaceOne

from airflow.hooks.base import BaseHook
from airflow.typing_compat import Literal


class MongoHook(BaseHook):
Expand Down Expand Up @@ -56,7 +58,7 @@ def __init__(self, conn_id: str = default_conn_name, *args, **kwargs) -> None:
self.mongo_conn_id = conn_id
self.connection = self.get_connection(conn_id)
self.extras = self.connection.extra_dejson.copy()
self.client = None
self.client: MongoClient | None = None
self.uri = self._create_uri()

def __enter__(self):
Expand Down Expand Up @@ -134,15 +136,39 @@ def aggregate(

return collection.aggregate(aggregate_query, **kwargs)

@overload
def find(
self,
mongo_collection: str,
query: dict,
find_one: bool = False,
find_one: Literal[False],
mongo_db: str | None = None,
projection: list | dict | None = None,
**kwargs,
) -> pymongo.cursor.Cursor:
...

@overload
def find(
self,
mongo_collection: str,
query: dict,
find_one: Literal[True],
mongo_db: str | None = None,
projection: list | dict | None = None,
**kwargs,
) -> Any | None:
...

def find(
self,
mongo_collection: str,
query: dict,
find_one: bool = False,
mongo_db: str | None = None,
projection: list | dict | None = None,
**kwargs,
) -> pymongo.cursor.Cursor | Any | None:
"""
Runs a mongo find query and returns the results
https://pymongo.readthedocs.io/en/stable/api/pymongo/collection.html#pymongo.collection.Collection.find
Expand Down
12 changes: 10 additions & 2 deletions tests/providers/amazon/aws/transfers/test_mongo_to_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,11 @@ def test_execute(self, mock_s3_hook, mock_mongo_hook):
operator.execute(None)

mock_mongo_hook.return_value.find.assert_called_once_with(
mongo_collection=MONGO_COLLECTION, query=MONGO_QUERY, mongo_db=None, projection=None
mongo_collection=MONGO_COLLECTION,
query=MONGO_QUERY,
find_one=False,
mongo_db=None,
projection=None,
)

op_stringify = self.mock_operator._stringify
Expand All @@ -119,7 +123,11 @@ def test_execute_compress(self, mock_s3_hook, mock_mongo_hook):
operator.execute(None)

mock_mongo_hook.return_value.find.assert_called_once_with(
mongo_collection=MONGO_COLLECTION, query=MONGO_QUERY, mongo_db=None, projection=None
mongo_collection=MONGO_COLLECTION,
query=MONGO_QUERY,
find_one=False,
mongo_db=None,
projection=None,
)

op_stringify = self.mock_operator._stringify
Expand Down

0 comments on commit 0117246

Please sign in to comment.