From ef12228701d66d4a52eb867891c914cd4f09a002 Mon Sep 17 00:00:00 2001 From: uzhastik Date: Fri, 9 Feb 2024 13:10:58 +0300 Subject: [PATCH 01/34] initial commit --- airflow/example_dags/airflow_dag_yq.py | 220 +++++++++++++++++ airflow/providers/yandex/hooks/yandex.py | 65 ++++- .../providers/yandex/hooks/yandexcloud_yq.py | 228 ++++++++++++++++++ .../yandex/operators/yandexcloud_yq.py | 134 ++++++++++ airflow/providers/yandex/triggers/__init__.py | 16 ++ .../yandex/triggers/yandexcloud_yq.py | 93 +++++++ 6 files changed, 753 insertions(+), 3 deletions(-) create mode 100644 airflow/example_dags/airflow_dag_yq.py create mode 100644 airflow/providers/yandex/hooks/yandexcloud_yq.py create mode 100644 airflow/providers/yandex/operators/yandexcloud_yq.py create mode 100644 airflow/providers/yandex/triggers/__init__.py create mode 100644 airflow/providers/yandex/triggers/yandexcloud_yq.py diff --git a/airflow/example_dags/airflow_dag_yq.py b/airflow/example_dags/airflow_dag_yq.py new file mode 100644 index 0000000000000..8714336bc0c90 --- /dev/null +++ b/airflow/example_dags/airflow_dag_yq.py @@ -0,0 +1,220 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Example DAG demonstrating the usage of the BashOperator.""" +from __future__ import annotations + +import datetime + +import pendulum +import dateutil + +from airflow.models.dag import DAG +from airflow.operators.bash import BashOperator +from airflow.operators.empty import EmptyOperator +from airflow.providers.yandex.operators.yandexcloud_yq import YQExecuteQueryOperator +from airflow.operators.python import PythonOperator +from airflow.decorators import task +import base64 +# import airflow.providers.yandex.operators.yandexcloud_dataproc + +with DAG( + dag_id="yq_operator", + # schedule="@daily", + schedule_interval='30 2 * * *', + start_date=pendulum.datetime(2023, 1, 16, 19, 15, tz="UTC"), + catchup=False, + dagrun_timeout=datetime.timedelta(minutes=60) +) as dag: + run_this_last = EmptyOperator( + task_id="run_this_last", + ) + + # # [START howto_operator_bash] + # run_this = BashOperator( + # task_id="run_after_loop", + # bash_command="echo 1", + # ) + # # [END howto_operator_bash] + + # run_this >> run_this_last + + # # [END howto_operator_bash_template] + # also_run_this >> run_this_last + + # @task + # def yq_read_data(): + # operator = YQExecuteQueryOperator(task_id="samplequery2", sql="select 22 as d, 33 as t") + # return operator.execute() + + # @task + # def ydb_write_data(yq_data): + # ydb_bulk_insert = YDBBulkUpsertOperator(task_id="bulk_insert", + # endpoint="grpcs://ydb.serverless.yandexcloud.net:2135", + # database="/ru-central1/b1g8skpblkos03malf3s/etndta0jk4us20e557i7", + # table="my_table", + # column_types={"id": ydb.PrimitiveType.Uint64, + # "name": ydb.PrimitiveType.Utf8}, + # values=[{"id":1,"name":"v"}]) + # return ydb_bulk_insert.execute() + + # data = yq_read_data() + # ydb_write_data(data) + + def base64ToString(b): + return base64.b64decode(b).decode('utf-8') + + def get_column_index(columns, name): + for index, column in enumerate(columns): + if column["name"] == name: + return index + + def get_col_by_name(row, columns, name): + index = get_column_index(columns, name) + value = row[index] + # print(value) + return value + + def process_query_count_result(**kwargs): + ti = kwargs['ti'] + result = ti.xcom_pull(task_ids='get_queries_count') + # print(result) + + print(f"Incoming rows={result['rows']}") + + def process_result(**kwargs): + ti = kwargs['ti'] + result = ti.xcom_pull(task_ids='samplequery2') + + print(f"Incoming rows={result['rows']}") + + + # ydb_bulk_insert >> run_this_last + + query = """ +$parse_ingress_bytes = ($m) -> { + $t = "IngressBytes: ["; + $start = Find($m, $t); + $end = Find($m, "]", $start); + return Cast(Substring($m, $start+LEN($t), $end-$start-LEN($t)) as Uint64); + }; + +$parse_folder_id = ($m) -> { + $t = "scope: [yandexcloud://"; + $start = Find($m, $t); + $end = Find($m, "]", $start); + return Substring($m, $start+LEN($t), $end-$start-LEN($t)); + }; + +$parse_status = ($m) -> { + $t = ", status:"; + $start = Find($m, $t); + $end = LEN($m); + return String::Strip(Substring($m, $start+LEN($t), $end-$start-LEN($t))); + }; + +select * from ( +select `@timestamp` as ts, $parse_ingress_bytes(message) as ingress_bytes, $parse_folder_id(message) as folder_id, $parse_status(message) as status from (select + `yq_prod_logs_cold_projected`.* +FROM + `yq_prod_logs_cold_projected`) +where component="YQ_AUDIT" and message like "FinalStatus%" +and message like "%IngressBytes%" +and `date` between Date("{{ data_interval_start | ds }}") and Date("{{ data_interval_end | ds }}") +) +where COALESCE(folder_id,"") != "" +limit 1000; + +""" + + yq_operator2 = YQExecuteQueryOperator(task_id="samplequery2", sql=query, connection_id="yandexcloud_default") + yq_operator2 >> run_this_last + + query_count_queries = """ +$parse_folder_id = ($m) -> { + $t = "scope: [yandexcloud://"; + $start = Find($m, $t); + $end = Find($m, "]", $start); + return Substring($m, $start+LEN($t), $end-$start-LEN($t)); + }; + +$parse_status = ($m) -> { + $t = ", status:"; + $start = Find($m, $t); + $end = LEN($m); + return String::Strip(Substring($m, $start+LEN($t), $end-$start-LEN($t))); + }; + +$parse_query_id = ($m) -> { + $t = "query id: ["; + $start = Find($m, $t); + $end = Find($m, "]", $start); + return Substring($m, $start+LEN($t), $end-$start-LEN($t)); + }; + +select * from ( +select `@timestamp` as ts, $parse_folder_id(message) as folder_id, $parse_status(message) as status, $parse_query_id(message) as query_id from (select + `yq_prod_logs_cold_projected`.* +FROM + `yq_prod_logs_cold_projected`) +where component="YQ_AUDIT" and message like "FinalStatus%" +and `date` between Date("{{ data_interval_start | ds }}") and Date("{{ data_interval_end | ds }}") +) +where COALESCE(folder_id,"") != "" +limit 1000; + +""" + + process_query_count_task = PythonOperator( task_id='process_query_count_result', + python_callable=process_query_count_result, + provide_context=True) + + yq_operator_queries_count = YQExecuteQueryOperator(task_id="get_queries_count", sql=query_count_queries, connection_id="yandexcloud_default") + yq_operator_queries_count >> process_query_count_task + + # yq_operator3 = YQExecuteQueryOperator(task_id="samplequery3", sql="select 33 as d, 44 as t") + # yq_operator3 >> run_this_last + + # yq_operator4 = YQExecuteQueryOperator(task_id="samplequery4", sql="select 33 as d, 44 as t") + # # yq_operator4 >> ydb_bulk_insert + + + # yq_operator = YQExecuteQueryOperator(task_id="samplequery", sql="select 1") + # yq_operator >> yq_operator2 + + + process_result_task = PythonOperator( + task_id='process_result', + python_callable=process_result, + provide_context=True) + + yq_operator2 >> process_result_task + + + + + # # [START howto_operator_bash_skip] + # this_will_skip = BashOperator( + # task_id="this_will_skip", + # bash_command='echo "hello world"; exit 99;', + # dag=dag, + # ) + # # [END howto_operator_bash_skip] + # this_will_skip >> run_this_last + +if __name__ == "__main__": + dag.test() diff --git a/airflow/providers/yandex/hooks/yandex.py b/airflow/providers/yandex/hooks/yandex.py index aa9cf4302ebfd..5290816c464a1 100644 --- a/airflow/providers/yandex/hooks/yandex.py +++ b/airflow/providers/yandex/hooks/yandex.py @@ -16,12 +16,17 @@ # under the License. from __future__ import annotations +import json import warnings +import requests +import time +import jwt +from requests.packages.urllib3.util.retry import Retry from typing import Any import yandexcloud -from airflow.exceptions import AirflowProviderDeprecationWarning +from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.hooks.base import BaseHook from airflow.providers.yandex.utils.credentials import ( get_credentials, @@ -132,13 +137,13 @@ def __init__( self.connection_id = yandex_conn_id or connection_id or default_conn_name self.connection = self.get_connection(self.connection_id) self.extras = self.connection.extra_dejson - credentials = get_credentials( + self.credentials = get_credentials( oauth_token=self._get_field("oauth"), service_account_json=self._get_field("service_account_json"), service_account_json_path=self._get_field("service_account_json_path"), ) sdk_config = self._get_endpoint() - self.sdk = yandexcloud.SDK(user_agent=provider_user_agent(), **sdk_config, **credentials) + self.sdk = yandexcloud.SDK(user_agent=provider_user_agent(), **sdk_config, **self.credentials) self.default_folder_id = default_folder_id or self._get_field("folder_id") self.default_public_ssh_key = default_public_ssh_key or self._get_field("public_ssh_key") self.default_service_account_id = default_service_account_id or get_service_account_id( @@ -158,3 +163,57 @@ def _get_field(self, field_name: str, default: Any = None) -> Any: if not hasattr(self, "extras"): return default return get_field_from_extras(self.extras, field_name, default) + + def get_iam_token(self) -> str: + if "oauth" in self.credentials: + return YandexCloudBaseHook._resolve_oauth(self.credentials["oauth"]) + if "service_account_key" in self.credentials: + return YandexCloudBaseHook._resolve_service_account_key(self.credentials["service_account_key"]) + raise AirflowException(f"Unknown credentials type {self.credentials.keys()}") + + @staticmethod + def _resolve_oauth(self, token: str) -> str: + pass + + @staticmethod + def _resolve_service_account_key(sa_info) -> str: + session = YandexCloudBaseHook.create_session() + + api = 'https://iam.api.cloud.yandex.net/iam/v1/tokens' + now = int(time.time()) + payload = { + 'aud': api, + 'iss': sa_info["service_account_id"], + 'iat': now, + 'exp': now + 360} + + encoded_token = jwt.encode( + payload, + sa_info["private_key"], + algorithm='PS256', + headers={'kid': sa_info["id"]}) + + data = {"jwt": encoded_token} + iam_response = session.post(api, json=data) + iam_response.raise_for_status() + + return iam_response.json()["iamToken"] + + @staticmethod + def create_session() -> requests.Session: + session = requests.Session() + session.verify = False + retry = Retry( + backoff_factor=0.3, + total=10 + ) + session.mount( + 'http://', + requests.adapters.HTTPAdapter(max_retries=retry) + ) + session.mount( + 'https://', + requests.adapters.HTTPAdapter(max_retries=retry) + ) + + return session diff --git a/airflow/providers/yandex/hooks/yandexcloud_yq.py b/airflow/providers/yandex/hooks/yandexcloud_yq.py new file mode 100644 index 0000000000000..3d423293171f2 --- /dev/null +++ b/airflow/providers/yandex/hooks/yandexcloud_yq.py @@ -0,0 +1,228 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import logging + +# These two lines enable debugging at httplib level (requests->urllib3->http.client) +# You will see the REQUEST, including HEADERS and DATA, and RESPONSE with HEADERS but without DATA. +# The only thing missing will be the response.body which is not logged. +try: + import http.client as http_client +except ImportError: + # Python 2 + import httplib as http_client +http_client.HTTPConnection.debuglevel = 1 + +# You must initialize logging, otherwise you'll not see debug output. +logging.basicConfig() +logging.getLogger().setLevel(logging.DEBUG) +requests_log = logging.getLogger("requests.packages.urllib3") +requests_log.setLevel(logging.DEBUG) +requests_log.propagate = True + +from airflow.providers.yandex.hooks.yandex import YandexCloudBaseHook +from airflow.exceptions import AirflowException +import requests +from requests.packages.urllib3.util.retry import Retry +from enum import Enum +import time +from datetime import timedelta, datetime +import aiohttp + +class QueryType(Enum): + ANALYTICS = 1 + STREAMING = 2 + +class YQHook(YandexCloudBaseHook): + """ + A base hook for Yandex.Cloud Data Proc. + + :param yandex_conn_id: The connection ID to use when fetching connection info. + """ + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + self.query_id: str | None = None + + + def start_execute_query(self, query_type: QueryType, query_text: str|None, name: str|None=None, description: str | None = None) -> str: + # self.default_folder_id + type = "ANALYTICS" if query_type == QueryType.ANALYTICS else "STREAMING" + + with YQHook.create_session(self.get_iam_token()) as session: + + data = {"name": name, + "type": type, + "text": query_text, + "description": description} + self.log.info(f"folder={self.default_folder_id}") + response = session.post(f"https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries?project={self.default_folder_id}", + json=data) + response.raise_for_status() + + self.query_id = response.json()["id"] + + return self.query_id + + def wait_for_query_to_complete(self, execution_timeout: timedelta): + status = None + try: + status = self.wait_results(self.query_id, execution_timeout) + except TimeoutError as err: + self.stop_query(self.query_id) + raise + + def get_query_result(self, query_id): + self.log.info(f"get_query_result query_id={query_id}") + query_info = self.get_queryinfo(query_id) + if query_info["status"] == "FAILED": + issues = query_info["issues"] + raise RuntimeError("Query failed", issues=issues) + + result_set_count = len(query_info["result_sets"]) + self.log.debug(f"result set count {result_set_count}") + + query_results = self.query_results(query_id, result_set_count) + self.log.debug(query_results) + return query_results + + def get_pandas_df(self)-> pd.DataFrame: + return pd.DataFrame() + + def query_results(self, query_id:str, result_set_count:int)->object: + results = list() + limit = 1000 + offset = 0 + + iam_token = self.get_iam_token() + with YQHook.create_session(iam_token) as session: + for result_index in range(0, result_set_count): + columns = None + rows = [] + while True: + print(f"limit={limit} offset={offset}") + response = session.get(f'https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries/{query_id}/results/{result_index}?project={self.default_folder_id}&limit={limit}&offset={offset}') + response.raise_for_status() + + qresults = response.json() + print(qresults) + if columns is None: + columns = qresults["columns"] + + rows.extend( qresults["rows"]) + if len(qresults["rows"]) != limit: + break + else: + offset += limit + + results.append({"rows":rows, "columns": columns}) + + if len(results) == 1: + return results[0] + else: + return results + + def stop_current_query(self)->None: + self.stop_query(self.query_id) + + def get_queryinfo(self, query_id:str)->object: + iam_token = self.get_iam_token() + with YQHook.create_session(iam_token) as session: + response = session.get(f'https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries/{query_id}?project={self.default_folder_id}') + response.raise_for_status() + print(response.json()) + + return response.json() + + def stop_query(self, query_id:str)->None: + iam_token = self.get_iam_token() + with YQHook.create_session(iam_token) as session: + session.get(f'https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries/{query_id}/stop?project={self.default_folder_id}') + + def get_query_status(self, query_id:str)->str: + iam_token = self.get_iam_token() + with YQHook.create_session(iam_token) as session: + response = session.get(f'https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries/{query_id}/status?project={self.default_folder_id}') + response.raise_for_status() + status = response.json()["status"] + return status + + async def get_query_status_async(self, query_id:str)->str: + iam_token = self.get_iam_token() + + headers = YQHook.get_request_url_header_params(iam_token) + url = f'https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries/{query_id}/status?project={self.default_folder_id}' + + self.log.info(f"Retrieving status for query id {query_id}, url {url}") + + async with aiohttp.ClientSession(headers=headers) as session: + async with session.get(url) as response: + status_code = response.status + assert status_code == 200 + resp = await response.json() + return resp["status"] + + @staticmethod + def get_request_url_header_params(iam_token: str|None=None)->dict[str,str]: + print(f"get_request_url_header_params iam_token={iam_token}") + headers = {} + if iam_token is not None: + headers['Authorization'] = f"{iam_token}" + + return headers + + def wait_results(self, query_id:str, execution_timeout: timedelta)->str: + execution_timeout = execution_timeout if execution_timeout is not None else timedelta(minutes=30) + + start = datetime.now() + while True: + if datetime.now() > start + execution_timeout: + raise TimeoutError("Query execution timeout") + + status = self.get_query_status(query_id) + if status not in ["RUNNING", "PENDING"]: + return status + + time.sleep(2) + + @staticmethod + def create_session(iamtoken: str | None = None) -> requests.Session: + session = requests.Session() + session.verify = False + session.timeout = 20 + retry = Retry( + backoff_factor=0.3, + total=10 + ) + session.mount( + 'http://', + requests.adapters.HTTPAdapter(max_retries=retry) + ) + session.mount( + 'https://', + requests.adapters.HTTPAdapter(max_retries=retry) + ) + + headers = YQHook.get_request_url_header_params(iamtoken) + for k, v in headers.items(): + session.headers[k] = v + + return session + + diff --git a/airflow/providers/yandex/operators/yandexcloud_yq.py b/airflow/providers/yandex/operators/yandexcloud_yq.py new file mode 100644 index 0000000000000..3a0de480791ea --- /dev/null +++ b/airflow/providers/yandex/operators/yandexcloud_yq.py @@ -0,0 +1,134 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import warnings +from dataclasses import dataclass +from typing import TYPE_CHECKING, Sequence +from datetime import timedelta +from airflow.configuration import conf +import pandas as pd + + +from airflow.exceptions import AirflowException +from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator +from airflow.providers.yandex.hooks.yandexcloud_yq import YQHook, QueryType +from airflow.providers.yandex.triggers.yandexcloud_yq import YQQueryStatusTrigger + +if TYPE_CHECKING: + from airflow.utils.context import Context + +class YQExecuteQueryOperator(SQLExecuteQueryOperator): + """ + Executes sql code using a specific Trino query Engine. + + :param sql: the SQL code to be executed as a single string, or + a list of str (sql statements), or a reference to a template file. + """ + + template_fields: Sequence[str] = ("sql",) + template_fields_renderers = {"sql": "sql"} + template_ext: Sequence[str] = (".sql",) + ui_color = "#ededed" + + def __init__( + self, + *, + type: QueryType = QueryType.ANALYTICS, + name: str | None = None, + description: str | None = None, + folder_id: str | None = None, + connection_id: str | None = None, + public_ssh_key: str | None = None, + service_account_id: str | None = None, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.name = name + self.description = description + self.type = type + self.deferrable = deferrable + self.folder_id = folder_id + self.connection_id = connection_id + self.public_ssh_key = public_ssh_key + self.service_account_id = service_account_id + + self.hook: YQHook | None = None + + def execute(self, context: Context) -> None: + self.hook = YQHook( + yandex_conn_id=self.connection_id, + default_folder_id=self.folder_id, + default_public_ssh_key=self.public_ssh_key, + default_service_account_id=self.service_account_id + ) + + self.hook.start_execute_query(self.type, self.sql, self.name, self.description) + + # value = [1] + # # Deprecated + # context["task_instance"].xcom_push(key="query_result", value=value) + # return value + + self.log.info(f"deffered is allowed [{self.deferrable}]") + + if self.deferrable: + # if True: + self.defer( + trigger=YQQueryStatusTrigger( + poll_interval=timedelta(seconds=2).seconds, + query_id=self.hook.query_id, + connection_id=self.connection_id, + folder_id=self.folder_id, + public_ssh_key=self.public_ssh_key, + service_account_id=self.service_account_id + ), + method_name="execute_complete", + ) + else: + self.hook.wait_for_query_to_complete(self.execution_timeout) + return self.hook.get_query_result(self.hook.query_id) + + def execute_complete(self, context: Context, event: dict[str, str | list[str]] | None = None) -> None: + if "status" in event and event["status"]!="COMPLETED": + msg = None + if 'message' in event: + msg = f"{event['status']}: {event['message']}" + else: + msg = event["status"] + raise AirflowException(msg) + else: + query_id = event["query_id"] + + hook = YQHook( connection_id=event["connection_id"], + default_folder_id=event["folder_id"], + default_public_ssh_key=event["public_ssh_key"], + default_service_account_id=event["service_account_id"]) + + result = hook.get_query_result(query_id) + self.log.info("%s completed successfully.", self.task_id) + return result + + def on_kill(): + if self.hook is not None: + self.hook.stop_current_query() + + @staticmethod + def to_dataframe(data): + column_names = [column["name"] for column in data["columns"]] + return pd.DataFrame(data["rows"], columns= column_names) diff --git a/airflow/providers/yandex/triggers/__init__.py b/airflow/providers/yandex/triggers/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/airflow/providers/yandex/triggers/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/airflow/providers/yandex/triggers/yandexcloud_yq.py b/airflow/providers/yandex/triggers/yandexcloud_yq.py new file mode 100644 index 0000000000000..1ff41d2a09c21 --- /dev/null +++ b/airflow/providers/yandex/triggers/yandexcloud_yq.py @@ -0,0 +1,93 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import asyncio +from typing import TYPE_CHECKING, Any, AsyncIterator + +from airflow.providers.yandex.hooks.yandexcloud_yq import YQHook, QueryType +from airflow.triggers.base import BaseTrigger, TriggerEvent +import traceback + +if TYPE_CHECKING: + from datetime import timedelta + + +class YQQueryStatusTrigger(BaseTrigger): + + def __init__( + self, + poll_interval: float, + query_id: str, + folder_id: str | None = None, + connection_id: str | None = None, + public_ssh_key: str | None = None, + service_account_id: str | None = None, + + ): + super().__init__() + self.poll_interval = poll_interval + self.query_id = query_id + self.connection_id = connection_id + self.folder_id = folder_id + self.public_ssh_key = public_ssh_key + self.service_account_id = service_account_id + + def serialize(self) -> tuple[str, dict[str, Any]]: + return ( + "airflow.providers.yandex.triggers.yandexcloud_yq.YQQueryStatusTrigger", + { + "poll_interval": self.poll_interval, + "query_id": self.query_id, + "connection_id": self.connection_id, + "folder_id": self.folder_id, + "public_ssh_key": self.public_ssh_key, + "service_account_id": self.service_account_id + }, + ) + + async def run(self) -> AsyncIterator[TriggerEvent]: + try: + while True: + status = await self.get_query_status(self.query_id) + if status not in ["RUNNING", "PENDING"]: + break + await asyncio.sleep(self.poll_interval) + + yield TriggerEvent( + { + "status": status, + "query_id": self.query_id, + "folder_id": self.folder_id, + "public_ssh_key": self.public_ssh_key, + "service_account_id": self.service_account_id, + "connection_id": self.connection_id + } + ) + except Exception as e: + message = f"{str(e)} trace={traceback.format_exc()}" + yield TriggerEvent({"status": "error", "message": message}) + + async def get_query_status(self, query_id: str) -> dict[str, Any]: + """Return True if the SQL query is still running otherwise return False.""" + hook = YQHook( + connection_id=self.connection_id + ) + return await hook.get_query_status_async(query_id) + + def _set_context(self, context): + pass From ae4fab2190e73b25ebb71f02ae08a7386893b300 Mon Sep 17 00:00:00 2001 From: uzhastik Date: Fri, 9 Feb 2024 18:03:00 +0300 Subject: [PATCH 02/34] support web link --- airflow/example_dags/airflow_dag_yq.py | 5 +- .../providers/yandex/hooks/yandexcloud_yq.py | 24 ++--- .../yandex/operators/yandexcloud_yq.py | 73 +++++---------- airflow/providers/yandex/provider.yaml | 21 ++++- airflow/providers/yandex/triggers/__init__.py | 16 ---- .../yandex/triggers/yandexcloud_yq.py | 93 ------------------- 6 files changed, 51 insertions(+), 181 deletions(-) delete mode 100644 airflow/providers/yandex/triggers/__init__.py delete mode 100644 airflow/providers/yandex/triggers/yandexcloud_yq.py diff --git a/airflow/example_dags/airflow_dag_yq.py b/airflow/example_dags/airflow_dag_yq.py index 8714336bc0c90..34f2e577dc1a3 100644 --- a/airflow/example_dags/airflow_dag_yq.py +++ b/airflow/example_dags/airflow_dag_yq.py @@ -40,6 +40,7 @@ catchup=False, dagrun_timeout=datetime.timedelta(minutes=60) ) as dag: + folder_id = "b1gaud5b392mmmeolb0k" run_this_last = EmptyOperator( task_id="run_this_last", ) @@ -141,7 +142,7 @@ def process_result(**kwargs): """ - yq_operator2 = YQExecuteQueryOperator(task_id="samplequery2", sql=query, connection_id="yandexcloud_default") + yq_operator2 = YQExecuteQueryOperator(task_id="samplequery2", sql=query, connection_id="yandexcloud_default", folder_id=folder_id) yq_operator2 >> run_this_last query_count_queries = """ @@ -183,7 +184,7 @@ def process_result(**kwargs): python_callable=process_query_count_result, provide_context=True) - yq_operator_queries_count = YQExecuteQueryOperator(task_id="get_queries_count", sql=query_count_queries, connection_id="yandexcloud_default") + yq_operator_queries_count = YQExecuteQueryOperator(task_id="get_queries_count", sql=query_count_queries, connection_id="yandexcloud_default", folder_id=folder_id) yq_operator_queries_count >> process_query_count_task # yq_operator3 = YQExecuteQueryOperator(task_id="samplequery3", sql="select 33 as d, 44 as t") diff --git a/airflow/providers/yandex/hooks/yandexcloud_yq.py b/airflow/providers/yandex/hooks/yandexcloud_yq.py index 3d423293171f2..850ad2fe3a76d 100644 --- a/airflow/providers/yandex/hooks/yandexcloud_yq.py +++ b/airflow/providers/yandex/hooks/yandexcloud_yq.py @@ -37,6 +37,7 @@ from airflow.providers.yandex.hooks.yandex import YandexCloudBaseHook from airflow.exceptions import AirflowException + import requests from requests.packages.urllib3.util.retry import Retry from enum import Enum @@ -62,18 +63,18 @@ def __init__(self, *args, **kwargs) -> None: def start_execute_query(self, query_type: QueryType, query_text: str|None, name: str|None=None, description: str | None = None) -> str: - # self.default_folder_id type = "ANALYTICS" if query_type == QueryType.ANALYTICS else "STREAMING" with YQHook.create_session(self.get_iam_token()) as session: + data = { + "name": name, + "type": type, + "text": query_text, + "description": description + } - data = {"name": name, - "type": type, - "text": query_text, - "description": description} self.log.info(f"folder={self.default_folder_id}") - response = session.post(f"https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries?project={self.default_folder_id}", - json=data) + response = session.post(f"https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries?project={self.default_folder_id}", json=data) response.raise_for_status() self.query_id = response.json()["id"] @@ -81,9 +82,8 @@ def start_execute_query(self, query_type: QueryType, query_text: str|None, name return self.query_id def wait_for_query_to_complete(self, execution_timeout: timedelta): - status = None try: - status = self.wait_results(self.query_id, execution_timeout) + return self.wait_results(self.query_id, execution_timeout) except TimeoutError as err: self.stop_query(self.query_id) raise @@ -102,9 +102,6 @@ def get_query_result(self, query_id): self.log.debug(query_results) return query_results - def get_pandas_df(self)-> pd.DataFrame: - return pd.DataFrame() - def query_results(self, query_id:str, result_set_count:int)->object: results = list() limit = 1000 @@ -180,10 +177,9 @@ async def get_query_status_async(self, query_id:str)->str: @staticmethod def get_request_url_header_params(iam_token: str|None=None)->dict[str,str]: - print(f"get_request_url_header_params iam_token={iam_token}") headers = {} if iam_token is not None: - headers['Authorization'] = f"{iam_token}" + headers['Authorization'] = f"Bearer {iam_token}" return headers diff --git a/airflow/providers/yandex/operators/yandexcloud_yq.py b/airflow/providers/yandex/operators/yandexcloud_yq.py index 3a0de480791ea..cd9d41c9fb7c5 100644 --- a/airflow/providers/yandex/operators/yandexcloud_yq.py +++ b/airflow/providers/yandex/operators/yandexcloud_yq.py @@ -16,30 +16,36 @@ # under the License. from __future__ import annotations -import warnings from dataclasses import dataclass from typing import TYPE_CHECKING, Sequence from datetime import timedelta from airflow.configuration import conf -import pandas as pd -from airflow.exceptions import AirflowException from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator +from airflow.models import BaseOperator, BaseOperatorLink, XCom +from airflow.models.taskinstancekey import TaskInstanceKey from airflow.providers.yandex.hooks.yandexcloud_yq import YQHook, QueryType -from airflow.providers.yandex.triggers.yandexcloud_yq import YQQueryStatusTrigger if TYPE_CHECKING: from airflow.utils.context import Context + +class YQLink(BaseOperatorLink): + name = "Yandex Query" + + def get_link(self, operator: BaseOperator, *, ti_key: TaskInstanceKey): + return XCom.get_value(key="web_link", ti_key=ti_key) or "https://yq.cloud.yandex.ru" + + class YQExecuteQueryOperator(SQLExecuteQueryOperator): """ - Executes sql code using a specific Trino query Engine. + Executes sql code using Yandex Query service. - :param sql: the SQL code to be executed as a single string, or - a list of str (sql statements), or a reference to a template file. + :param sql: the SQL code to be executed as a single string """ + operator_extra_links = (YQLink(),) template_fields: Sequence[str] = ("sql",) template_fields_renderers = {"sql": "sql"} template_ext: Sequence[str] = (".sql",) @@ -80,51 +86,14 @@ def execute(self, context: Context) -> None: self.hook.start_execute_query(self.type, self.sql, self.name, self.description) - # value = [1] - # # Deprecated - # context["task_instance"].xcom_push(key="query_result", value=value) - # return value - - self.log.info(f"deffered is allowed [{self.deferrable}]") - - if self.deferrable: - # if True: - self.defer( - trigger=YQQueryStatusTrigger( - poll_interval=timedelta(seconds=2).seconds, - query_id=self.hook.query_id, - connection_id=self.connection_id, - folder_id=self.folder_id, - public_ssh_key=self.public_ssh_key, - service_account_id=self.service_account_id - ), - method_name="execute_complete", - ) - else: - self.hook.wait_for_query_to_complete(self.execution_timeout) - return self.hook.get_query_result(self.hook.query_id) - - def execute_complete(self, context: Context, event: dict[str, str | list[str]] | None = None) -> None: - if "status" in event and event["status"]!="COMPLETED": - msg = None - if 'message' in event: - msg = f"{event['status']}: {event['message']}" - else: - msg = event["status"] - raise AirflowException(msg) - else: - query_id = event["query_id"] - - hook = YQHook( connection_id=event["connection_id"], - default_folder_id=event["folder_id"], - default_public_ssh_key=event["public_ssh_key"], - default_service_account_id=event["service_account_id"]) - - result = hook.get_query_result(query_id) - self.log.info("%s completed successfully.", self.task_id) - return result - - def on_kill(): + # pass to YQLink + web_link = f"https://yq.cloud.yandex.ru/folders/{self.folder_id}/ide/queries/{self.hook.query_id}" + context["ti"].xcom_push(key="web_link", value=web_link) + + self.hook.wait_for_query_to_complete(self.execution_timeout) + return self.hook.get_query_result(self.hook.query_id) + + def on_kill(self): if self.hook is not None: self.hook.stop_current_query() diff --git a/airflow/providers/yandex/provider.yaml b/airflow/providers/yandex/provider.yaml index 08c31f88d23c1..30d556f0c81d7 100644 --- a/airflow/providers/yandex/provider.yaml +++ b/airflow/providers/yandex/provider.yaml @@ -19,9 +19,8 @@ package-name: apache-airflow-providers-yandex name: Yandex description: | - This package is for Yandex, including: + Yandex including `Yandex.Cloud `__ - - `Yandex.Cloud `__ state: ready source-date-epoch: 1707636562 # note that those versions are maintained by release manager - do not update them manually @@ -63,11 +62,22 @@ integrations: logo: /integration-logos/yandex/Yandex-Cloud.png tags: [service] + - integration-name: Yandex.Cloud YQ + external-doc-url: https://cloud.yandex.com/dataproc + how-to-guide: + - /docs/apache-airflow-providers-yandex/operators.rst + logo: /integration-logos/yandex/Yandex-Cloud.png + tags: [service] + operators: - integration-name: Yandex.Cloud Dataproc python-modules: - airflow.providers.yandex.operators.yandexcloud_dataproc + - integration-name: Yandex.Cloud YQ + python-modules: + - airflow.providers.yandex.operators.yandexcloud_yq + hooks: - integration-name: Yandex.Cloud python-modules: @@ -75,13 +85,16 @@ hooks: - integration-name: Yandex.Cloud Dataproc python-modules: - airflow.providers.yandex.hooks.yandexcloud_dataproc + - integration-name: Yandex.Cloud YQ + python-modules: + - airflow.providers.yandex.hooks.yandexcloud_yq connection-types: - hook-class-name: airflow.providers.yandex.hooks.yandex.YandexCloudBaseHook connection-type: yandexcloud -secrets-backends: - - airflow.providers.yandex.secrets.lockbox.LockboxSecretBackend +extra-links: + - airflow.providers.yandex.operators.yandexcloud_yq.YQLink config: yandex: diff --git a/airflow/providers/yandex/triggers/__init__.py b/airflow/providers/yandex/triggers/__init__.py deleted file mode 100644 index 13a83393a9124..0000000000000 --- a/airflow/providers/yandex/triggers/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. diff --git a/airflow/providers/yandex/triggers/yandexcloud_yq.py b/airflow/providers/yandex/triggers/yandexcloud_yq.py deleted file mode 100644 index 1ff41d2a09c21..0000000000000 --- a/airflow/providers/yandex/triggers/yandexcloud_yq.py +++ /dev/null @@ -1,93 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -import asyncio -from typing import TYPE_CHECKING, Any, AsyncIterator - -from airflow.providers.yandex.hooks.yandexcloud_yq import YQHook, QueryType -from airflow.triggers.base import BaseTrigger, TriggerEvent -import traceback - -if TYPE_CHECKING: - from datetime import timedelta - - -class YQQueryStatusTrigger(BaseTrigger): - - def __init__( - self, - poll_interval: float, - query_id: str, - folder_id: str | None = None, - connection_id: str | None = None, - public_ssh_key: str | None = None, - service_account_id: str | None = None, - - ): - super().__init__() - self.poll_interval = poll_interval - self.query_id = query_id - self.connection_id = connection_id - self.folder_id = folder_id - self.public_ssh_key = public_ssh_key - self.service_account_id = service_account_id - - def serialize(self) -> tuple[str, dict[str, Any]]: - return ( - "airflow.providers.yandex.triggers.yandexcloud_yq.YQQueryStatusTrigger", - { - "poll_interval": self.poll_interval, - "query_id": self.query_id, - "connection_id": self.connection_id, - "folder_id": self.folder_id, - "public_ssh_key": self.public_ssh_key, - "service_account_id": self.service_account_id - }, - ) - - async def run(self) -> AsyncIterator[TriggerEvent]: - try: - while True: - status = await self.get_query_status(self.query_id) - if status not in ["RUNNING", "PENDING"]: - break - await asyncio.sleep(self.poll_interval) - - yield TriggerEvent( - { - "status": status, - "query_id": self.query_id, - "folder_id": self.folder_id, - "public_ssh_key": self.public_ssh_key, - "service_account_id": self.service_account_id, - "connection_id": self.connection_id - } - ) - except Exception as e: - message = f"{str(e)} trace={traceback.format_exc()}" - yield TriggerEvent({"status": "error", "message": message}) - - async def get_query_status(self, query_id: str) -> dict[str, Any]: - """Return True if the SQL query is still running otherwise return False.""" - hook = YQHook( - connection_id=self.connection_id - ) - return await hook.get_query_status_async(query_id) - - def _set_context(self, context): - pass From aad01bd5718ee9b8be41a388e4f938c6f2c8decb Mon Sep 17 00:00:00 2001 From: uzhastik Date: Fri, 9 Feb 2024 18:35:19 +0300 Subject: [PATCH 03/34] move jwt logic out of base hook --- airflow/providers/yandex/hooks/yandex.py | 63 +-------------- .../providers/yandex/hooks/yandexcloud_yq.py | 76 +++++++++++-------- 2 files changed, 46 insertions(+), 93 deletions(-) diff --git a/airflow/providers/yandex/hooks/yandex.py b/airflow/providers/yandex/hooks/yandex.py index 5290816c464a1..8e508ccc6ca8c 100644 --- a/airflow/providers/yandex/hooks/yandex.py +++ b/airflow/providers/yandex/hooks/yandex.py @@ -16,17 +16,12 @@ # under the License. from __future__ import annotations -import json import warnings -import requests -import time -import jwt -from requests.packages.urllib3.util.retry import Retry from typing import Any import yandexcloud -from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning +from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.hooks.base import BaseHook from airflow.providers.yandex.utils.credentials import ( get_credentials, @@ -162,58 +157,4 @@ def _get_endpoint(self) -> dict[str, str]: def _get_field(self, field_name: str, default: Any = None) -> Any: if not hasattr(self, "extras"): return default - return get_field_from_extras(self.extras, field_name, default) - - def get_iam_token(self) -> str: - if "oauth" in self.credentials: - return YandexCloudBaseHook._resolve_oauth(self.credentials["oauth"]) - if "service_account_key" in self.credentials: - return YandexCloudBaseHook._resolve_service_account_key(self.credentials["service_account_key"]) - raise AirflowException(f"Unknown credentials type {self.credentials.keys()}") - - @staticmethod - def _resolve_oauth(self, token: str) -> str: - pass - - @staticmethod - def _resolve_service_account_key(sa_info) -> str: - session = YandexCloudBaseHook.create_session() - - api = 'https://iam.api.cloud.yandex.net/iam/v1/tokens' - now = int(time.time()) - payload = { - 'aud': api, - 'iss': sa_info["service_account_id"], - 'iat': now, - 'exp': now + 360} - - encoded_token = jwt.encode( - payload, - sa_info["private_key"], - algorithm='PS256', - headers={'kid': sa_info["id"]}) - - data = {"jwt": encoded_token} - iam_response = session.post(api, json=data) - iam_response.raise_for_status() - - return iam_response.json()["iamToken"] - - @staticmethod - def create_session() -> requests.Session: - session = requests.Session() - session.verify = False - retry = Retry( - backoff_factor=0.3, - total=10 - ) - session.mount( - 'http://', - requests.adapters.HTTPAdapter(max_retries=retry) - ) - session.mount( - 'https://', - requests.adapters.HTTPAdapter(max_retries=retry) - ) - - return session + return get_field_from_extras(self.extras, field_name, default) \ No newline at end of file diff --git a/airflow/providers/yandex/hooks/yandexcloud_yq.py b/airflow/providers/yandex/hooks/yandexcloud_yq.py index 850ad2fe3a76d..088fe579ef9b7 100644 --- a/airflow/providers/yandex/hooks/yandexcloud_yq.py +++ b/airflow/providers/yandex/hooks/yandexcloud_yq.py @@ -15,17 +15,20 @@ # specific language governing permissions and limitations # under the License. from __future__ import annotations +from requests.packages.urllib3.util.retry import Retry +from datetime import timedelta, datetime +from enum import Enum import logging +import requests +import time +import jwt # These two lines enable debugging at httplib level (requests->urllib3->http.client) # You will see the REQUEST, including HEADERS and DATA, and RESPONSE with HEADERS but without DATA. # The only thing missing will be the response.body which is not logged. -try: - import http.client as http_client -except ImportError: - # Python 2 - import httplib as http_client + +import http.client as http_client http_client.HTTPConnection.debuglevel = 1 # You must initialize logging, otherwise you'll not see debug output. @@ -38,12 +41,6 @@ from airflow.providers.yandex.hooks.yandex import YandexCloudBaseHook from airflow.exceptions import AirflowException -import requests -from requests.packages.urllib3.util.retry import Retry -from enum import Enum -import time -from datetime import timedelta, datetime -import aiohttp class QueryType(Enum): ANALYTICS = 1 @@ -84,7 +81,7 @@ def start_execute_query(self, query_type: QueryType, query_text: str|None, name def wait_for_query_to_complete(self, execution_timeout: timedelta): try: return self.wait_results(self.query_id, execution_timeout) - except TimeoutError as err: + except TimeoutError: self.stop_query(self.query_id) raise @@ -143,7 +140,6 @@ def get_queryinfo(self, query_id:str)->object: with YQHook.create_session(iam_token) as session: response = session.get(f'https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries/{query_id}?project={self.default_folder_id}') response.raise_for_status() - print(response.json()) return response.json() @@ -160,21 +156,6 @@ def get_query_status(self, query_id:str)->str: status = response.json()["status"] return status - async def get_query_status_async(self, query_id:str)->str: - iam_token = self.get_iam_token() - - headers = YQHook.get_request_url_header_params(iam_token) - url = f'https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries/{query_id}/status?project={self.default_folder_id}' - - self.log.info(f"Retrieving status for query id {query_id}, url {url}") - - async with aiohttp.ClientSession(headers=headers) as session: - async with session.get(url) as response: - status_code = response.status - assert status_code == 200 - resp = await response.json() - return resp["status"] - @staticmethod def get_request_url_header_params(iam_token: str|None=None)->dict[str,str]: headers = {} @@ -197,8 +178,40 @@ def wait_results(self, query_id:str, execution_timeout: timedelta)->str: time.sleep(2) + def get_iam_token(self) -> str: + if "oauth" in self.credentials: + return self.credentials["oauth"] + if "service_account_key" in self.credentials: + return YQHook._resolve_service_account_key(self.credentials["service_account_key"]) + raise AirflowException(f"Unknown credentials type {self.credentials.keys()}") + @staticmethod - def create_session(iamtoken: str | None = None) -> requests.Session: + def _resolve_service_account_key(sa_info) -> str: + session = YQHook.create_session() + + api = 'https://iam.api.cloud.yandex.net/iam/v1/tokens' + now = int(time.time()) + payload = { + 'aud': api, + 'iss': sa_info["service_account_id"], + 'iat': now, + 'exp': now + 360 + } + + encoded_token = jwt.encode( + payload, + sa_info["private_key"], + algorithm='PS256', + headers={'kid': sa_info["id"]}) + + data = {"jwt": encoded_token} + iam_response = session.post(api, json=data) + iam_response.raise_for_status() + + return iam_response.json()["iamToken"] + + @staticmethod + def create_session(iam_token: str | None = None) -> requests.Session: session = requests.Session() session.verify = False session.timeout = 20 @@ -215,10 +228,9 @@ def create_session(iamtoken: str | None = None) -> requests.Session: requests.adapters.HTTPAdapter(max_retries=retry) ) - headers = YQHook.get_request_url_header_params(iamtoken) + headers = YQHook.get_request_url_header_params(iam_token) for k, v in headers.items(): session.headers[k] = v return session - - + \ No newline at end of file From 39ecf8ee32f88306257e6e13891884b3b5470c2f Mon Sep 17 00:00:00 2001 From: Sergey Uzhakov Date: Sat, 10 Feb 2024 14:17:33 +0300 Subject: [PATCH 04/34] use http client in hook --- airflow/providers/yandex/hooks/http_client.py | 273 ++++++++++++++++++ .../providers/yandex/hooks/yandexcloud_yq.py | 202 ++++--------- .../yandex/operators/yandexcloud_yq.py | 39 +-- 3 files changed, 353 insertions(+), 161 deletions(-) create mode 100644 airflow/providers/yandex/hooks/http_client.py diff --git a/airflow/providers/yandex/hooks/http_client.py b/airflow/providers/yandex/hooks/http_client.py new file mode 100644 index 0000000000000..fd65c4344807b --- /dev/null +++ b/airflow/providers/yandex/hooks/http_client.py @@ -0,0 +1,273 @@ +from __future__ import annotations + +import logging +import time +from datetime import datetime +import requests + +from requests.adapters import HTTPAdapter +from typing import Any +from urllib3.util.retry import Retry + +MAX_RETRY_FOR_SESSION = 4 +BACK_OFF_FACTOR = 0.3 +TIME_BETWEEN_RETRIES = 1000 +ERROR_CODES = (500, 502, 504) + + +def requests_retry_session(session, + retries=MAX_RETRY_FOR_SESSION, + back_off_factor=BACK_OFF_FACTOR, + status_force_list=ERROR_CODES): + retry = Retry(total=retries, read=retries, connect=retries, + backoff_factor=back_off_factor, + status_forcelist=status_force_list, + method_whitelist=frozenset(['GET', 'POST'])) + adapter = HTTPAdapter(max_retries=retry) + session.mount('http://', adapter) + session.mount('https://', adapter) + return session + + +class YQHttpClientConfig(object): + def __init__(self, + token: str | None = None, + project: str | None = None, + user_agent: str | None = "Python YQ HTTP SDK") -> None: + + self.token = token + self.project = project + self.user_agent = user_agent + + # urls should not contain trailing / + self.endpoint: str = "https://api.yandex-query.cloud.yandex.net" + self.web_base_url: str = "https://yq.cloud.yandex.ru" + +class YQHttpClientException(BaseException): + def __init__(self, *args: object) -> None: + super().__init__(*args) + + +class YQHttpClient(object): + def __init__(self, config: YQHttpClientConfig): + self.config = config + self.session = requests_retry_session(session=requests.Session()) + + def __enter__(self): + return self + + def __exit__(self, *args): + self.session.close() + + def _build_headers(self, idempotency_key=None, request_id=None) -> dict[str, str]: + headers = { + "Authorization": f"Bearer {self.config.token}" + } + if idempotency_key is not None: + headers["Idempotency-Key"] = idempotency_key + + if request_id is not None: + headers["x-request-id"] = request_id + + if self.config.user_agent is not None: + headers["User-Agent"] = self.config.user_agent + + return headers + + def _build_params(self) -> dict[str, str]: + params = {} + if self.config.project is not None: + params["project"] = self.config.project + + return params + + def _compose_api_url(self, path: str) -> str: + return self.config.endpoint + path + + def _compose_web_url(self, path: str) -> str: + return self.config.web_base_url + path + + def _validate_http_error(self, response, expected_code=200) -> None: + logging.info(f"Response: {response.status_code}, {response.text}") + if response.status_code != expected_code: + if response.headers.get("Content-Type", "").startswith("application/json"): + body = response.json() + raise YQHttpClientException(f"Error occurred: {response.status_code}", + status=body.get("status"), + message=body.get("message"), + details=body.get("details") + ) + + raise YQHttpClientException(f"Error occurred: {response.status_code}, {response.text}") + + def create_query(self, + query_text=None, + type=None, + name=None, + description=None, + idempotency_key=None, + request_id=None, + expected_code=200): + body = dict() + if query_text is not None: + body["text"] = query_text + + if type is not None: + body["type"] = type + + if name is not None: + body["name"] = name + + if description is not None: + body["description"] = description + + response = self.session.post(self._compose_api_url("/api/fq/v1/queries"), + headers=self._build_headers(idempotency_key=idempotency_key, + request_id=request_id), + params=self._build_params(), + json=body) + + self._validate_http_error(response, expected_code=expected_code) + return response.json()["id"] + + def get_query_status(self, query_id, request_id=None, expected_code=200) -> Any: + response = self.session.get( + self._compose_api_url(f"/api/fq/v1/queries/{query_id}/status"), + headers=self._build_headers(request_id=request_id), + params=self._build_params() + ) + + self._validate_http_error(response, expected_code=expected_code) + return response.json()["status"] + + def get_query(self, query_id, request_id=None, expected_code=200) -> Any: + response = self.session.get( + self._compose_api_url(f"/api/fq/v1/queries/{query_id}"), + headers=self._build_headers(request_id=request_id), + params=self._build_params() + ) + + self._validate_http_error(response, expected_code=expected_code) + return response.json() + + def stop_query(self, + query_id: str, + idempotency_key: str | None = None, + request_id: str | None = None, + expected_code: int = 204) -> Any: + + response = self.session.post(self._compose_api_url(f"/api/fq/v1/queries/{query_id}/stop"), + headers=self._build_headers(idempotency_key=idempotency_key, request_id=request_id), + params=self._build_params()) + self._validate_http_error(response, expected_code=expected_code) + return response + + def wait_query_to_complete(self, query_id, execution_timeout=None, stop_on_timeout=False) -> str: + status = None + delay = 0.2 # start with 0.2 sec + try: + start = datetime.now() + while True: + if execution_timeout is not None and datetime.now() > start + execution_timeout: + raise TimeoutError(f"Query {query_id} execution timeout, last status {status}") + + status = self.get_query_status(query_id) + if status not in ["RUNNING", "PENDING"]: + return status + + time.sleep(delay) + delay *= 2 + delay = min(2, delay) # up to 2 seconds + + except TimeoutError: + if stop_on_timeout: + self.stop_query(query_id) + raise + + def wait_query_to_succeed(self, query_id, execution_timeout=None, stop_on_timeout=False) -> int: + status = self.wait_query_to_complete( + query_id=query_id, + execution_timeout=execution_timeout, + stop_on_timeout=stop_on_timeout + ) + + query = self.get_query(query_id) + if status != "COMPLETED": + issues = query["issues"] + raise RuntimeError(f"Query {query_id} failed", issues=issues) + + return len(query["result_sets"]) + + def get_query_result_set_page(self, + query_id, + result_set_index, + offset=None, + limit=None, + raw_format=False, + request_id=None, + expected_code=200) -> Any: + params = self._build_params() + if offset is not None: + params["offset"] = offset + + if limit is not None: + params["limit"] = limit + + response = self.session.get( + self._compose_api_url(f"/api/fq/v1/queries/{query_id}/results/{result_set_index}"), + headers=self._build_headers(request_id=request_id), + params=params + ) + + self._validate_http_error(response, expected_code=expected_code) + return response.json() + + def get_query_result_set(self, query_id: str, result_set_index: int, raw_format: bool = False) -> Any: + offset = 0 + limit = 1000 + columns = None + rows = [] + while True: + part = self.get_query_result_set_page( + query_id, + result_set_index=result_set_index, + offset=offset, + limit=limit, + raw_format=raw_format + ) + + if columns is None: + columns = part["columns"] + + r = part["rows"] + rows.extend(r) + if len(r) != limit: + break + + offset += limit + + return {"rows": rows, "columns": columns} + + def get_query_all_result_sets(self, query_id: str, result_set_count: int, raw_format: bool = False) -> Any: + result = list() + for i in range(0, result_set_count): + r = self.get_query_result_set( + query_id, + result_set_index=i, + raw_format=raw_format + ) + + if result_set_count == 1: + return r + + result.append(r) + + return result + + def get_openapi_spec(self) -> str: + response = self.session.get(self._compose_api_url("/resources/v1/openapi.yaml")) + self._validate_http_error(response) + return response.text + + def compose_query_web_link(self, query_id) -> str: + return self._compose_web_url(f"/folders/{self.config.project}/ide/queries/{query_id}") diff --git a/airflow/providers/yandex/hooks/yandexcloud_yq.py b/airflow/providers/yandex/hooks/yandexcloud_yq.py index 088fe579ef9b7..e1329e85bfd4e 100644 --- a/airflow/providers/yandex/hooks/yandexcloud_yq.py +++ b/airflow/providers/yandex/hooks/yandexcloud_yq.py @@ -22,14 +22,15 @@ import logging import requests import time +from typing import Any import jwt # These two lines enable debugging at httplib level (requests->urllib3->http.client) # You will see the REQUEST, including HEADERS and DATA, and RESPONSE with HEADERS but without DATA. # The only thing missing will be the response.body which is not logged. -import http.client as http_client -http_client.HTTPConnection.debuglevel = 1 +import http.client +http.client.HTTPConnection.debuglevel = 1 # You must initialize logging, otherwise you'll not see debug output. logging.basicConfig() @@ -39,8 +40,9 @@ requests_log.propagate = True from airflow.providers.yandex.hooks.yandex import YandexCloudBaseHook +from airflow.providers.yandex.hooks.http_client import YQHttpClientConfig, YQHttpClient from airflow.exceptions import AirflowException - +from airflow.providers.yandex.utils.user_agent import provider_user_agent class QueryType(Enum): ANALYTICS = 1 @@ -48,135 +50,50 @@ class QueryType(Enum): class YQHook(YandexCloudBaseHook): """ - A base hook for Yandex.Cloud Data Proc. - - :param yandex_conn_id: The connection ID to use when fetching connection info. + A hook for Yandex Query """ def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) - self.query_id: str | None = None - - - def start_execute_query(self, query_type: QueryType, query_text: str|None, name: str|None=None, description: str | None = None) -> str: - type = "ANALYTICS" if query_type == QueryType.ANALYTICS else "STREAMING" + config = YQHttpClientConfig( + token=self.get_iam_token(), + project=self.default_folder_id, + user_agent=provider_user_agent() + ) - with YQHook.create_session(self.get_iam_token()) as session: - data = { - "name": name, - "type": type, - "text": query_text, - "description": description - } + self.client: YQHttpClient = YQHttpClient(config=config) - self.log.info(f"folder={self.default_folder_id}") - response = session.post(f"https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries?project={self.default_folder_id}", json=data) - response.raise_for_status() - - self.query_id = response.json()["id"] - - return self.query_id - - def wait_for_query_to_complete(self, execution_timeout: timedelta): - try: - return self.wait_results(self.query_id, execution_timeout) - except TimeoutError: - self.stop_query(self.query_id) - raise - - def get_query_result(self, query_id): - self.log.info(f"get_query_result query_id={query_id}") - query_info = self.get_queryinfo(query_id) - if query_info["status"] == "FAILED": - issues = query_info["issues"] - raise RuntimeError("Query failed", issues=issues) - - result_set_count = len(query_info["result_sets"]) - self.log.debug(f"result set count {result_set_count}") - - query_results = self.query_results(query_id, result_set_count) - self.log.debug(query_results) - return query_results - - def query_results(self, query_id:str, result_set_count:int)->object: - results = list() - limit = 1000 - offset = 0 - - iam_token = self.get_iam_token() - with YQHook.create_session(iam_token) as session: - for result_index in range(0, result_set_count): - columns = None - rows = [] - while True: - print(f"limit={limit} offset={offset}") - response = session.get(f'https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries/{query_id}/results/{result_index}?project={self.default_folder_id}&limit={limit}&offset={offset}') - response.raise_for_status() - - qresults = response.json() - print(qresults) - if columns is None: - columns = qresults["columns"] - - rows.extend( qresults["rows"]) - if len(qresults["rows"]) != limit: - break - else: - offset += limit - - results.append({"rows":rows, "columns": columns}) - - if len(results) == 1: - return results[0] - else: - return results - - def stop_current_query(self)->None: - self.stop_query(self.query_id) - - def get_queryinfo(self, query_id:str)->object: - iam_token = self.get_iam_token() - with YQHook.create_session(iam_token) as session: - response = session.get(f'https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries/{query_id}?project={self.default_folder_id}') - response.raise_for_status() - - return response.json() - - def stop_query(self, query_id:str)->None: - iam_token = self.get_iam_token() - with YQHook.create_session(iam_token) as session: - session.get(f'https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries/{query_id}/stop?project={self.default_folder_id}') - - def get_query_status(self, query_id:str)->str: - iam_token = self.get_iam_token() - with YQHook.create_session(iam_token) as session: - response = session.get(f'https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries/{query_id}/status?project={self.default_folder_id}') - response.raise_for_status() - status = response.json()["status"] - return status + def close(self): + self.client.close() - @staticmethod - def get_request_url_header_params(iam_token: str|None=None)->dict[str,str]: - headers = {} - if iam_token is not None: - headers['Authorization'] = f"Bearer {iam_token}" + def create_query(self, query_text: str|None, name: str|None=None, description: str | None = None, query_type: QueryType = QueryType.ANALYTICS) -> str: + type = "ANALYTICS" if query_type == QueryType.ANALYTICS else "STREAMING" + + return self.client.create_query( + name=name, + type=type, + query_text=query_text, + description=description + ) - return headers + def stop_query(self, query_id: str) -> None: + self.stop_query(query_id) - def wait_results(self, query_id:str, execution_timeout: timedelta)->str: - execution_timeout = execution_timeout if execution_timeout is not None else timedelta(minutes=30) + def get_query(self, query_id: str) -> Any: + return self.client.get_query(query_id) - start = datetime.now() - while True: - if datetime.now() > start + execution_timeout: - raise TimeoutError("Query execution timeout") + def get_query_status(self, query_id: str) -> str: + return self.client.get_query_status(query_id) - status = self.get_query_status(query_id) - if status not in ["RUNNING", "PENDING"]: - return status + def wait_results(self, query_id: str) -> Any: + result_set_count = self.client.wait_query_to_succeed( + query_id, + execution_timeout=timedelta(minutes=30), + stop_on_timeout=True + ) - time.sleep(2) + return self.client.get_query_all_result_sets(query_id=query_id, result_set_count=result_set_count) def get_iam_token(self) -> str: if "oauth" in self.credentials: @@ -185,33 +102,36 @@ def get_iam_token(self) -> str: return YQHook._resolve_service_account_key(self.credentials["service_account_key"]) raise AirflowException(f"Unknown credentials type {self.credentials.keys()}") + def compose_query_web_link(self, query_id:str): + return self.client.compose_query_web_link(query_id) + @staticmethod def _resolve_service_account_key(sa_info) -> str: - session = YQHook.create_session() - - api = 'https://iam.api.cloud.yandex.net/iam/v1/tokens' - now = int(time.time()) - payload = { - 'aud': api, - 'iss': sa_info["service_account_id"], - 'iat': now, - 'exp': now + 360 - } + with YQHook.create_session() as session: + api = 'https://iam.api.cloud.yandex.net/iam/v1/tokens' + now = int(time.time()) + payload = { + 'aud': api, + 'iss': sa_info["service_account_id"], + 'iat': now, + 'exp': now + 360 + } - encoded_token = jwt.encode( - payload, - sa_info["private_key"], - algorithm='PS256', - headers={'kid': sa_info["id"]}) + encoded_token = jwt.encode( + payload, + sa_info["private_key"], + algorithm='PS256', + headers={'kid': sa_info["id"]} + ) - data = {"jwt": encoded_token} - iam_response = session.post(api, json=data) - iam_response.raise_for_status() + data = {"jwt": encoded_token} + iam_response = session.post(api, json=data) + iam_response.raise_for_status() - return iam_response.json()["iamToken"] + return iam_response.json()["iamToken"] @staticmethod - def create_session(iam_token: str | None = None) -> requests.Session: + def create_session() -> requests.Session: session = requests.Session() session.verify = False session.timeout = 20 @@ -228,9 +148,5 @@ def create_session(iam_token: str | None = None) -> requests.Session: requests.adapters.HTTPAdapter(max_retries=retry) ) - headers = YQHook.get_request_url_header_params(iam_token) - for k, v in headers.items(): - session.headers[k] = v - return session \ No newline at end of file diff --git a/airflow/providers/yandex/operators/yandexcloud_yq.py b/airflow/providers/yandex/operators/yandexcloud_yq.py index cd9d41c9fb7c5..ad68d79a50eaa 100644 --- a/airflow/providers/yandex/operators/yandexcloud_yq.py +++ b/airflow/providers/yandex/operators/yandexcloud_yq.py @@ -17,11 +17,10 @@ from __future__ import annotations from dataclasses import dataclass -from typing import TYPE_CHECKING, Sequence +from typing import TYPE_CHECKING, Sequence, Any from datetime import timedelta from airflow.configuration import conf - from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator from airflow.models import BaseOperator, BaseOperatorLink, XCom from airflow.models.taskinstancekey import TaskInstanceKey @@ -30,12 +29,13 @@ if TYPE_CHECKING: from airflow.utils.context import Context +XCOM_WEBLINK_KEY="web_link" class YQLink(BaseOperatorLink): name = "Yandex Query" def get_link(self, operator: BaseOperator, *, ti_key: TaskInstanceKey): - return XCom.get_value(key="web_link", ti_key=ti_key) or "https://yq.cloud.yandex.ru" + return XCom.get_value(key=XCOM_WEBLINK_KEY, ti_key=ti_key) or "https://yq.cloud.yandex.ru" class YQExecuteQueryOperator(SQLExecuteQueryOperator): @@ -75,8 +75,9 @@ def __init__( self.service_account_id = service_account_id self.hook: YQHook | None = None + self.query_id: str | None - def execute(self, context: Context) -> None: + def execute(self, context: Context) -> Any: self.hook = YQHook( yandex_conn_id=self.connection_id, default_folder_id=self.folder_id, @@ -84,20 +85,22 @@ def execute(self, context: Context) -> None: default_service_account_id=self.service_account_id ) - self.hook.start_execute_query(self.type, self.sql, self.name, self.description) - - # pass to YQLink - web_link = f"https://yq.cloud.yandex.ru/folders/{self.folder_id}/ide/queries/{self.hook.query_id}" - context["ti"].xcom_push(key="web_link", value=web_link) + self.query_id = self.hook.create_query( + query_type=self.type, + query_text=self.sql, + name=self.name, + description=self.description + ) - self.hook.wait_for_query_to_complete(self.execution_timeout) - return self.hook.get_query_result(self.hook.query_id) + # pass to YQLink + web_link = self.hook.compose_query_web_link(self.query_id) + context["ti"].xcom_push(key=XCOM_WEBLINK_KEY, value=web_link) - def on_kill(self): - if self.hook is not None: - self.hook.stop_current_query() + results = self.hook.wait_results(self.query_id) + # forget query to avoid 'stop_query' in on_kill + self.query_id = None + return results - @staticmethod - def to_dataframe(data): - column_names = [column["name"] for column in data["columns"]] - return pd.DataFrame(data["rows"], columns= column_names) + def on_kill(self) -> None: + if self.hook is not None and self.query_id is not None: + self.hook.stop_query(self.query_id) From 71917b54c211cb48963ef92440a7b8e81b86ecbf Mon Sep 17 00:00:00 2001 From: Sergey Uzhakov Date: Mon, 12 Feb 2024 19:52:09 +0300 Subject: [PATCH 05/34] add yq_results --- airflow/providers/yandex/hooks/http_client.py | 23 +- airflow/providers/yandex/hooks/yq_results.py | 303 ++++++++++++++++++ 2 files changed, 319 insertions(+), 7 deletions(-) create mode 100644 airflow/providers/yandex/hooks/yq_results.py diff --git a/airflow/providers/yandex/hooks/http_client.py b/airflow/providers/yandex/hooks/http_client.py index fd65c4344807b..070336435747a 100644 --- a/airflow/providers/yandex/hooks/http_client.py +++ b/airflow/providers/yandex/hooks/http_client.py @@ -22,7 +22,7 @@ def requests_retry_session(session, retry = Retry(total=retries, read=retries, connect=retries, backoff_factor=back_off_factor, status_forcelist=status_force_list, - method_whitelist=frozenset(['GET', 'POST'])) + allowed_methods=frozenset(['GET', 'POST'])) adapter = HTTPAdapter(max_retries=retry) session.mount('http://', adapter) session.mount('https://', adapter) @@ -35,6 +35,7 @@ def __init__(self, project: str | None = None, user_agent: str | None = "Python YQ HTTP SDK") -> None: + assert len(token) > 0, "empty token" self.token = token self.project = project self.user_agent = user_agent @@ -43,9 +44,13 @@ def __init__(self, self.endpoint: str = "https://api.yandex-query.cloud.yandex.net" self.web_base_url: str = "https://yq.cloud.yandex.ru" -class YQHttpClientException(BaseException): - def __init__(self, *args: object) -> None: - super().__init__(*args) + +class YQHttpClientException(Exception): + def __init__(self, prefix_message: str, status: str, message: str, details: Any) -> None: + super().__init__(f"{prefix_message} : {message}") + self.status = status + self.message = message + self.details = details class YQHttpClient(object): @@ -155,9 +160,13 @@ def stop_query(self, idempotency_key: str | None = None, request_id: str | None = None, expected_code: int = 204) -> Any: - + + headers = self._build_headers( + idempotency_key=idempotency_key, + request_id=request_id + ) response = self.session.post(self._compose_api_url(f"/api/fq/v1/queries/{query_id}/stop"), - headers=self._build_headers(idempotency_key=idempotency_key, request_id=request_id), + headers=headers, params=self._build_params()) self._validate_http_error(response, expected_code=expected_code) return response @@ -261,7 +270,7 @@ def get_query_all_result_sets(self, query_id: str, result_set_count: int, raw_fo return r result.append(r) - + return result def get_openapi_spec(self) -> str: diff --git a/airflow/providers/yandex/hooks/yq_results.py b/airflow/providers/yandex/hooks/yq_results.py new file mode 100644 index 0000000000000..4098ba30b9bd7 --- /dev/null +++ b/airflow/providers/yandex/hooks/yq_results.py @@ -0,0 +1,303 @@ +from typing import Any, Optional +import base64 +import pprint +import dateutil.parser +from datetime import datetime +from decimal import Decimal + + +class YQResults: + """Holds and formats query execution results""" + + def __init__(self, results: list[dict[str:Any]] | dict[str:Any]): + self._raw_results = results if results is list else [results] + self._results = None + + def _convert(self): + return [YQResults._convert_single(result) for result in self._raw_results] + + @staticmethod + def _convert_from_float(value: float|str) -> float: + # special values, e.g inf encoded as str, normal values are in float + return float(value) + + @staticmethod + def _convert_from_pgfloat(value: str|None) -> float: + if value is None: + return None + return float(value) + + @staticmethod + def _convert_from_pgint(value: str | None) -> int: + if value is None: + return None + return int(value) + + @staticmethod + def _convert_from_decimal(value: str) -> Decimal: + return Decimal(value) + + @staticmethod + def _convert_from_pgnumeric(value: str | None) -> Decimal: + if value is None: + return None + return Decimal(value) + + @staticmethod + def _convert_from_base64(value: str) -> str | bytes: + b = base64.b64decode(value) + try: + return b.decode('utf-8') + except: + return b + + @staticmethod + def _convert_from_datetime(value: str) -> datetime: + # suitable for yql data and datetime parsing + return dateutil.parser.isoparse(value) + + @staticmethod + def _convert_from_pgdatetime(value: str | None) -> datetime: + if value is None: + return None + return dateutil.parser.isoparse(value) + + @staticmethod + def _convert_from_enum(value: list) -> str: + return str(value[0]) + + @staticmethod + def _extract_from_optional(type: str) -> str: + # Uint16? -> Uint16 + if type.endswith("?"): + return type[0:-1] + + # Optional -> Uint16 + return type[len("Optional<"):-1] + + @staticmethod + def _extract_from_set(type: str) -> str: + # Set -> Uint16 + return type[len("Set<"):-1] + + @staticmethod + def _extract_from_list(type: str) -> str: + # List -> Uint16 + return type[len("List<"):-1] + + @staticmethod + def _split_type_list(type_list: str) -> list[str]: + # naive implementation + # todo fix it + return type_list.split(",") + + @staticmethod + def _extract_from_tuple(type: str) -> str: + # Tuple -> [Uint16, String, Double] + return YQResults._split_type_list(type[len("Tuple<"):-1]) + + @staticmethod + def _extract_from_dict(type: str) -> (str, str): + # Dict -> (Uint16, String) + [key_type, value_type] = YQResults._split_type_list(type[len("Dict<"):-1]) + return key_type, value_type + + @staticmethod + def _extract_from_variant_over_struct(type: str) -> (str, str): + # Variant<'One':Int32,'Two':String> -> {One: Int32, Two: String} + types_with_names = YQResults._split_type_list(type[len("Variant<"):-1]) + result = {} + for t in types_with_names: + [n, t] = t.split(":") + # strip ' + n = n[1:-1] + result[n] = t + return result + + @staticmethod + def _extract_from_variant_over_tuple(type: str) -> (str, str): + # Variant -> [Int32, String] + return YQResults._split_type_list(type[len("Variant<"):-1]) + + @staticmethod + def _convert_from_optional(value: list[Any]) -> Optional[Any]: + # Optional types are encoded as [[]] objects + # If type is Uint16, value is encoded as {"rows":[[value]]} + # If type is Optional, value is encoded as {"rows":[[[value]]]} + # If value is None than result is {"rows":[[[]]]} + # So check if len equals 1 it means that it contains value + # if len is 0 it means it has no value i.e. value is None + assert len(value) < 2, str(value) + if len(value) == 1: + return value[0] + + return None + + @staticmethod + def id(v): + return v + + @staticmethod + def _get_converter(column_type: str) -> Any: + """Returns converter based on column type""" + + # primitives + if column_type in ["Int8", "Int16", "Int32", "Int64", + "Uint8", "Uint16", "Uint32", "Uint64", + "Bool", "Utf8", "Uuid", + "Void", "Null", + "EmptyList", "Struct<>", "Tuple<>"]: + return YQResults.id + + if column_type == "String": + return YQResults._convert_from_base64 + + if column_type in ["Float", "Double"]: + return YQResults._convert_from_float + + if column_type.startswith("Decimal("): + return YQResults._convert_from_decimal + + if column_type.startswith("Enum<"): + return YQResults._convert_from_enum + + if column_type in ["Date", "Datetime", "Timestamp"]: + return YQResults._convert_from_datetime + + # containers + if column_type.startswith("Optional<") or column_type.endswith("?"): + # If type is Optional than get base type + inner_converter = YQResults._get_converter( + YQResults._extract_from_optional(column_type)) + + # Remove "Optional" encoding + # and convert resulting value as others + def convert(x): + inner_value = YQResults._convert_from_optional(x) + if inner_value is None: + return None + return inner_converter(inner_value) + + return convert + + if column_type.startswith("Set<"): + inner_converter = YQResults._get_converter(YQResults._extract_from_set(column_type)) + + def convert(x): + return {inner_converter(v) for v in x} + + return convert + + if column_type.startswith("List<"): + inner_converter = YQResults._get_converter(YQResults._extract_from_list(column_type)) + + def convert(x): + return [inner_converter(v) for v in x] + + return convert + + if column_type.startswith("Tuple<"): + inner_types = YQResults._extract_from_tuple(column_type) + inner_converters = [YQResults._get_converter(t) for t in inner_types] + + def convert(x): + assert len(x) == len(inner_converters), f"Wrong lenght for tuple value: {len(x)} != {len(inner_converters)}" + return tuple([c(v) for (c, v) in zip(inner_converters, x)]) + + return convert + + # variant over struct + if column_type.startswith("Variant<'"): + inner_types = YQResults._extract_from_variant_over_struct(column_type) + inner_converters = {k: YQResults._get_converter(t) for k, t in inner_types.items()} + + def convert(x): + return inner_converters[x[0]](x[1]) + + return convert + + # variant over tuple + if column_type.startswith("Variant<"): + inner_types = YQResults._extract_from_variant_over_tuple(column_type) + inner_converters = [YQResults._get_converter(t) for t in inner_types] + + def convert(x): + return inner_converters[x[0]](x[1]) + + return convert + + if column_type == "EmptyDict": + def convert(x): + return {} + + return convert + + if column_type.startswith("Dict<"): + key_type, value_type = YQResults._extract_from_dict(column_type) + key_converter = YQResults._get_converter(key_type) + value_converter = YQResults._get_converter(value_type) + + def convert(x): + return {key_converter(v[0]): value_converter(v[1]) for v in x} + + return convert + + # pg types + if column_type.startswith("pgfloat"): + return YQResults._convert_from_pgfloat + + if column_type in ["pgint2", "pgint4", "pgint8"]: + return YQResults._convert_from_pgint + + if column_type == "pgnumeric": + return YQResults._convert_from_pgnumeric + + if column_type in ["pgdate", "pgtimestamp"]: + return YQResults._convert_from_pgdatetime + + if column_type.startswith("pg"): + return YQResults.id + + # unsupported type + return YQResults.id + + @staticmethod + def _convert_single(results: dict[str, Any]) -> dict[str, Any]: + converters = [] + converted_results = [] + for column in results["columns"]: + converters.append(YQResults._get_converter(column["type"])) + + for row in results["rows"]: + new_row = [] + for index, value in enumerate(row): + converter = converters[index] + new_row.append( + value if converter is None else converter(value)) + + converted_results.append(new_row) + + return {"rows": converted_results, "columns": results["columns"]} + + def _repr_pretty_(self, p, cycle): + p.text(pprint.pformat(self._results)) + + @property + def results(self): + if self._results is None: + self._results = self._convert() + + return self._results + + @property + def raw_results(self): + return self._raw_results + + def to_table(self, index: int = 0): + return self.results[index]["rows"] + + def to_dataframe(self, index: int = 0): + result_set = self.results[index] + columns = [column["name"] for column in result_set["columns"]] + import pandas + return pandas.DataFrame(result_set["rows"], columns=columns) From f9d7e72b7fbb4005e51742458ff77bfc78c79cd6 Mon Sep 17 00:00:00 2001 From: Sergey Uzhakov Date: Tue, 13 Feb 2024 12:57:46 +0300 Subject: [PATCH 06/34] add token_prefix, format exception message --- airflow/providers/yandex/hooks/http_client.py | 20 +++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/airflow/providers/yandex/hooks/http_client.py b/airflow/providers/yandex/hooks/http_client.py index 070336435747a..0c7f52d804f2d 100644 --- a/airflow/providers/yandex/hooks/http_client.py +++ b/airflow/providers/yandex/hooks/http_client.py @@ -43,13 +43,14 @@ def __init__(self, # urls should not contain trailing / self.endpoint: str = "https://api.yandex-query.cloud.yandex.net" self.web_base_url: str = "https://yq.cloud.yandex.ru" + self.token_prefix = "Bearer " class YQHttpClientException(Exception): - def __init__(self, prefix_message: str, status: str, message: str, details: Any) -> None: - super().__init__(f"{prefix_message} : {message}") + def __init__(self, message: str, status: str, msg: str, details: Any) -> None: + super().__init__(message) self.status = status - self.message = message + self.msg = msg self.details = details @@ -66,7 +67,7 @@ def __exit__(self, *args): def _build_headers(self, idempotency_key=None, request_id=None) -> dict[str, str]: headers = { - "Authorization": f"Bearer {self.config.token}" + "Authorization": f"{self.config.token_prefix}{self.config.token}" } if idempotency_key is not None: headers["Idempotency-Key"] = idempotency_key @@ -97,10 +98,13 @@ def _validate_http_error(self, response, expected_code=200) -> None: if response.status_code != expected_code: if response.headers.get("Content-Type", "").startswith("application/json"): body = response.json() - raise YQHttpClientException(f"Error occurred: {response.status_code}", - status=body.get("status"), - message=body.get("message"), - details=body.get("details") + status = body.get("status") + msg = body.get("message") + details = body.get("details") + raise YQHttpClientException(f"Error occurred. http code={response.status_code}, status={status}, msg={msg}, details={details}", + status=status, + msg=msg, + details=details ) raise YQHttpClientException(f"Error occurred: {response.status_code}, {response.text}") From 318f652a64064c21a286f55ca7b4454c634bf6da Mon Sep 17 00:00:00 2001 From: Sergey Uzhakov Date: Tue, 13 Feb 2024 15:34:46 +0300 Subject: [PATCH 07/34] use YQResults inside client --- airflow/providers/yandex/hooks/http_client.py | 9 ++++-- .../hooks/{yq_results.py => query_results.py} | 29 +++++++++---------- .../providers/yandex/hooks/yandexcloud_yq.py | 3 +- 3 files changed, 22 insertions(+), 19 deletions(-) rename airflow/providers/yandex/hooks/{yq_results.py => query_results.py} (92%) diff --git a/airflow/providers/yandex/hooks/http_client.py b/airflow/providers/yandex/hooks/http_client.py index 0c7f52d804f2d..efc69e0a3b766 100644 --- a/airflow/providers/yandex/hooks/http_client.py +++ b/airflow/providers/yandex/hooks/http_client.py @@ -9,6 +9,8 @@ from typing import Any from urllib3.util.retry import Retry +from .query_results import YQResults + MAX_RETRY_FOR_SESSION = 4 BACK_OFF_FACTOR = 0.3 TIME_BETWEEN_RETRIES = 1000 @@ -259,7 +261,11 @@ def get_query_result_set(self, query_id: str, result_set_index: int, raw_format: offset += limit - return {"rows": rows, "columns": columns} + result = {"rows": rows, "columns": columns} + if raw_format: + return result + + return YQResults(result).results def get_query_all_result_sets(self, query_id: str, result_set_count: int, raw_format: bool = False) -> Any: result = list() @@ -274,7 +280,6 @@ def get_query_all_result_sets(self, query_id: str, result_set_count: int, raw_fo return r result.append(r) - return result def get_openapi_spec(self) -> str: diff --git a/airflow/providers/yandex/hooks/yq_results.py b/airflow/providers/yandex/hooks/query_results.py similarity index 92% rename from airflow/providers/yandex/hooks/yq_results.py rename to airflow/providers/yandex/hooks/query_results.py index 4098ba30b9bd7..e40afb39324fb 100644 --- a/airflow/providers/yandex/hooks/yq_results.py +++ b/airflow/providers/yandex/hooks/query_results.py @@ -1,3 +1,4 @@ +from __future__ import annotations from typing import Any, Optional import base64 import pprint @@ -9,15 +10,12 @@ class YQResults: """Holds and formats query execution results""" - def __init__(self, results: list[dict[str:Any]] | dict[str:Any]): - self._raw_results = results if results is list else [results] + def __init__(self, results: dict[str, Any]): + self._raw_results = results self._results = None - def _convert(self): - return [YQResults._convert_single(result) for result in self._raw_results] - @staticmethod - def _convert_from_float(value: float|str) -> float: + def _convert_from_float(value: float | str) -> float: # special values, e.g inf encoded as str, normal values are in float return float(value) @@ -261,14 +259,13 @@ def convert(x): # unsupported type return YQResults.id - @staticmethod - def _convert_single(results: dict[str, Any]) -> dict[str, Any]: + def _convert(self): converters = [] converted_results = [] - for column in results["columns"]: + for column in self._raw_results["columns"]: converters.append(YQResults._get_converter(column["type"])) - for row in results["rows"]: + for row in self._raw_results["rows"]: new_row = [] for index, value in enumerate(row): converter = converters[index] @@ -277,7 +274,7 @@ def _convert_single(results: dict[str, Any]) -> dict[str, Any]: converted_results.append(new_row) - return {"rows": converted_results, "columns": results["columns"]} + self._results = {"rows": converted_results, "columns": self._raw_results["columns"]} def _repr_pretty_(self, p, cycle): p.text(pprint.pformat(self._results)) @@ -285,7 +282,7 @@ def _repr_pretty_(self, p, cycle): @property def results(self): if self._results is None: - self._results = self._convert() + self._convert() return self._results @@ -293,11 +290,11 @@ def results(self): def raw_results(self): return self._raw_results - def to_table(self, index: int = 0): - return self.results[index]["rows"] + def to_table(self): + return self._results["rows"] - def to_dataframe(self, index: int = 0): - result_set = self.results[index] + def to_dataframe(self): + result_set = self._results columns = [column["name"] for column in result_set["columns"]] import pandas return pandas.DataFrame(result_set["rows"], columns=columns) diff --git a/airflow/providers/yandex/hooks/yandexcloud_yq.py b/airflow/providers/yandex/hooks/yandexcloud_yq.py index e1329e85bfd4e..168de52182659 100644 --- a/airflow/providers/yandex/hooks/yandexcloud_yq.py +++ b/airflow/providers/yandex/hooks/yandexcloud_yq.py @@ -40,10 +40,11 @@ requests_log.propagate = True from airflow.providers.yandex.hooks.yandex import YandexCloudBaseHook -from airflow.providers.yandex.hooks.http_client import YQHttpClientConfig, YQHttpClient from airflow.exceptions import AirflowException from airflow.providers.yandex.utils.user_agent import provider_user_agent +from .http_client import YQHttpClientConfig, YQHttpClient + class QueryType(Enum): ANALYTICS = 1 STREAMING = 2 From 8faf664792788e8818cfcb0fb041f4dcecb39ee6 Mon Sep 17 00:00:00 2001 From: Sergey Uzhakov Date: Tue, 13 Feb 2024 20:16:30 +0300 Subject: [PATCH 08/34] add tests, fix provider.yaml --- airflow/providers/yandex/provider.yaml | 2 + .../yandex/hooks/test_yandexcloud_yq.py | 149 ++++++++++++++++++ 2 files changed, 151 insertions(+) create mode 100644 tests/providers/yandex/hooks/test_yandexcloud_yq.py diff --git a/airflow/providers/yandex/provider.yaml b/airflow/providers/yandex/provider.yaml index 30d556f0c81d7..a433b3c72d443 100644 --- a/airflow/providers/yandex/provider.yaml +++ b/airflow/providers/yandex/provider.yaml @@ -88,6 +88,8 @@ hooks: - integration-name: Yandex.Cloud YQ python-modules: - airflow.providers.yandex.hooks.yandexcloud_yq + - airflow.providers.yandex.hooks.http_client + - airflow.providers.yandex.hooks.query_results connection-types: - hook-class-name: airflow.providers.yandex.hooks.yandex.YandexCloudBaseHook diff --git a/tests/providers/yandex/hooks/test_yandexcloud_yq.py b/tests/providers/yandex/hooks/test_yandexcloud_yq.py new file mode 100644 index 0000000000000..8acfc81ceade7 --- /dev/null +++ b/tests/providers/yandex/hooks/test_yandexcloud_yq.py @@ -0,0 +1,149 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import json +import responses +from decimal import Decimal +from responses import matchers +from unittest import mock + +from airflow.models import Connection +from airflow.providers.yandex.hooks.yandexcloud_yq import YQHook + +OAUTH_TOKEN = "my_oauth_token" +SERVICE_ACCOUNT_AUTH_KEY_JSON = """{"id":"my_id", "service_account_id":"my_sa1", "private_key":"-----BEGIN PRIVATE KEY----- my_pk"}""" + +class TestYandexCloudYqHook: + def _init_hook(self): + with mock.patch("airflow.hooks.base.BaseHook.get_connection") as mock_get_connection: + mock_get_connection.return_value = self.connection + self.hook = YQHook(default_folder_id="my_folder_id") + + def setup_method(self): + #self.connection = Connection(extra=json.dumps({"oauth": OAUTH_TOKEN})) + self.connection = Connection(extra=json.dumps({"service_account_json": SERVICE_ACCOUNT_AUTH_KEY_JSON})) + + @responses.activate() + @mock.patch("jwt.encode") + def test_simple_select_via_iam(self, mock_jwt): + responses.post( + "https://iam.api.cloud.yandex.net/iam/v1/tokens", + json={"iamToken": "super_token"}, + status=200, + ) + mock_jwt.return_value = "zzzz" + + responses.post( + "https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries", + match=[ + matchers.header_matcher({"Content-Type": "application/json", "Authorization": "Bearer super_token"}), + matchers.query_param_matcher({"project": "my_folder_id"}) + ], + json={"id": "query1"}, + status=200, + ) + + responses.get( + "https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries/query1/status", + json={"status": "COMPLETED"}, + status=200, + ) + + responses.get( + "https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries/query1", + json={"id": "query1", "result_sets": [{"rows_count": 1, "truncated": False}]}, + status=200, + ) + + responses.get( + "https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries/query1/results/0", + json={"rows": [[777]], "columns": [{"name": "column0", "type": "Int32"}]}, + status=200, + ) + + self._init_hook() + query_id = self.hook.create_query(query_text="select 777", name="my query", description="my desc") + assert query_id == "query1" + + results = self.hook.wait_results(query_id) + assert results == {"rows": [[777]], "columns": [ + {"name": "column0", "type": "Int32"}]} + + @responses.activate() + @mock.patch("jwt.encode") + def test_integral_results(self, mock_jwt): + responses.post( + "https://iam.api.cloud.yandex.net/iam/v1/tokens", + json={"iamToken": "super_token"}, + status=200, + ) + mock_jwt.return_value = "zzzz" + + responses.post( + "https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries", + match=[ + matchers.header_matcher({"Content-Type": "application/json", "Authorization": "Bearer super_token"}), + matchers.query_param_matcher({"project": "my_folder_id"}) + ], + json={"id": "query1"}, + status=200, + ) + + responses.get( + "https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries/query1/status", + json={"status": "COMPLETED"}, + status=200, + ) + + responses.get( + "https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries/query1", + json={"id": "query1", "result_sets": [{"rows_count": 1, "truncated": False}]}, + status=200, + ) + + responses.get( + "https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries/query1/results/0", + json={ + "rows":[[100,-100,200,200,10000000000,-20000000000,"18014398509481984","-18014398509481984",123.5,-789.125,"inf",True,False,"aGVsbG8=","hello","1.23","he\"llo_again","Я Привет",1,2,3,4]], + "columns":[{"name":"column0","type":"Int32"},{"name":"column1","type":"Int32"},{"name":"column2","type":"Int64"},{"name":"column3","type":"Uint64"},{"name":"column4","type":"Uint64"},{"name":"column5","type":"Int64"},{"name":"column6","type":"Int64"},{"name":"column7","type":"Int64"},{"name":"column8","type":"Float"},{"name":"column9","type":"Double"},{"name":"column10","type":"Double"},{"name":"column11","type":"Bool"},{"name":"column12","type":"Bool"},{"name":"column13","type":"String"},{"name":"column14","type":"Utf8"},{"name":"column15","type":"Decimal(6,3)"},{"name":"column16","type":"Utf8"},{"name":"column17","type":"Utf8"},{"name":"column18","type":"Int8"},{"name":"column19","type":"Int16"},{"name":"column20","type":"Uint8"},{"name":"column21","type":"Uint16"}] + }, + status=200, + ) + + self._init_hook() + query_id = self.hook.create_query(query_text="complex_query1", name="my query", description="my desc") + assert query_id == "query1" + + results = self.hook.wait_results(query_id) + assert results == { + "rows": [ + [ + 100, -100, + 200, 200, + 10000000000, -20000000000, + "18014398509481984", "-18014398509481984", + 123.5, -789.125, + float("inf"), True, + False, "hello", + "hello", Decimal("1.23"), + "he\"llo_again", "Я Привет", + 1, 2, 3, 4 + ] + ], + "columns":[{"name":"column0","type":"Int32"},{"name":"column1","type":"Int32"},{"name":"column2","type":"Int64"},{"name":"column3","type":"Uint64"},{"name":"column4","type":"Uint64"},{"name":"column5","type":"Int64"},{"name":"column6","type":"Int64"},{"name":"column7","type":"Int64"},{"name":"column8","type":"Float"},{"name":"column9","type":"Double"},{"name":"column10","type":"Double"},{"name":"column11","type":"Bool"},{"name":"column12","type":"Bool"},{"name":"column13","type":"String"},{"name":"column14","type":"Utf8"},{"name":"column15","type":"Decimal(6,3)"},{"name":"column16","type":"Utf8"},{"name":"column17","type":"Utf8"},{"name":"column18","type":"Int8"},{"name":"column19","type":"Int16"},{"name":"column20","type":"Uint8"},{"name":"column21","type":"Uint16"}] + } \ No newline at end of file From d5f0e3f2f19dd02a2567896661a516c9f74f8863 Mon Sep 17 00:00:00 2001 From: Sergey Uzhakov Date: Wed, 14 Feb 2024 12:20:56 +0300 Subject: [PATCH 09/34] fix oauth token usage, add tests for complex results --- .../providers/yandex/hooks/yandexcloud_yq.py | 8 +- .../yandex/hooks/test_yandexcloud_yq.py | 139 +++++++++++++----- 2 files changed, 106 insertions(+), 41 deletions(-) diff --git a/airflow/providers/yandex/hooks/yandexcloud_yq.py b/airflow/providers/yandex/hooks/yandexcloud_yq.py index 168de52182659..3da21ef01e9a3 100644 --- a/airflow/providers/yandex/hooks/yandexcloud_yq.py +++ b/airflow/providers/yandex/hooks/yandexcloud_yq.py @@ -79,7 +79,7 @@ def create_query(self, query_text: str|None, name: str|None=None, description: s ) def stop_query(self, query_id: str) -> None: - self.stop_query(query_id) + self.client.stop_query(query_id) def get_query(self, query_id: str) -> Any: return self.client.get_query(query_id) @@ -97,11 +97,11 @@ def wait_results(self, query_id: str) -> Any: return self.client.get_query_all_result_sets(query_id=query_id, result_set_count=result_set_count) def get_iam_token(self) -> str: - if "oauth" in self.credentials: - return self.credentials["oauth"] + if "token" in self.credentials: + return self.credentials["token"] if "service_account_key" in self.credentials: return YQHook._resolve_service_account_key(self.credentials["service_account_key"]) - raise AirflowException(f"Unknown credentials type {self.credentials.keys()}") + raise AirflowException(f"Unknown credentials type, available keys {self.credentials.keys()}") def compose_query_web_link(self, query_id:str): return self.client.compose_query_web_link(query_id) diff --git a/tests/providers/yandex/hooks/test_yandexcloud_yq.py b/tests/providers/yandex/hooks/test_yandexcloud_yq.py index 8acfc81ceade7..9475184705655 100644 --- a/tests/providers/yandex/hooks/test_yandexcloud_yq.py +++ b/tests/providers/yandex/hooks/test_yandexcloud_yq.py @@ -18,6 +18,8 @@ import json import responses +from datetime import datetime +from dateutil.tz import tzutc from decimal import Decimal from responses import matchers from unittest import mock @@ -35,12 +37,9 @@ def _init_hook(self): self.hook = YQHook(default_folder_id="my_folder_id") def setup_method(self): - #self.connection = Connection(extra=json.dumps({"oauth": OAUTH_TOKEN})) - self.connection = Connection(extra=json.dumps({"service_account_json": SERVICE_ACCOUNT_AUTH_KEY_JSON})) + self.connection = Connection(extra={"service_account_json": SERVICE_ACCOUNT_AUTH_KEY_JSON}) - @responses.activate() - @mock.patch("jwt.encode") - def test_simple_select_via_iam(self, mock_jwt): + def setup_mocks_for_query_execution(self, mock_jwt, query_results_json): responses.post( "https://iam.api.cloud.yandex.net/iam/v1/tokens", json={"iamToken": "super_token"}, @@ -58,6 +57,12 @@ def test_simple_select_via_iam(self, mock_jwt): status=200, ) + responses.get( + "https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries/query1/status", + json={"status": "RUNNING"}, + status=200, + ) + responses.get( "https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries/query1/status", json={"status": "COMPLETED"}, @@ -72,57 +77,64 @@ def test_simple_select_via_iam(self, mock_jwt): responses.get( "https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries/query1/results/0", - json={"rows": [[777]], "columns": [{"name": "column0", "type": "Int32"}]}, + json=query_results_json, + status=200, + ) + + @responses.activate() + def test_oauth_token_usage(self): + responses.post( + "https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries", + match=[ + matchers.header_matcher({"Content-Type": "application/json", "Authorization": f"Bearer {OAUTH_TOKEN}"}), + matchers.query_param_matcher({"project": "my_folder_id"}) + ], + json={"id": "query1"}, status=200, ) + self.connection = Connection(extra={"oauth": OAUTH_TOKEN}) + self._init_hook() + query_id = self.hook.create_query(query_text="select 777") + assert query_id == "query1" + + @responses.activate() + @mock.patch("jwt.encode") + def test_select_results(self, mock_jwt): + self.setup_mocks_for_query_execution(mock_jwt, {"rows": [[777]], "columns": [{"name": "column0", "type": "Int32"}]}) + self._init_hook() query_id = self.hook.create_query(query_text="select 777", name="my query", description="my desc") assert query_id == "query1" + assert self.hook.compose_query_web_link(query_id) == "https://yq.cloud.yandex.ru/folders/my_folder_id/ide/queries/query1" + results = self.hook.wait_results(query_id) assert results == {"rows": [[777]], "columns": [ {"name": "column0", "type": "Int32"}]} - @responses.activate() - @mock.patch("jwt.encode") - def test_integral_results(self, mock_jwt): - responses.post( - "https://iam.api.cloud.yandex.net/iam/v1/tokens", - json={"iamToken": "super_token"}, - status=200, - ) - mock_jwt.return_value = "zzzz" - responses.post( - "https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries", + "https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries/query1/stop", match=[ matchers.header_matcher({"Content-Type": "application/json", "Authorization": "Bearer super_token"}), matchers.query_param_matcher({"project": "my_folder_id"}) ], - json={"id": "query1"}, - status=200, - ) - - responses.get( - "https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries/query1/status", - json={"status": "COMPLETED"}, - status=200, - ) - - responses.get( - "https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries/query1", - json={"id": "query1", "result_sets": [{"rows_count": 1, "truncated": False}]}, - status=200, + status=204, ) - responses.get( - "https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries/query1/results/0", - json={ + assert self.hook.get_query_status(query_id) == "COMPLETED" + assert self.hook.get_query(query_id) == {"id": "query1", "result_sets": [{"rows_count": 1, "truncated": False}]} + self.hook.stop_query(query_id) + + @responses.activate() + @mock.patch("jwt.encode") + def test_integral_results(self, mock_jwt): + # json response and results could be found here: https://github.com/ydb-platform/ydb/blob/284b7efb67edcdade0b12c849b7fad40739ad62b/ydb/tests/fq/http_api/test_http_api.py#L336 + self.setup_mocks_for_query_execution(mock_jwt, + { "rows":[[100,-100,200,200,10000000000,-20000000000,"18014398509481984","-18014398509481984",123.5,-789.125,"inf",True,False,"aGVsbG8=","hello","1.23","he\"llo_again","Я Привет",1,2,3,4]], "columns":[{"name":"column0","type":"Int32"},{"name":"column1","type":"Int32"},{"name":"column2","type":"Int64"},{"name":"column3","type":"Uint64"},{"name":"column4","type":"Uint64"},{"name":"column5","type":"Int64"},{"name":"column6","type":"Int64"},{"name":"column7","type":"Int64"},{"name":"column8","type":"Float"},{"name":"column9","type":"Double"},{"name":"column10","type":"Double"},{"name":"column11","type":"Bool"},{"name":"column12","type":"Bool"},{"name":"column13","type":"String"},{"name":"column14","type":"Utf8"},{"name":"column15","type":"Decimal(6,3)"},{"name":"column16","type":"Utf8"},{"name":"column17","type":"Utf8"},{"name":"column18","type":"Int8"},{"name":"column19","type":"Int16"},{"name":"column20","type":"Uint8"},{"name":"column21","type":"Uint16"}] - }, - status=200, + } ) self._init_hook() @@ -146,4 +158,57 @@ def test_integral_results(self, mock_jwt): ] ], "columns":[{"name":"column0","type":"Int32"},{"name":"column1","type":"Int32"},{"name":"column2","type":"Int64"},{"name":"column3","type":"Uint64"},{"name":"column4","type":"Uint64"},{"name":"column5","type":"Int64"},{"name":"column6","type":"Int64"},{"name":"column7","type":"Int64"},{"name":"column8","type":"Float"},{"name":"column9","type":"Double"},{"name":"column10","type":"Double"},{"name":"column11","type":"Bool"},{"name":"column12","type":"Bool"},{"name":"column13","type":"String"},{"name":"column14","type":"Utf8"},{"name":"column15","type":"Decimal(6,3)"},{"name":"column16","type":"Utf8"},{"name":"column17","type":"Utf8"},{"name":"column18","type":"Int8"},{"name":"column19","type":"Int16"},{"name":"column20","type":"Uint8"},{"name":"column21","type":"Uint16"}] - } \ No newline at end of file + } + + @responses.activate() + @mock.patch("jwt.encode") + def test_complex_results(self, mock_jwt): + # json response and results could be found here: https://github.com/ydb-platform/ydb/blob/284b7efb67edcdade0b12c849b7fad40739ad62b/ydb/tests/fq/http_api/test_http_api.py#L445 + self.setup_mocks_for_query_execution(mock_jwt, + { + "rows": [[[], [1, 2], [], [["YWJj", 1]], [["xyz", 1]], None, "PT15M", "2019-09-16", "2019-09-16T10:46:05Z", "2019-09-16T11:27:44.345849Z", "2019-09-16,Europe/Moscow", "2019-09-16T14:32:40,Europe/Moscow", "2019-09-16T14:32:55.874913,Europe/Moscow", ["One", 12], [1, "eHl6"], ["a", 1], ["monday", None], 1, {}, {"a": 1, "b": "xyz"}, None, None, [[[1, [[177]]]]], [[[1, []]]], [[[1, []]]], ["Foo", None], ["Bar", None], [], [1, "cHJpdmV0", "2019-09-16"]]], + "columns": [{"name": "column0", "type": "EmptyList"}, {"name": "column1", "type": "List"}, {"name": "column2", "type": "EmptyDict"}, {"name": "column3", "type": "Dict"}, {"name": "column4", "type": "Dict"}, {"name": "column5", "type": "Uuid"}, {"name": "column6", "type": "Interval"}, {"name": "column7", "type": "Date"}, {"name": "column8", "type": "Datetime"}, {"name": "column9", "type": "Timestamp"}, {"name": "column10", "type": "TzDate"}, {"name": "column11", "type": "TzDatetime"}, {"name": "column12", "type": "TzTimestamp"}, {"name": "column13", "type": "Variant<'One':Int32,'Two':String>"}, {"name": "column14", "type": "Variant"}, {"name": "column15", "type": "Variant<'a':Int32>"}, {"name": "column16", "type": "Enum<'monday'>"}, {"name": "column17", "type": "Tagged"}, {"name": "column18", "type": "Struct<>"}, {"name": "column19", "type": "Struct<'a':Int32,'b':Utf8>"}, {"name": "column20", "type": "Void"}, {"name": "column21", "type": "Null"}, {"name": "column22", "type": "Optional?>"}, {"name": "column23", "type": "Optional?>"}, {"name": "column24", "type": "Optional?>"}, {"name": "column25", "type": "Enum<'Bar','Foo'>"}, {"name": "column26", "type": "Enum<'Bar','Foo'>"}, {"name": "column27", "type": "Tuple<>"}, {"name": "column28", "type": "Tuple"}] + } + ) + + self._init_hook() + query_id = self.hook.create_query(query_text="complex_query1", name="my query", description="my desc") + assert query_id == "query1" + + results = self.hook.wait_results(query_id) + assert results == { + "rows": [ + [ + [], + [1, 2], + {}, + {"abc": 1}, + {"xyz": 1}, + None, # seems like http api doesn't support uuid values + "PT15M", + datetime(2019, 9, 16, 0, 0), + datetime(2019, 9, 16, 10, 46, 5, tzinfo=tzutc()), + datetime(2019, 9, 16, 11, 27, 44, 345849, tzinfo=tzutc()), + "2019-09-16,Europe/Moscow", + "2019-09-16T14:32:40,Europe/Moscow", + "2019-09-16T14:32:55.874913,Europe/Moscow", + 12, + "xyz", + 1, + "monday", + 1, + {}, + {"a": 1, "b": "xyz"}, + None, + None, + 177, + None, + None, + "Foo", + "Bar", + [], + (1, "privet", datetime(2019, 9, 16, 0, 0)) + ] + ], + "columns": [{"name": "column0", "type": "EmptyList"}, {"name": "column1", "type": "List"}, {"name": "column2", "type": "EmptyDict"}, {"name": "column3", "type": "Dict"}, {"name": "column4", "type": "Dict"}, {"name": "column5", "type": "Uuid"}, {"name": "column6", "type": "Interval"}, {"name": "column7", "type": "Date"}, {"name": "column8", "type": "Datetime"}, {"name": "column9", "type": "Timestamp"}, {"name": "column10", "type": "TzDate"}, {"name": "column11", "type": "TzDatetime"}, {"name": "column12", "type": "TzTimestamp"}, {"name": "column13", "type": "Variant<'One':Int32,'Two':String>"}, {"name": "column14", "type": "Variant"}, {"name": "column15", "type": "Variant<'a':Int32>"}, {"name": "column16", "type": "Enum<'monday'>"}, {"name": "column17", "type": "Tagged"}, {"name": "column18", "type": "Struct<>"}, {"name": "column19", "type": "Struct<'a':Int32,'b':Utf8>"}, {"name": "column20", "type": "Void"}, {"name": "column21", "type": "Null"}, {"name": "column22", "type": "Optional?>"}, {"name": "column23", "type": "Optional?>"}, {"name": "column24", "type": "Optional?>"}, {"name": "column25", "type": "Enum<'Bar','Foo'>"}, {"name": "column26", "type": "Enum<'Bar','Foo'>"}, {"name": "column27", "type": "Tuple<>"}, {"name": "column28", "type": "Tuple"}] + } From 90b891cc77faff7b9c1acba2caae5c833cde9069 Mon Sep 17 00:00:00 2001 From: Sergey Uzhakov Date: Wed, 14 Feb 2024 16:57:41 +0300 Subject: [PATCH 10/34] add tests for YQ operator --- airflow/providers/yandex/hooks/http_client.py | 28 ++++- .../providers/yandex/hooks/query_results.py | 20 ++++ .../providers/yandex/hooks/yandexcloud_yq.py | 4 +- .../yandex/hooks/test_yandexcloud_yq.py | 1 - .../yandex/operators/test_yandexcloud_yq.py | 109 ++++++++++++++++++ 5 files changed, 158 insertions(+), 4 deletions(-) create mode 100644 tests/providers/yandex/operators/test_yandexcloud_yq.py diff --git a/airflow/providers/yandex/hooks/http_client.py b/airflow/providers/yandex/hooks/http_client.py index efc69e0a3b766..553d23d62c271 100644 --- a/airflow/providers/yandex/hooks/http_client.py +++ b/airflow/providers/yandex/hooks/http_client.py @@ -1,3 +1,23 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file is a copy of https://github.com/ydb-platform/ydb/tree/284b7efb67edcdade0b12c849b7fad40739ad62b/ydb/core/fq/libs/http_api_client +# It is highly recommended to modify original file first in YDB project and merge it here afterwards + from __future__ import annotations import logging @@ -209,7 +229,7 @@ def wait_query_to_succeed(self, query_id, execution_timeout=None, stop_on_timeou query = self.get_query(query_id) if status != "COMPLETED": issues = query["issues"] - raise RuntimeError(f"Query {query_id} failed", issues=issues) + raise RuntimeError(f"Query {query_id} failed with issues={issues}") return len(query["result_sets"]) @@ -289,3 +309,9 @@ def get_openapi_spec(self) -> str: def compose_query_web_link(self, query_id) -> str: return self._compose_web_url(f"/folders/{self.config.project}/ide/queries/{query_id}") + + @staticmethod + def result_set_to_dataframe(data): + import pandas as pd + column_names = [column["name"] for column in data["columns"]] + return pd.DataFrame(data["rows"], columns=column_names) diff --git a/airflow/providers/yandex/hooks/query_results.py b/airflow/providers/yandex/hooks/query_results.py index e40afb39324fb..7909071cce0dc 100644 --- a/airflow/providers/yandex/hooks/query_results.py +++ b/airflow/providers/yandex/hooks/query_results.py @@ -1,3 +1,23 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file is a copy of https://github.com/ydb-platform/ydb/tree/284b7efb67edcdade0b12c849b7fad40739ad62b/ydb/core/fq/libs/http_api_client +# It is highly recommended to modify original file first in YDB project and merge it here afterwards + from __future__ import annotations from typing import Any, Optional import base64 diff --git a/airflow/providers/yandex/hooks/yandexcloud_yq.py b/airflow/providers/yandex/hooks/yandexcloud_yq.py index 3da21ef01e9a3..ed707ac04d7d3 100644 --- a/airflow/providers/yandex/hooks/yandexcloud_yq.py +++ b/airflow/providers/yandex/hooks/yandexcloud_yq.py @@ -58,7 +58,7 @@ def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) config = YQHttpClientConfig( - token=self.get_iam_token(), + token=self._get_iam_token(), project=self.default_folder_id, user_agent=provider_user_agent() ) @@ -96,7 +96,7 @@ def wait_results(self, query_id: str) -> Any: return self.client.get_query_all_result_sets(query_id=query_id, result_set_count=result_set_count) - def get_iam_token(self) -> str: + def _get_iam_token(self) -> str: if "token" in self.credentials: return self.credentials["token"] if "service_account_key" in self.credentials: diff --git a/tests/providers/yandex/hooks/test_yandexcloud_yq.py b/tests/providers/yandex/hooks/test_yandexcloud_yq.py index 9475184705655..c7ea48500a925 100644 --- a/tests/providers/yandex/hooks/test_yandexcloud_yq.py +++ b/tests/providers/yandex/hooks/test_yandexcloud_yq.py @@ -16,7 +16,6 @@ # under the License. from __future__ import annotations -import json import responses from datetime import datetime from dateutil.tz import tzutc diff --git a/tests/providers/yandex/operators/test_yandexcloud_yq.py b/tests/providers/yandex/operators/test_yandexcloud_yq.py new file mode 100644 index 0000000000000..69ca246a27c3c --- /dev/null +++ b/tests/providers/yandex/operators/test_yandexcloud_yq.py @@ -0,0 +1,109 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from datetime import datetime, timedelta +import pytest +import responses +from responses import matchers +import re +from unittest.mock import MagicMock, call, patch + +from airflow.models.dag import DAG +from airflow.models import Connection +from airflow.providers.yandex.operators.yandexcloud_yq import YQExecuteQueryOperator + +OAUTH_TOKEN = "my_oauth_token" +FOLDER_ID = "my_folder_id" + + +class TestYQExecuteQueryOperator: + def setup_method(self): + dag_id = "test_dag" + self.dag = DAG( + dag_id, + default_args={ + "owner": "airflow", + "start_date": datetime.today(), + "end_date": datetime.today() + timedelta(days=1), + }, + schedule="@once", + ) + + @responses.activate() + @patch("airflow.hooks.base.BaseHook.get_connection") + def test_create_cluster(self, mock_get_connection): + mock_get_connection.return_value = Connection(extra={"oauth": OAUTH_TOKEN}) + operator = YQExecuteQueryOperator( + task_id="simple_sql", + sql="select 987", + folder_id="my_folder_id" + ) + context = {"ti": MagicMock()} + + responses.post( + "https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries", + match=[ + matchers.header_matcher({"Content-Type": "application/json", "Authorization": f"Bearer {OAUTH_TOKEN}"}), + matchers.query_param_matcher({"project": FOLDER_ID}) + ], + json={"id": "query1"}, + status=200, + ) + + responses.get( + "https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries/query1/status", + json={"status": "COMPLETED"}, + status=200, + ) + + responses.get( + "https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries/query1", + json={"id": "query1", "result_sets": [{"rows_count": 1, "truncated": False}]}, + status=200, + ) + + responses.get( + "https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries/query1/results/0", + json={"rows": [[777]], "columns": [{"name": "column0", "type": "Int32"}]}, + status=200, + ) + + results = operator.execute(context) + assert results == {"rows": [[777]], "columns": [{"name": "column0", "type": "Int32"}]} + + context["ti"].xcom_push.assert_has_calls( + [ + call(key="web_link", value=f"https://yq.cloud.yandex.ru/folders/{FOLDER_ID}/ide/queries/query1"), + ] + ) + + responses.get( + "https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries/query1/status", + json={"status": "ERROR"}, + status=200, + ) + + responses.get( + "https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries/query1", + json={"id": "query1", "issues": ["some error"]}, + status=200, + ) + + with pytest.raises(RuntimeError, match=re.escape("""Query query1 failed with issues=['some error']""")): + operator.execute(context) + \ No newline at end of file From 88156b2f8bf3491c682251fe3e9ce0320dc2827b Mon Sep 17 00:00:00 2001 From: Sergey Uzhakov Date: Wed, 14 Feb 2024 18:41:02 +0300 Subject: [PATCH 11/34] fix test name --- tests/providers/yandex/operators/test_yandexcloud_yq.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/providers/yandex/operators/test_yandexcloud_yq.py b/tests/providers/yandex/operators/test_yandexcloud_yq.py index 69ca246a27c3c..b984d49d72351 100644 --- a/tests/providers/yandex/operators/test_yandexcloud_yq.py +++ b/tests/providers/yandex/operators/test_yandexcloud_yq.py @@ -46,7 +46,7 @@ def setup_method(self): @responses.activate() @patch("airflow.hooks.base.BaseHook.get_connection") - def test_create_cluster(self, mock_get_connection): + def test_execute_query(self, mock_get_connection): mock_get_connection.return_value = Connection(extra={"oauth": OAUTH_TOKEN}) operator = YQExecuteQueryOperator( task_id="simple_sql", From 04e8c92e16633871cf612faaa153de70bbd3b014 Mon Sep 17 00:00:00 2001 From: Sergey Uzhakov Date: Wed, 14 Feb 2024 19:38:12 +0300 Subject: [PATCH 12/34] linting --- airflow/example_dags/airflow_dag_yq.py | 221 ------------------ .../providers/yandex/hooks/query_results.py | 84 +++---- 2 files changed, 43 insertions(+), 262 deletions(-) delete mode 100644 airflow/example_dags/airflow_dag_yq.py diff --git a/airflow/example_dags/airflow_dag_yq.py b/airflow/example_dags/airflow_dag_yq.py deleted file mode 100644 index 34f2e577dc1a3..0000000000000 --- a/airflow/example_dags/airflow_dag_yq.py +++ /dev/null @@ -1,221 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Example DAG demonstrating the usage of the BashOperator.""" -from __future__ import annotations - -import datetime - -import pendulum -import dateutil - -from airflow.models.dag import DAG -from airflow.operators.bash import BashOperator -from airflow.operators.empty import EmptyOperator -from airflow.providers.yandex.operators.yandexcloud_yq import YQExecuteQueryOperator -from airflow.operators.python import PythonOperator -from airflow.decorators import task -import base64 -# import airflow.providers.yandex.operators.yandexcloud_dataproc - -with DAG( - dag_id="yq_operator", - # schedule="@daily", - schedule_interval='30 2 * * *', - start_date=pendulum.datetime(2023, 1, 16, 19, 15, tz="UTC"), - catchup=False, - dagrun_timeout=datetime.timedelta(minutes=60) -) as dag: - folder_id = "b1gaud5b392mmmeolb0k" - run_this_last = EmptyOperator( - task_id="run_this_last", - ) - - # # [START howto_operator_bash] - # run_this = BashOperator( - # task_id="run_after_loop", - # bash_command="echo 1", - # ) - # # [END howto_operator_bash] - - # run_this >> run_this_last - - # # [END howto_operator_bash_template] - # also_run_this >> run_this_last - - # @task - # def yq_read_data(): - # operator = YQExecuteQueryOperator(task_id="samplequery2", sql="select 22 as d, 33 as t") - # return operator.execute() - - # @task - # def ydb_write_data(yq_data): - # ydb_bulk_insert = YDBBulkUpsertOperator(task_id="bulk_insert", - # endpoint="grpcs://ydb.serverless.yandexcloud.net:2135", - # database="/ru-central1/b1g8skpblkos03malf3s/etndta0jk4us20e557i7", - # table="my_table", - # column_types={"id": ydb.PrimitiveType.Uint64, - # "name": ydb.PrimitiveType.Utf8}, - # values=[{"id":1,"name":"v"}]) - # return ydb_bulk_insert.execute() - - # data = yq_read_data() - # ydb_write_data(data) - - def base64ToString(b): - return base64.b64decode(b).decode('utf-8') - - def get_column_index(columns, name): - for index, column in enumerate(columns): - if column["name"] == name: - return index - - def get_col_by_name(row, columns, name): - index = get_column_index(columns, name) - value = row[index] - # print(value) - return value - - def process_query_count_result(**kwargs): - ti = kwargs['ti'] - result = ti.xcom_pull(task_ids='get_queries_count') - # print(result) - - print(f"Incoming rows={result['rows']}") - - def process_result(**kwargs): - ti = kwargs['ti'] - result = ti.xcom_pull(task_ids='samplequery2') - - print(f"Incoming rows={result['rows']}") - - - # ydb_bulk_insert >> run_this_last - - query = """ -$parse_ingress_bytes = ($m) -> { - $t = "IngressBytes: ["; - $start = Find($m, $t); - $end = Find($m, "]", $start); - return Cast(Substring($m, $start+LEN($t), $end-$start-LEN($t)) as Uint64); - }; - -$parse_folder_id = ($m) -> { - $t = "scope: [yandexcloud://"; - $start = Find($m, $t); - $end = Find($m, "]", $start); - return Substring($m, $start+LEN($t), $end-$start-LEN($t)); - }; - -$parse_status = ($m) -> { - $t = ", status:"; - $start = Find($m, $t); - $end = LEN($m); - return String::Strip(Substring($m, $start+LEN($t), $end-$start-LEN($t))); - }; - -select * from ( -select `@timestamp` as ts, $parse_ingress_bytes(message) as ingress_bytes, $parse_folder_id(message) as folder_id, $parse_status(message) as status from (select - `yq_prod_logs_cold_projected`.* -FROM - `yq_prod_logs_cold_projected`) -where component="YQ_AUDIT" and message like "FinalStatus%" -and message like "%IngressBytes%" -and `date` between Date("{{ data_interval_start | ds }}") and Date("{{ data_interval_end | ds }}") -) -where COALESCE(folder_id,"") != "" -limit 1000; - -""" - - yq_operator2 = YQExecuteQueryOperator(task_id="samplequery2", sql=query, connection_id="yandexcloud_default", folder_id=folder_id) - yq_operator2 >> run_this_last - - query_count_queries = """ -$parse_folder_id = ($m) -> { - $t = "scope: [yandexcloud://"; - $start = Find($m, $t); - $end = Find($m, "]", $start); - return Substring($m, $start+LEN($t), $end-$start-LEN($t)); - }; - -$parse_status = ($m) -> { - $t = ", status:"; - $start = Find($m, $t); - $end = LEN($m); - return String::Strip(Substring($m, $start+LEN($t), $end-$start-LEN($t))); - }; - -$parse_query_id = ($m) -> { - $t = "query id: ["; - $start = Find($m, $t); - $end = Find($m, "]", $start); - return Substring($m, $start+LEN($t), $end-$start-LEN($t)); - }; - -select * from ( -select `@timestamp` as ts, $parse_folder_id(message) as folder_id, $parse_status(message) as status, $parse_query_id(message) as query_id from (select - `yq_prod_logs_cold_projected`.* -FROM - `yq_prod_logs_cold_projected`) -where component="YQ_AUDIT" and message like "FinalStatus%" -and `date` between Date("{{ data_interval_start | ds }}") and Date("{{ data_interval_end | ds }}") -) -where COALESCE(folder_id,"") != "" -limit 1000; - -""" - - process_query_count_task = PythonOperator( task_id='process_query_count_result', - python_callable=process_query_count_result, - provide_context=True) - - yq_operator_queries_count = YQExecuteQueryOperator(task_id="get_queries_count", sql=query_count_queries, connection_id="yandexcloud_default", folder_id=folder_id) - yq_operator_queries_count >> process_query_count_task - - # yq_operator3 = YQExecuteQueryOperator(task_id="samplequery3", sql="select 33 as d, 44 as t") - # yq_operator3 >> run_this_last - - # yq_operator4 = YQExecuteQueryOperator(task_id="samplequery4", sql="select 33 as d, 44 as t") - # # yq_operator4 >> ydb_bulk_insert - - - # yq_operator = YQExecuteQueryOperator(task_id="samplequery", sql="select 1") - # yq_operator >> yq_operator2 - - - process_result_task = PythonOperator( - task_id='process_result', - python_callable=process_result, - provide_context=True) - - yq_operator2 >> process_result_task - - - - - # # [START howto_operator_bash_skip] - # this_will_skip = BashOperator( - # task_id="this_will_skip", - # bash_command='echo "hello world"; exit 99;', - # dag=dag, - # ) - # # [END howto_operator_bash_skip] - # this_will_skip >> run_this_last - -if __name__ == "__main__": - dag.test() diff --git a/airflow/providers/yandex/hooks/query_results.py b/airflow/providers/yandex/hooks/query_results.py index 7909071cce0dc..689e337f82f3b 100644 --- a/airflow/providers/yandex/hooks/query_results.py +++ b/airflow/providers/yandex/hooks/query_results.py @@ -40,13 +40,13 @@ def _convert_from_float(value: float | str) -> float: return float(value) @staticmethod - def _convert_from_pgfloat(value: str|None) -> float: + def _convert_from_pgfloat(value: str | None) -> Optional[float]: if value is None: return None return float(value) @staticmethod - def _convert_from_pgint(value: str | None) -> int: + def _convert_from_pgint(value: str | None) -> Optional[int]: if value is None: return None return int(value) @@ -54,9 +54,9 @@ def _convert_from_pgint(value: str | None) -> int: @staticmethod def _convert_from_decimal(value: str) -> Decimal: return Decimal(value) - + @staticmethod - def _convert_from_pgnumeric(value: str | None) -> Decimal: + def _convert_from_pgnumeric(value: str | None) -> Optional[Decimal]: if value is None: return None return Decimal(value) @@ -65,7 +65,7 @@ def _convert_from_pgnumeric(value: str | None) -> Decimal: def _convert_from_base64(value: str) -> str | bytes: b = base64.b64decode(value) try: - return b.decode('utf-8') + return b.decode("utf-8") except: return b @@ -75,7 +75,7 @@ def _convert_from_datetime(value: str) -> datetime: return dateutil.parser.isoparse(value) @staticmethod - def _convert_from_pgdatetime(value: str | None) -> datetime: + def _convert_from_pgdatetime(value: str | None) -> Optional[datetime]: if value is None: return None return dateutil.parser.isoparse(value) @@ -83,47 +83,47 @@ def _convert_from_pgdatetime(value: str | None) -> datetime: @staticmethod def _convert_from_enum(value: list) -> str: return str(value[0]) - + @staticmethod - def _extract_from_optional(type: str) -> str: + def _extract_from_optional(type_name: str) -> str: # Uint16? -> Uint16 - if type.endswith("?"): - return type[0:-1] + if type_name.endswith("?"): + return type_name[0:-1] # Optional -> Uint16 - return type[len("Optional<"):-1] - + return type_name[len("Optional<"):-1] + @staticmethod - def _extract_from_set(type: str) -> str: + def _extract_from_set(type_name: str) -> str: # Set -> Uint16 - return type[len("Set<"):-1] + return type_name[len("Set<"):-1] @staticmethod - def _extract_from_list(type: str) -> str: + def _extract_from_list(type_name: str) -> str: # List -> Uint16 - return type[len("List<"):-1] - + return type_name[len("List<"):-1] + @staticmethod def _split_type_list(type_list: str) -> list[str]: # naive implementation # todo fix it return type_list.split(",") - + @staticmethod - def _extract_from_tuple(type: str) -> str: + def _extract_from_tuple(type_name: str) -> list[str]: # Tuple -> [Uint16, String, Double] - return YQResults._split_type_list(type[len("Tuple<"):-1]) - + return YQResults._split_type_list(type_name[len("Tuple<"):-1]) + @staticmethod - def _extract_from_dict(type: str) -> (str, str): + def _extract_from_dict(type_name: str) -> tuple[str, str]: # Dict -> (Uint16, String) - [key_type, value_type] = YQResults._split_type_list(type[len("Dict<"):-1]) + [key_type, value_type] = YQResults._split_type_list(type_name[len("Dict<"):-1]) return key_type, value_type - + @staticmethod - def _extract_from_variant_over_struct(type: str) -> (str, str): + def _extract_from_variant_over_struct(type_name: str) -> tuple[str, str]: # Variant<'One':Int32,'Two':String> -> {One: Int32, Two: String} - types_with_names = YQResults._split_type_list(type[len("Variant<"):-1]) + types_with_names = YQResults._split_type_list(type_name[len("Variant<"):-1]) result = {} for t in types_with_names: [n, t] = t.split(":") @@ -131,11 +131,11 @@ def _extract_from_variant_over_struct(type: str) -> (str, str): n = n[1:-1] result[n] = t return result - + @staticmethod - def _extract_from_variant_over_tuple(type: str) -> (str, str): + def _extract_from_variant_over_tuple(type_name: str) -> tuple[str, str]: # Variant -> [Int32, String] - return YQResults._split_type_list(type[len("Variant<"):-1]) + return YQResults._split_type_list(type_name[len("Variant<"):-1]) @staticmethod def _convert_from_optional(value: list[Any]) -> Optional[Any]: @@ -154,7 +154,7 @@ def _convert_from_optional(value: list[Any]) -> Optional[Any]: @staticmethod def id(v): return v - + @staticmethod def _get_converter(column_type: str) -> Any: """Returns converter based on column type""" @@ -178,10 +178,10 @@ def _get_converter(column_type: str) -> Any: if column_type.startswith("Enum<"): return YQResults._convert_from_enum - + if column_type in ["Date", "Datetime", "Timestamp"]: return YQResults._convert_from_datetime - + # containers if column_type.startswith("Optional<") or column_type.endswith("?"): # If type is Optional than get base type @@ -197,7 +197,7 @@ def convert(x): return inner_converter(inner_value) return convert - + if column_type.startswith("Set<"): inner_converter = YQResults._get_converter(YQResults._extract_from_set(column_type)) @@ -205,7 +205,7 @@ def convert(x): return {inner_converter(v) for v in x} return convert - + if column_type.startswith("List<"): inner_converter = YQResults._get_converter(YQResults._extract_from_list(column_type)) @@ -217,9 +217,10 @@ def convert(x): if column_type.startswith("Tuple<"): inner_types = YQResults._extract_from_tuple(column_type) inner_converters = [YQResults._get_converter(t) for t in inner_types] - + def convert(x): - assert len(x) == len(inner_converters), f"Wrong lenght for tuple value: {len(x)} != {len(inner_converters)}" + assert len(x) == len( + inner_converters), f"Wrong lenght for tuple value: {len(x)} != {len(inner_converters)}" return tuple([c(v) for (c, v) in zip(inner_converters, x)]) return convert @@ -228,7 +229,7 @@ def convert(x): if column_type.startswith("Variant<'"): inner_types = YQResults._extract_from_variant_over_struct(column_type) inner_converters = {k: YQResults._get_converter(t) for k, t in inner_types.items()} - + def convert(x): return inner_converters[x[0]](x[1]) @@ -238,7 +239,7 @@ def convert(x): if column_type.startswith("Variant<"): inner_types = YQResults._extract_from_variant_over_tuple(column_type) inner_converters = [YQResults._get_converter(t) for t in inner_types] - + def convert(x): return inner_converters[x[0]](x[1]) @@ -249,7 +250,7 @@ def convert(x): return {} return convert - + if column_type.startswith("Dict<"): key_type, value_type = YQResults._extract_from_dict(column_type) key_converter = YQResults._get_converter(key_type) @@ -269,10 +270,10 @@ def convert(x): if column_type == "pgnumeric": return YQResults._convert_from_pgnumeric - + if column_type in ["pgdate", "pgtimestamp"]: return YQResults._convert_from_pgdatetime - + if column_type.startswith("pg"): return YQResults.id @@ -318,3 +319,4 @@ def to_dataframe(self): columns = [column["name"] for column in result_set["columns"]] import pandas return pandas.DataFrame(result_set["rows"], columns=columns) + \ No newline at end of file From be6bbc78d3965a3f7ab06cb56acfd56f80064d6b Mon Sep 17 00:00:00 2001 From: Sergey Uzhakov Date: Wed, 14 Feb 2024 20:17:04 +0300 Subject: [PATCH 13/34] restyling --- airflow/providers/yandex/hooks/http_client.py | 12 ++++++------ airflow/providers/yandex/hooks/query_results.py | 4 ++-- airflow/providers/yandex/hooks/yandexcloud_yq.py | 4 ++-- tests/providers/yandex/hooks/test_yandexcloud_yq.py | 6 +++--- .../yandex/operators/test_yandexcloud_yq.py | 1 - 5 files changed, 13 insertions(+), 14 deletions(-) diff --git a/airflow/providers/yandex/hooks/http_client.py b/airflow/providers/yandex/hooks/http_client.py index 553d23d62c271..3159eafd1c784 100644 --- a/airflow/providers/yandex/hooks/http_client.py +++ b/airflow/providers/yandex/hooks/http_client.py @@ -69,7 +69,7 @@ def __init__(self, class YQHttpClientException(Exception): - def __init__(self, message: str, status: str, msg: str, details: Any) -> None: + def __init__(self, message: str, status: str = None, msg: str = None, details: Any = None) -> None: super().__init__(message) self.status = status self.msg = msg @@ -133,7 +133,7 @@ def _validate_http_error(self, response, expected_code=200) -> None: def create_query(self, query_text=None, - type=None, + query_type=None, name=None, description=None, idempotency_key=None, @@ -143,8 +143,8 @@ def create_query(self, if query_text is not None: body["text"] = query_text - if type is not None: - body["type"] = type + if query_type is not None: + body["type"] = query_type if name is not None: body["name"] = name @@ -284,11 +284,11 @@ def get_query_result_set(self, query_id: str, result_set_index: int, raw_format: result = {"rows": rows, "columns": columns} if raw_format: return result - + return YQResults(result).results def get_query_all_result_sets(self, query_id: str, result_set_count: int, raw_format: bool = False) -> Any: - result = list() + result = [] for i in range(0, result_set_count): r = self.get_query_result_set( query_id, diff --git a/airflow/providers/yandex/hooks/query_results.py b/airflow/providers/yandex/hooks/query_results.py index 689e337f82f3b..5644f759e7d8b 100644 --- a/airflow/providers/yandex/hooks/query_results.py +++ b/airflow/providers/yandex/hooks/query_results.py @@ -66,7 +66,7 @@ def _convert_from_base64(value: str) -> str | bytes: b = base64.b64decode(value) try: return b.decode("utf-8") - except: + except UnicodeDecodeError: return b @staticmethod @@ -106,7 +106,7 @@ def _extract_from_list(type_name: str) -> str: @staticmethod def _split_type_list(type_list: str) -> list[str]: # naive implementation - # todo fix it + # fixme: fix it return type_list.split(",") @staticmethod diff --git a/airflow/providers/yandex/hooks/yandexcloud_yq.py b/airflow/providers/yandex/hooks/yandexcloud_yq.py index ed707ac04d7d3..9c879639699a4 100644 --- a/airflow/providers/yandex/hooks/yandexcloud_yq.py +++ b/airflow/providers/yandex/hooks/yandexcloud_yq.py @@ -69,11 +69,11 @@ def close(self): self.client.close() def create_query(self, query_text: str|None, name: str|None=None, description: str | None = None, query_type: QueryType = QueryType.ANALYTICS) -> str: - type = "ANALYTICS" if query_type == QueryType.ANALYTICS else "STREAMING" + t = "ANALYTICS" if query_type == QueryType.ANALYTICS else "STREAMING" return self.client.create_query( name=name, - type=type, + query_type=t, query_text=query_text, description=description ) diff --git a/tests/providers/yandex/hooks/test_yandexcloud_yq.py b/tests/providers/yandex/hooks/test_yandexcloud_yq.py index c7ea48500a925..004af55333478 100644 --- a/tests/providers/yandex/hooks/test_yandexcloud_yq.py +++ b/tests/providers/yandex/hooks/test_yandexcloud_yq.py @@ -45,7 +45,7 @@ def setup_mocks_for_query_execution(self, mock_jwt, query_results_json): status=200, ) mock_jwt.return_value = "zzzz" - + responses.post( "https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries", match=[ @@ -111,7 +111,7 @@ def test_select_results(self, mock_jwt): results = self.hook.wait_results(query_id) assert results == {"rows": [[777]], "columns": [ {"name": "column0", "type": "Int32"}]} - + responses.post( "https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries/query1/stop", match=[ @@ -124,7 +124,7 @@ def test_select_results(self, mock_jwt): assert self.hook.get_query_status(query_id) == "COMPLETED" assert self.hook.get_query(query_id) == {"id": "query1", "result_sets": [{"rows_count": 1, "truncated": False}]} self.hook.stop_query(query_id) - + @responses.activate() @mock.patch("jwt.encode") def test_integral_results(self, mock_jwt): diff --git a/tests/providers/yandex/operators/test_yandexcloud_yq.py b/tests/providers/yandex/operators/test_yandexcloud_yq.py index b984d49d72351..a2a1a6c4ecc2d 100644 --- a/tests/providers/yandex/operators/test_yandexcloud_yq.py +++ b/tests/providers/yandex/operators/test_yandexcloud_yq.py @@ -106,4 +106,3 @@ def test_execute_query(self, mock_get_connection): with pytest.raises(RuntimeError, match=re.escape("""Query query1 failed with issues=['some error']""")): operator.execute(context) - \ No newline at end of file From 0708e52c2907ca44304c28681bff45f79982b793 Mon Sep 17 00:00:00 2001 From: Sergey Uzhakov Date: Thu, 15 Feb 2024 15:33:52 +0300 Subject: [PATCH 14/34] improve tests, fix close(), add link to YQ service --- airflow/providers/yandex/hooks/http_client.py | 5 ++++- .../providers/yandex/hooks/yandexcloud_yq.py | 15 ++++---------- .../yandex/operators/yandexcloud_yq.py | 7 +------ airflow/providers/yandex/provider.yaml | 2 +- .../yandex/hooks/test_yandexcloud_yq.py | 20 ++++++++++--------- .../yandex/operators/test_yandexcloud_yq.py | 3 ++- 6 files changed, 23 insertions(+), 29 deletions(-) diff --git a/airflow/providers/yandex/hooks/http_client.py b/airflow/providers/yandex/hooks/http_client.py index 3159eafd1c784..1cb8ae64f8828 100644 --- a/airflow/providers/yandex/hooks/http_client.py +++ b/airflow/providers/yandex/hooks/http_client.py @@ -87,6 +87,9 @@ def __enter__(self): def __exit__(self, *args): self.session.close() + def close(self): + self.session.close() + def _build_headers(self, idempotency_key=None, request_id=None) -> dict[str, str]: headers = { "Authorization": f"{self.config.token_prefix}{self.config.token}" @@ -116,7 +119,7 @@ def _compose_web_url(self, path: str) -> str: return self.config.web_base_url + path def _validate_http_error(self, response, expected_code=200) -> None: - logging.info(f"Response: {response.status_code}, {response.text}") + logging.debug("Response: %s, %s", response.status_code, response.text) if response.status_code != expected_code: if response.headers.get("Content-Type", "").startswith("application/json"): body = response.json() diff --git a/airflow/providers/yandex/hooks/yandexcloud_yq.py b/airflow/providers/yandex/hooks/yandexcloud_yq.py index 9c879639699a4..18140e889b15d 100644 --- a/airflow/providers/yandex/hooks/yandexcloud_yq.py +++ b/airflow/providers/yandex/hooks/yandexcloud_yq.py @@ -17,8 +17,7 @@ from __future__ import annotations from requests.packages.urllib3.util.retry import Retry -from datetime import timedelta, datetime -from enum import Enum +from datetime import timedelta import logging import requests import time @@ -45,9 +44,6 @@ from .http_client import YQHttpClientConfig, YQHttpClient -class QueryType(Enum): - ANALYTICS = 1 - STREAMING = 2 class YQHook(YandexCloudBaseHook): """ @@ -68,12 +64,9 @@ def __init__(self, *args, **kwargs) -> None: def close(self): self.client.close() - def create_query(self, query_text: str|None, name: str|None=None, description: str | None = None, query_type: QueryType = QueryType.ANALYTICS) -> str: - t = "ANALYTICS" if query_type == QueryType.ANALYTICS else "STREAMING" - + def create_query(self, query_text: str|None, name: str|None=None, description: str | None = None) -> str: return self.client.create_query( name=name, - query_type=t, query_text=query_text, description=description ) @@ -87,10 +80,10 @@ def get_query(self, query_id: str) -> Any: def get_query_status(self, query_id: str) -> str: return self.client.get_query_status(query_id) - def wait_results(self, query_id: str) -> Any: + def wait_results(self, query_id: str, execution_timeout: timedelta = timedelta(minutes=30)) -> Any: result_set_count = self.client.wait_query_to_succeed( query_id, - execution_timeout=timedelta(minutes=30), + execution_timeout=execution_timeout, stop_on_timeout=True ) diff --git a/airflow/providers/yandex/operators/yandexcloud_yq.py b/airflow/providers/yandex/operators/yandexcloud_yq.py index ad68d79a50eaa..c31bf84c6575a 100644 --- a/airflow/providers/yandex/operators/yandexcloud_yq.py +++ b/airflow/providers/yandex/operators/yandexcloud_yq.py @@ -16,15 +16,13 @@ # under the License. from __future__ import annotations -from dataclasses import dataclass from typing import TYPE_CHECKING, Sequence, Any -from datetime import timedelta from airflow.configuration import conf from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator from airflow.models import BaseOperator, BaseOperatorLink, XCom from airflow.models.taskinstancekey import TaskInstanceKey -from airflow.providers.yandex.hooks.yandexcloud_yq import YQHook, QueryType +from airflow.providers.yandex.hooks.yandexcloud_yq import YQHook if TYPE_CHECKING: from airflow.utils.context import Context @@ -54,7 +52,6 @@ class YQExecuteQueryOperator(SQLExecuteQueryOperator): def __init__( self, *, - type: QueryType = QueryType.ANALYTICS, name: str | None = None, description: str | None = None, folder_id: str | None = None, @@ -67,7 +64,6 @@ def __init__( super().__init__(**kwargs) self.name = name self.description = description - self.type = type self.deferrable = deferrable self.folder_id = folder_id self.connection_id = connection_id @@ -86,7 +82,6 @@ def execute(self, context: Context) -> Any: ) self.query_id = self.hook.create_query( - query_type=self.type, query_text=self.sql, name=self.name, description=self.description diff --git a/airflow/providers/yandex/provider.yaml b/airflow/providers/yandex/provider.yaml index a433b3c72d443..7a91ed9cd259b 100644 --- a/airflow/providers/yandex/provider.yaml +++ b/airflow/providers/yandex/provider.yaml @@ -63,7 +63,7 @@ integrations: tags: [service] - integration-name: Yandex.Cloud YQ - external-doc-url: https://cloud.yandex.com/dataproc + external-doc-url: https://cloud.yandex.com/en/services/query how-to-guide: - /docs/apache-airflow-providers-yandex/operators.rst logo: /integration-logos/yandex/Yandex-Cloud.png diff --git a/tests/providers/yandex/hooks/test_yandexcloud_yq.py b/tests/providers/yandex/hooks/test_yandexcloud_yq.py index 004af55333478..b84099dc776f7 100644 --- a/tests/providers/yandex/hooks/test_yandexcloud_yq.py +++ b/tests/providers/yandex/hooks/test_yandexcloud_yq.py @@ -50,7 +50,8 @@ def setup_mocks_for_query_execution(self, mock_jwt, query_results_json): "https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries", match=[ matchers.header_matcher({"Content-Type": "application/json", "Authorization": "Bearer super_token"}), - matchers.query_param_matcher({"project": "my_folder_id"}) + matchers.query_param_matcher({"project": "my_folder_id"}), + matchers.json_params_matcher({"description": "my desc", "name": "my query", "text": "select 777"}) ], json={"id": "query1"}, status=200, @@ -80,6 +81,11 @@ def setup_mocks_for_query_execution(self, mock_jwt, query_results_json): status=200, ) + def _create_test_query(self): + query_id = self.hook.create_query(query_text="select 777", name="my query", description="my desc") + assert query_id == "query1" + return query_id + @responses.activate() def test_oauth_token_usage(self): responses.post( @@ -94,8 +100,7 @@ def test_oauth_token_usage(self): self.connection = Connection(extra={"oauth": OAUTH_TOKEN}) self._init_hook() - query_id = self.hook.create_query(query_text="select 777") - assert query_id == "query1" + self._create_test_query() @responses.activate() @mock.patch("jwt.encode") @@ -103,8 +108,7 @@ def test_select_results(self, mock_jwt): self.setup_mocks_for_query_execution(mock_jwt, {"rows": [[777]], "columns": [{"name": "column0", "type": "Int32"}]}) self._init_hook() - query_id = self.hook.create_query(query_text="select 777", name="my query", description="my desc") - assert query_id == "query1" + query_id = self._create_test_query() assert self.hook.compose_query_web_link(query_id) == "https://yq.cloud.yandex.ru/folders/my_folder_id/ide/queries/query1" @@ -137,8 +141,7 @@ def test_integral_results(self, mock_jwt): ) self._init_hook() - query_id = self.hook.create_query(query_text="complex_query1", name="my query", description="my desc") - assert query_id == "query1" + query_id = self._create_test_query() results = self.hook.wait_results(query_id) assert results == { @@ -171,8 +174,7 @@ def test_complex_results(self, mock_jwt): ) self._init_hook() - query_id = self.hook.create_query(query_text="complex_query1", name="my query", description="my desc") - assert query_id == "query1" + query_id = self._create_test_query() results = self.hook.wait_results(query_id) assert results == { diff --git a/tests/providers/yandex/operators/test_yandexcloud_yq.py b/tests/providers/yandex/operators/test_yandexcloud_yq.py index a2a1a6c4ecc2d..f996bd3afff16 100644 --- a/tests/providers/yandex/operators/test_yandexcloud_yq.py +++ b/tests/providers/yandex/operators/test_yandexcloud_yq.py @@ -59,7 +59,8 @@ def test_execute_query(self, mock_get_connection): "https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries", match=[ matchers.header_matcher({"Content-Type": "application/json", "Authorization": f"Bearer {OAUTH_TOKEN}"}), - matchers.query_param_matcher({"project": FOLDER_ID}) + matchers.query_param_matcher({"project": FOLDER_ID}), + matchers.json_params_matcher({"text": "select 987"}) ], json={"id": "query1"}, status=200, From 0f7f968eccccb776bb1158247e6289bc2b4d8038 Mon Sep 17 00:00:00 2001 From: Sergey Uzhakov Date: Thu, 15 Feb 2024 15:41:47 +0300 Subject: [PATCH 15/34] trim spaces --- airflow/providers/yandex/hooks/query_results.py | 1 - 1 file changed, 1 deletion(-) diff --git a/airflow/providers/yandex/hooks/query_results.py b/airflow/providers/yandex/hooks/query_results.py index 5644f759e7d8b..d61174b737296 100644 --- a/airflow/providers/yandex/hooks/query_results.py +++ b/airflow/providers/yandex/hooks/query_results.py @@ -319,4 +319,3 @@ def to_dataframe(self): columns = [column["name"] for column in result_set["columns"]] import pandas return pandas.DataFrame(result_set["rows"], columns=columns) - \ No newline at end of file From cf4c6efb8aa9660c6bd61366e1ff2c695557d4ab Mon Sep 17 00:00:00 2001 From: Sergey Uzhakov Date: Thu, 15 Feb 2024 19:40:09 +0300 Subject: [PATCH 16/34] add docstrings, remove query description, move privates to bottom of the file --- .../providers/yandex/hooks/yandexcloud_yq.py | 59 ++++++++++++++----- .../yandex/operators/yandexcloud_yq.py | 11 ++-- .../yandex/hooks/test_yandexcloud_yq.py | 4 +- 3 files changed, 51 insertions(+), 23 deletions(-) diff --git a/airflow/providers/yandex/hooks/yandexcloud_yq.py b/airflow/providers/yandex/hooks/yandexcloud_yq.py index 18140e889b15d..da2c4646297dd 100644 --- a/airflow/providers/yandex/hooks/yandexcloud_yq.py +++ b/airflow/providers/yandex/hooks/yandexcloud_yq.py @@ -62,32 +62,62 @@ def __init__(self, *args, **kwargs) -> None: self.client: YQHttpClient = YQHttpClient(config=config) def close(self): + """Release all resources""" self.client.close() - def create_query(self, query_text: str|None, name: str|None=None, description: str | None = None) -> str: + def create_query(self, query_text: str | None, name: str | None = None) -> str: + """Create and run query. + + :param query_text: SQL text. + :param name: name for the query + """ + return self.client.create_query( name=name, query_text=query_text, - description=description ) + def wait_results(self, query_id: str, execution_timeout: timedelta = timedelta(minutes=30)) -> Any: + """Wait for query complete and get results + + :param query_id: ID of query. + :param execution_timeout: how long to wait for the query to complete. + """ + result_set_count = self.client.wait_query_to_succeed( + query_id, + execution_timeout=execution_timeout, + stop_on_timeout=True + ) + + return self.client.get_query_all_result_sets(query_id=query_id, result_set_count=result_set_count) + def stop_query(self, query_id: str) -> None: + """Stop the query. + + :param query_id: ID of the query. + """ self.client.stop_query(query_id) def get_query(self, query_id: str) -> Any: + """Get query info. + + :param query_id: ID of the query. + """ return self.client.get_query(query_id) def get_query_status(self, query_id: str) -> str: + """Get status fo the query. + + :param query_id: ID of query. + """ return self.client.get_query_status(query_id) - def wait_results(self, query_id: str, execution_timeout: timedelta = timedelta(minutes=30)) -> Any: - result_set_count = self.client.wait_query_to_succeed( - query_id, - execution_timeout=execution_timeout, - stop_on_timeout=True - ) - - return self.client.get_query_all_result_sets(query_id=query_id, result_set_count=result_set_count) + def compose_query_web_link(self, query_id: str): + """Compose web link to query in Yandex Query UI + + :param query_id: ID of query. + """ + return self.client.compose_query_web_link(query_id) def _get_iam_token(self) -> str: if "token" in self.credentials: @@ -96,12 +126,9 @@ def _get_iam_token(self) -> str: return YQHook._resolve_service_account_key(self.credentials["service_account_key"]) raise AirflowException(f"Unknown credentials type, available keys {self.credentials.keys()}") - def compose_query_web_link(self, query_id:str): - return self.client.compose_query_web_link(query_id) - @staticmethod - def _resolve_service_account_key(sa_info) -> str: - with YQHook.create_session() as session: + def _resolve_service_account_key(sa_info: dict) -> str: + with YQHook._create_session() as session: api = 'https://iam.api.cloud.yandex.net/iam/v1/tokens' now = int(time.time()) payload = { @@ -125,7 +152,7 @@ def _resolve_service_account_key(sa_info) -> str: return iam_response.json()["iamToken"] @staticmethod - def create_session() -> requests.Session: + def _create_session() -> requests.Session: session = requests.Session() session.verify = False session.timeout = 20 diff --git a/airflow/providers/yandex/operators/yandexcloud_yq.py b/airflow/providers/yandex/operators/yandexcloud_yq.py index c31bf84c6575a..efd7ba973a022 100644 --- a/airflow/providers/yandex/operators/yandexcloud_yq.py +++ b/airflow/providers/yandex/operators/yandexcloud_yq.py @@ -41,6 +41,10 @@ class YQExecuteQueryOperator(SQLExecuteQueryOperator): Executes sql code using Yandex Query service. :param sql: the SQL code to be executed as a single string + :param name: name of the query in YandexQuery + :param folder_id: cloud folder id where to create query + :param connection_id: Airflow connection ID to get parameters from + :param folder_id: cloud folder id where to create query """ operator_extra_links = (YQLink(),) @@ -53,7 +57,6 @@ def __init__( self, *, name: str | None = None, - description: str | None = None, folder_id: str | None = None, connection_id: str | None = None, public_ssh_key: str | None = None, @@ -63,7 +66,6 @@ def __init__( ) -> None: super().__init__(**kwargs) self.name = name - self.description = description self.deferrable = deferrable self.folder_id = folder_id self.connection_id = connection_id @@ -71,7 +73,7 @@ def __init__( self.service_account_id = service_account_id self.hook: YQHook | None = None - self.query_id: str | None + self.query_id: str | None = None def execute(self, context: Context) -> Any: self.hook = YQHook( @@ -83,8 +85,7 @@ def execute(self, context: Context) -> Any: self.query_id = self.hook.create_query( query_text=self.sql, - name=self.name, - description=self.description + name=self.name ) # pass to YQLink diff --git a/tests/providers/yandex/hooks/test_yandexcloud_yq.py b/tests/providers/yandex/hooks/test_yandexcloud_yq.py index b84099dc776f7..9956bfa849fd1 100644 --- a/tests/providers/yandex/hooks/test_yandexcloud_yq.py +++ b/tests/providers/yandex/hooks/test_yandexcloud_yq.py @@ -51,7 +51,7 @@ def setup_mocks_for_query_execution(self, mock_jwt, query_results_json): match=[ matchers.header_matcher({"Content-Type": "application/json", "Authorization": "Bearer super_token"}), matchers.query_param_matcher({"project": "my_folder_id"}), - matchers.json_params_matcher({"description": "my desc", "name": "my query", "text": "select 777"}) + matchers.json_params_matcher({"name": "my query", "text": "select 777"}) ], json={"id": "query1"}, status=200, @@ -82,7 +82,7 @@ def setup_mocks_for_query_execution(self, mock_jwt, query_results_json): ) def _create_test_query(self): - query_id = self.hook.create_query(query_text="select 777", name="my query", description="my desc") + query_id = self.hook.create_query(query_text="select 777", name="my query") assert query_id == "query1" return query_id From 773d996ca92067b6148b6a2b8e808eb3d8b3c366 Mon Sep 17 00:00:00 2001 From: Sergey Uzhakov Date: Thu, 15 Feb 2024 19:54:02 +0300 Subject: [PATCH 17/34] fix last newline --- airflow/providers/yandex/hooks/yandex.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/providers/yandex/hooks/yandex.py b/airflow/providers/yandex/hooks/yandex.py index 8e508ccc6ca8c..251a47b7b8d93 100644 --- a/airflow/providers/yandex/hooks/yandex.py +++ b/airflow/providers/yandex/hooks/yandex.py @@ -157,4 +157,4 @@ def _get_endpoint(self) -> dict[str, str]: def _get_field(self, field_name: str, default: Any = None) -> Any: if not hasattr(self, "extras"): return default - return get_field_from_extras(self.extras, field_name, default) \ No newline at end of file + return get_field_from_extras(self.extras, field_name, default) From d23a317caeef2ebbc0692674da8d26180f1fc6da Mon Sep 17 00:00:00 2001 From: Sergey Uzhakov Date: Thu, 15 Feb 2024 19:56:08 +0300 Subject: [PATCH 18/34] restyling --- .../providers/yandex/hooks/yandexcloud_yq.py | 51 ++++++------------- 1 file changed, 15 insertions(+), 36 deletions(-) diff --git a/airflow/providers/yandex/hooks/yandexcloud_yq.py b/airflow/providers/yandex/hooks/yandexcloud_yq.py index da2c4646297dd..02162ed74a162 100644 --- a/airflow/providers/yandex/hooks/yandexcloud_yq.py +++ b/airflow/providers/yandex/hooks/yandexcloud_yq.py @@ -29,6 +29,7 @@ # The only thing missing will be the response.body which is not logged. import http.client + http.client.HTTPConnection.debuglevel = 1 # You must initialize logging, otherwise you'll not see debug output. @@ -54,9 +55,7 @@ def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) config = YQHttpClientConfig( - token=self._get_iam_token(), - project=self.default_folder_id, - user_agent=provider_user_agent() + token=self._get_iam_token(), project=self.default_folder_id, user_agent=provider_user_agent() ) self.client: YQHttpClient = YQHttpClient(config=config) @@ -67,7 +66,7 @@ def close(self): def create_query(self, query_text: str | None, name: str | None = None) -> str: """Create and run query. - + :param query_text: SQL text. :param name: name for the query """ @@ -79,42 +78,40 @@ def create_query(self, query_text: str | None, name: str | None = None) -> str: def wait_results(self, query_id: str, execution_timeout: timedelta = timedelta(minutes=30)) -> Any: """Wait for query complete and get results - + :param query_id: ID of query. :param execution_timeout: how long to wait for the query to complete. """ result_set_count = self.client.wait_query_to_succeed( - query_id, - execution_timeout=execution_timeout, - stop_on_timeout=True + query_id, execution_timeout=execution_timeout, stop_on_timeout=True ) return self.client.get_query_all_result_sets(query_id=query_id, result_set_count=result_set_count) def stop_query(self, query_id: str) -> None: """Stop the query. - + :param query_id: ID of the query. """ self.client.stop_query(query_id) def get_query(self, query_id: str) -> Any: """Get query info. - + :param query_id: ID of the query. """ return self.client.get_query(query_id) def get_query_status(self, query_id: str) -> str: """Get status fo the query. - + :param query_id: ID of query. """ return self.client.get_query_status(query_id) def compose_query_web_link(self, query_id: str): """Compose web link to query in Yandex Query UI - + :param query_id: ID of query. """ return self.client.compose_query_web_link(query_id) @@ -129,20 +126,12 @@ def _get_iam_token(self) -> str: @staticmethod def _resolve_service_account_key(sa_info: dict) -> str: with YQHook._create_session() as session: - api = 'https://iam.api.cloud.yandex.net/iam/v1/tokens' + api = "https://iam.api.cloud.yandex.net/iam/v1/tokens" now = int(time.time()) - payload = { - 'aud': api, - 'iss': sa_info["service_account_id"], - 'iat': now, - 'exp': now + 360 - } + payload = {"aud": api, "iss": sa_info["service_account_id"], "iat": now, "exp": now + 360} encoded_token = jwt.encode( - payload, - sa_info["private_key"], - algorithm='PS256', - headers={'kid': sa_info["id"]} + payload, sa_info["private_key"], algorithm="PS256", headers={"kid": sa_info["id"]} ) data = {"jwt": encoded_token} @@ -156,18 +145,8 @@ def _create_session() -> requests.Session: session = requests.Session() session.verify = False session.timeout = 20 - retry = Retry( - backoff_factor=0.3, - total=10 - ) - session.mount( - 'http://', - requests.adapters.HTTPAdapter(max_retries=retry) - ) - session.mount( - 'https://', - requests.adapters.HTTPAdapter(max_retries=retry) - ) + retry = Retry(backoff_factor=0.3, total=10) + session.mount("http://", requests.adapters.HTTPAdapter(max_retries=retry)) + session.mount("https://", requests.adapters.HTTPAdapter(max_retries=retry)) return session - \ No newline at end of file From ec9707374a3642b3c379c095bf2d749ee37e9a7c Mon Sep 17 00:00:00 2001 From: Sergey Uzhakov Date: Thu, 15 Feb 2024 20:00:40 +0300 Subject: [PATCH 19/34] restyling --- airflow/providers/yandex/hooks/http_client.py | 147 +++++----- .../providers/yandex/hooks/query_results.py | 48 ++-- .../yandex/operators/yandexcloud_yq.py | 10 +- .../yandex/hooks/test_yandexcloud_yq.py | 261 +++++++++++++++--- .../yandex/operators/test_yandexcloud_yq.py | 20 +- 5 files changed, 347 insertions(+), 139 deletions(-) diff --git a/airflow/providers/yandex/hooks/http_client.py b/airflow/providers/yandex/hooks/http_client.py index 1cb8ae64f8828..4e136c628a978 100644 --- a/airflow/providers/yandex/hooks/http_client.py +++ b/airflow/providers/yandex/hooks/http_client.py @@ -37,26 +37,30 @@ ERROR_CODES = (500, 502, 504) -def requests_retry_session(session, - retries=MAX_RETRY_FOR_SESSION, - back_off_factor=BACK_OFF_FACTOR, - status_force_list=ERROR_CODES): - retry = Retry(total=retries, read=retries, connect=retries, - backoff_factor=back_off_factor, - status_forcelist=status_force_list, - allowed_methods=frozenset(['GET', 'POST'])) +def requests_retry_session( + session, retries=MAX_RETRY_FOR_SESSION, back_off_factor=BACK_OFF_FACTOR, status_force_list=ERROR_CODES +): + retry = Retry( + total=retries, + read=retries, + connect=retries, + backoff_factor=back_off_factor, + status_forcelist=status_force_list, + allowed_methods=frozenset(["GET", "POST"]), + ) adapter = HTTPAdapter(max_retries=retry) - session.mount('http://', adapter) - session.mount('https://', adapter) + session.mount("http://", adapter) + session.mount("https://", adapter) return session class YQHttpClientConfig(object): - def __init__(self, - token: str | None = None, - project: str | None = None, - user_agent: str | None = "Python YQ HTTP SDK") -> None: - + def __init__( + self, + token: str | None = None, + project: str | None = None, + user_agent: str | None = "Python YQ HTTP SDK", + ) -> None: assert len(token) > 0, "empty token" self.token = token self.project = project @@ -91,9 +95,7 @@ def close(self): self.session.close() def _build_headers(self, idempotency_key=None, request_id=None) -> dict[str, str]: - headers = { - "Authorization": f"{self.config.token_prefix}{self.config.token}" - } + headers = {"Authorization": f"{self.config.token_prefix}{self.config.token}"} if idempotency_key is not None: headers["Idempotency-Key"] = idempotency_key @@ -126,22 +128,25 @@ def _validate_http_error(self, response, expected_code=200) -> None: status = body.get("status") msg = body.get("message") details = body.get("details") - raise YQHttpClientException(f"Error occurred. http code={response.status_code}, status={status}, msg={msg}, details={details}", - status=status, - msg=msg, - details=details - ) + raise YQHttpClientException( + f"Error occurred. http code={response.status_code}, status={status}, msg={msg}, details={details}", + status=status, + msg=msg, + details=details, + ) raise YQHttpClientException(f"Error occurred: {response.status_code}, {response.text}") - def create_query(self, - query_text=None, - query_type=None, - name=None, - description=None, - idempotency_key=None, - request_id=None, - expected_code=200): + def create_query( + self, + query_text=None, + query_type=None, + name=None, + description=None, + idempotency_key=None, + request_id=None, + expected_code=200, + ): body = dict() if query_text is not None: body["text"] = query_text @@ -155,11 +160,12 @@ def create_query(self, if description is not None: body["description"] = description - response = self.session.post(self._compose_api_url("/api/fq/v1/queries"), - headers=self._build_headers(idempotency_key=idempotency_key, - request_id=request_id), - params=self._build_params(), - json=body) + response = self.session.post( + self._compose_api_url("/api/fq/v1/queries"), + headers=self._build_headers(idempotency_key=idempotency_key, request_id=request_id), + params=self._build_params(), + json=body, + ) self._validate_http_error(response, expected_code=expected_code) return response.json()["id"] @@ -168,7 +174,7 @@ def get_query_status(self, query_id, request_id=None, expected_code=200) -> Any: response = self.session.get( self._compose_api_url(f"/api/fq/v1/queries/{query_id}/status"), headers=self._build_headers(request_id=request_id), - params=self._build_params() + params=self._build_params(), ) self._validate_http_error(response, expected_code=expected_code) @@ -178,25 +184,25 @@ def get_query(self, query_id, request_id=None, expected_code=200) -> Any: response = self.session.get( self._compose_api_url(f"/api/fq/v1/queries/{query_id}"), headers=self._build_headers(request_id=request_id), - params=self._build_params() + params=self._build_params(), ) self._validate_http_error(response, expected_code=expected_code) return response.json() - def stop_query(self, - query_id: str, - idempotency_key: str | None = None, - request_id: str | None = None, - expected_code: int = 204) -> Any: - - headers = self._build_headers( - idempotency_key=idempotency_key, - request_id=request_id + def stop_query( + self, + query_id: str, + idempotency_key: str | None = None, + request_id: str | None = None, + expected_code: int = 204, + ) -> Any: + headers = self._build_headers(idempotency_key=idempotency_key, request_id=request_id) + response = self.session.post( + self._compose_api_url(f"/api/fq/v1/queries/{query_id}/stop"), + headers=headers, + params=self._build_params(), ) - response = self.session.post(self._compose_api_url(f"/api/fq/v1/queries/{query_id}/stop"), - headers=headers, - params=self._build_params()) self._validate_http_error(response, expected_code=expected_code) return response @@ -224,9 +230,7 @@ def wait_query_to_complete(self, query_id, execution_timeout=None, stop_on_timeo def wait_query_to_succeed(self, query_id, execution_timeout=None, stop_on_timeout=False) -> int: status = self.wait_query_to_complete( - query_id=query_id, - execution_timeout=execution_timeout, - stop_on_timeout=stop_on_timeout + query_id=query_id, execution_timeout=execution_timeout, stop_on_timeout=stop_on_timeout ) query = self.get_query(query_id) @@ -236,14 +240,16 @@ def wait_query_to_succeed(self, query_id, execution_timeout=None, stop_on_timeou return len(query["result_sets"]) - def get_query_result_set_page(self, - query_id, - result_set_index, - offset=None, - limit=None, - raw_format=False, - request_id=None, - expected_code=200) -> Any: + def get_query_result_set_page( + self, + query_id, + result_set_index, + offset=None, + limit=None, + raw_format=False, + request_id=None, + expected_code=200, + ) -> Any: params = self._build_params() if offset is not None: params["offset"] = offset @@ -254,7 +260,7 @@ def get_query_result_set_page(self, response = self.session.get( self._compose_api_url(f"/api/fq/v1/queries/{query_id}/results/{result_set_index}"), headers=self._build_headers(request_id=request_id), - params=params + params=params, ) self._validate_http_error(response, expected_code=expected_code) @@ -267,11 +273,7 @@ def get_query_result_set(self, query_id: str, result_set_index: int, raw_format: rows = [] while True: part = self.get_query_result_set_page( - query_id, - result_set_index=result_set_index, - offset=offset, - limit=limit, - raw_format=raw_format + query_id, result_set_index=result_set_index, offset=offset, limit=limit, raw_format=raw_format ) if columns is None: @@ -290,14 +292,12 @@ def get_query_result_set(self, query_id: str, result_set_index: int, raw_format: return YQResults(result).results - def get_query_all_result_sets(self, query_id: str, result_set_count: int, raw_format: bool = False) -> Any: + def get_query_all_result_sets( + self, query_id: str, result_set_count: int, raw_format: bool = False + ) -> Any: result = [] for i in range(0, result_set_count): - r = self.get_query_result_set( - query_id, - result_set_index=i, - raw_format=raw_format - ) + r = self.get_query_result_set(query_id, result_set_index=i, raw_format=raw_format) if result_set_count == 1: return r @@ -316,5 +316,6 @@ def compose_query_web_link(self, query_id) -> str: @staticmethod def result_set_to_dataframe(data): import pandas as pd + column_names = [column["name"] for column in data["columns"]] return pd.DataFrame(data["rows"], columns=column_names) diff --git a/airflow/providers/yandex/hooks/query_results.py b/airflow/providers/yandex/hooks/query_results.py index d61174b737296..8da38a1aeaf30 100644 --- a/airflow/providers/yandex/hooks/query_results.py +++ b/airflow/providers/yandex/hooks/query_results.py @@ -91,17 +91,17 @@ def _extract_from_optional(type_name: str) -> str: return type_name[0:-1] # Optional -> Uint16 - return type_name[len("Optional<"):-1] + return type_name[len("Optional<") : -1] @staticmethod def _extract_from_set(type_name: str) -> str: # Set -> Uint16 - return type_name[len("Set<"):-1] + return type_name[len("Set<") : -1] @staticmethod def _extract_from_list(type_name: str) -> str: # List -> Uint16 - return type_name[len("List<"):-1] + return type_name[len("List<") : -1] @staticmethod def _split_type_list(type_list: str) -> list[str]: @@ -112,18 +112,18 @@ def _split_type_list(type_list: str) -> list[str]: @staticmethod def _extract_from_tuple(type_name: str) -> list[str]: # Tuple -> [Uint16, String, Double] - return YQResults._split_type_list(type_name[len("Tuple<"):-1]) + return YQResults._split_type_list(type_name[len("Tuple<") : -1]) @staticmethod def _extract_from_dict(type_name: str) -> tuple[str, str]: # Dict -> (Uint16, String) - [key_type, value_type] = YQResults._split_type_list(type_name[len("Dict<"):-1]) + [key_type, value_type] = YQResults._split_type_list(type_name[len("Dict<") : -1]) return key_type, value_type @staticmethod def _extract_from_variant_over_struct(type_name: str) -> tuple[str, str]: # Variant<'One':Int32,'Two':String> -> {One: Int32, Two: String} - types_with_names = YQResults._split_type_list(type_name[len("Variant<"):-1]) + types_with_names = YQResults._split_type_list(type_name[len("Variant<") : -1]) result = {} for t in types_with_names: [n, t] = t.split(":") @@ -135,7 +135,7 @@ def _extract_from_variant_over_struct(type_name: str) -> tuple[str, str]: @staticmethod def _extract_from_variant_over_tuple(type_name: str) -> tuple[str, str]: # Variant -> [Int32, String] - return YQResults._split_type_list(type_name[len("Variant<"):-1]) + return YQResults._split_type_list(type_name[len("Variant<") : -1]) @staticmethod def _convert_from_optional(value: list[Any]) -> Optional[Any]: @@ -160,11 +160,24 @@ def _get_converter(column_type: str) -> Any: """Returns converter based on column type""" # primitives - if column_type in ["Int8", "Int16", "Int32", "Int64", - "Uint8", "Uint16", "Uint32", "Uint64", - "Bool", "Utf8", "Uuid", - "Void", "Null", - "EmptyList", "Struct<>", "Tuple<>"]: + if column_type in [ + "Int8", + "Int16", + "Int32", + "Int64", + "Uint8", + "Uint16", + "Uint32", + "Uint64", + "Bool", + "Utf8", + "Uuid", + "Void", + "Null", + "EmptyList", + "Struct<>", + "Tuple<>", + ]: return YQResults.id if column_type == "String": @@ -185,8 +198,7 @@ def _get_converter(column_type: str) -> Any: # containers if column_type.startswith("Optional<") or column_type.endswith("?"): # If type is Optional than get base type - inner_converter = YQResults._get_converter( - YQResults._extract_from_optional(column_type)) + inner_converter = YQResults._get_converter(YQResults._extract_from_optional(column_type)) # Remove "Optional" encoding # and convert resulting value as others @@ -220,7 +232,8 @@ def convert(x): def convert(x): assert len(x) == len( - inner_converters), f"Wrong lenght for tuple value: {len(x)} != {len(inner_converters)}" + inner_converters + ), f"Wrong lenght for tuple value: {len(x)} != {len(inner_converters)}" return tuple([c(v) for (c, v) in zip(inner_converters, x)]) return convert @@ -246,6 +259,7 @@ def convert(x): return convert if column_type == "EmptyDict": + def convert(x): return {} @@ -290,8 +304,7 @@ def _convert(self): new_row = [] for index, value in enumerate(row): converter = converters[index] - new_row.append( - value if converter is None else converter(value)) + new_row.append(value if converter is None else converter(value)) converted_results.append(new_row) @@ -318,4 +331,5 @@ def to_dataframe(self): result_set = self._results columns = [column["name"] for column in result_set["columns"]] import pandas + return pandas.DataFrame(result_set["rows"], columns=columns) diff --git a/airflow/providers/yandex/operators/yandexcloud_yq.py b/airflow/providers/yandex/operators/yandexcloud_yq.py index efd7ba973a022..e18f4133279a9 100644 --- a/airflow/providers/yandex/operators/yandexcloud_yq.py +++ b/airflow/providers/yandex/operators/yandexcloud_yq.py @@ -27,7 +27,8 @@ if TYPE_CHECKING: from airflow.utils.context import Context -XCOM_WEBLINK_KEY="web_link" +XCOM_WEBLINK_KEY = "web_link" + class YQLink(BaseOperatorLink): name = "Yandex Query" @@ -80,13 +81,10 @@ def execute(self, context: Context) -> Any: yandex_conn_id=self.connection_id, default_folder_id=self.folder_id, default_public_ssh_key=self.public_ssh_key, - default_service_account_id=self.service_account_id + default_service_account_id=self.service_account_id, ) - self.query_id = self.hook.create_query( - query_text=self.sql, - name=self.name - ) + self.query_id = self.hook.create_query(query_text=self.sql, name=self.name) # pass to YQLink web_link = self.hook.compose_query_web_link(self.query_id) diff --git a/tests/providers/yandex/hooks/test_yandexcloud_yq.py b/tests/providers/yandex/hooks/test_yandexcloud_yq.py index 9956bfa849fd1..23fef99aec20b 100644 --- a/tests/providers/yandex/hooks/test_yandexcloud_yq.py +++ b/tests/providers/yandex/hooks/test_yandexcloud_yq.py @@ -27,7 +27,10 @@ from airflow.providers.yandex.hooks.yandexcloud_yq import YQHook OAUTH_TOKEN = "my_oauth_token" -SERVICE_ACCOUNT_AUTH_KEY_JSON = """{"id":"my_id", "service_account_id":"my_sa1", "private_key":"-----BEGIN PRIVATE KEY----- my_pk"}""" +SERVICE_ACCOUNT_AUTH_KEY_JSON = ( + """{"id":"my_id", "service_account_id":"my_sa1", "private_key":"-----BEGIN PRIVATE KEY----- my_pk"}""" +) + class TestYandexCloudYqHook: def _init_hook(self): @@ -49,9 +52,11 @@ def setup_mocks_for_query_execution(self, mock_jwt, query_results_json): responses.post( "https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries", match=[ - matchers.header_matcher({"Content-Type": "application/json", "Authorization": "Bearer super_token"}), + matchers.header_matcher( + {"Content-Type": "application/json", "Authorization": "Bearer super_token"} + ), matchers.query_param_matcher({"project": "my_folder_id"}), - matchers.json_params_matcher({"name": "my query", "text": "select 777"}) + matchers.json_params_matcher({"name": "my query", "text": "select 777"}), ], json={"id": "query1"}, status=200, @@ -85,14 +90,16 @@ def _create_test_query(self): query_id = self.hook.create_query(query_text="select 777", name="my query") assert query_id == "query1" return query_id - + @responses.activate() def test_oauth_token_usage(self): responses.post( "https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries", match=[ - matchers.header_matcher({"Content-Type": "application/json", "Authorization": f"Bearer {OAUTH_TOKEN}"}), - matchers.query_param_matcher({"project": "my_folder_id"}) + matchers.header_matcher( + {"Content-Type": "application/json", "Authorization": f"Bearer {OAUTH_TOKEN}"} + ), + matchers.query_param_matcher({"project": "my_folder_id"}), ], json={"id": "query1"}, status=200, @@ -105,39 +112,97 @@ def test_oauth_token_usage(self): @responses.activate() @mock.patch("jwt.encode") def test_select_results(self, mock_jwt): - self.setup_mocks_for_query_execution(mock_jwt, {"rows": [[777]], "columns": [{"name": "column0", "type": "Int32"}]}) + self.setup_mocks_for_query_execution( + mock_jwt, {"rows": [[777]], "columns": [{"name": "column0", "type": "Int32"}]} + ) self._init_hook() query_id = self._create_test_query() - assert self.hook.compose_query_web_link(query_id) == "https://yq.cloud.yandex.ru/folders/my_folder_id/ide/queries/query1" + assert ( + self.hook.compose_query_web_link(query_id) + == "https://yq.cloud.yandex.ru/folders/my_folder_id/ide/queries/query1" + ) results = self.hook.wait_results(query_id) - assert results == {"rows": [[777]], "columns": [ - {"name": "column0", "type": "Int32"}]} + assert results == {"rows": [[777]], "columns": [{"name": "column0", "type": "Int32"}]} responses.post( "https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries/query1/stop", match=[ - matchers.header_matcher({"Content-Type": "application/json", "Authorization": "Bearer super_token"}), - matchers.query_param_matcher({"project": "my_folder_id"}) + matchers.header_matcher( + {"Content-Type": "application/json", "Authorization": "Bearer super_token"} + ), + matchers.query_param_matcher({"project": "my_folder_id"}), ], status=204, ) assert self.hook.get_query_status(query_id) == "COMPLETED" - assert self.hook.get_query(query_id) == {"id": "query1", "result_sets": [{"rows_count": 1, "truncated": False}]} + assert self.hook.get_query(query_id) == { + "id": "query1", + "result_sets": [{"rows_count": 1, "truncated": False}], + } self.hook.stop_query(query_id) @responses.activate() @mock.patch("jwt.encode") def test_integral_results(self, mock_jwt): # json response and results could be found here: https://github.com/ydb-platform/ydb/blob/284b7efb67edcdade0b12c849b7fad40739ad62b/ydb/tests/fq/http_api/test_http_api.py#L336 - self.setup_mocks_for_query_execution(mock_jwt, + self.setup_mocks_for_query_execution( + mock_jwt, { - "rows":[[100,-100,200,200,10000000000,-20000000000,"18014398509481984","-18014398509481984",123.5,-789.125,"inf",True,False,"aGVsbG8=","hello","1.23","he\"llo_again","Я Привет",1,2,3,4]], - "columns":[{"name":"column0","type":"Int32"},{"name":"column1","type":"Int32"},{"name":"column2","type":"Int64"},{"name":"column3","type":"Uint64"},{"name":"column4","type":"Uint64"},{"name":"column5","type":"Int64"},{"name":"column6","type":"Int64"},{"name":"column7","type":"Int64"},{"name":"column8","type":"Float"},{"name":"column9","type":"Double"},{"name":"column10","type":"Double"},{"name":"column11","type":"Bool"},{"name":"column12","type":"Bool"},{"name":"column13","type":"String"},{"name":"column14","type":"Utf8"},{"name":"column15","type":"Decimal(6,3)"},{"name":"column16","type":"Utf8"},{"name":"column17","type":"Utf8"},{"name":"column18","type":"Int8"},{"name":"column19","type":"Int16"},{"name":"column20","type":"Uint8"},{"name":"column21","type":"Uint16"}] - } + "rows": [ + [ + 100, + -100, + 200, + 200, + 10000000000, + -20000000000, + "18014398509481984", + "-18014398509481984", + 123.5, + -789.125, + "inf", + True, + False, + "aGVsbG8=", + "hello", + "1.23", + 'he"llo_again', + "Я Привет", + 1, + 2, + 3, + 4, + ] + ], + "columns": [ + {"name": "column0", "type": "Int32"}, + {"name": "column1", "type": "Int32"}, + {"name": "column2", "type": "Int64"}, + {"name": "column3", "type": "Uint64"}, + {"name": "column4", "type": "Uint64"}, + {"name": "column5", "type": "Int64"}, + {"name": "column6", "type": "Int64"}, + {"name": "column7", "type": "Int64"}, + {"name": "column8", "type": "Float"}, + {"name": "column9", "type": "Double"}, + {"name": "column10", "type": "Double"}, + {"name": "column11", "type": "Bool"}, + {"name": "column12", "type": "Bool"}, + {"name": "column13", "type": "String"}, + {"name": "column14", "type": "Utf8"}, + {"name": "column15", "type": "Decimal(6,3)"}, + {"name": "column16", "type": "Utf8"}, + {"name": "column17", "type": "Utf8"}, + {"name": "column18", "type": "Int8"}, + {"name": "column19", "type": "Int16"}, + {"name": "column20", "type": "Uint8"}, + {"name": "column21", "type": "Uint16"}, + ], + }, ) self._init_hook() @@ -147,30 +212,128 @@ def test_integral_results(self, mock_jwt): assert results == { "rows": [ [ - 100, -100, - 200, 200, - 10000000000, -20000000000, - "18014398509481984", "-18014398509481984", - 123.5, -789.125, - float("inf"), True, - False, "hello", - "hello", Decimal("1.23"), - "he\"llo_again", "Я Привет", - 1, 2, 3, 4 + 100, + -100, + 200, + 200, + 10000000000, + -20000000000, + "18014398509481984", + "-18014398509481984", + 123.5, + -789.125, + float("inf"), + True, + False, + "hello", + "hello", + Decimal("1.23"), + 'he"llo_again', + "Я Привет", + 1, + 2, + 3, + 4, ] ], - "columns":[{"name":"column0","type":"Int32"},{"name":"column1","type":"Int32"},{"name":"column2","type":"Int64"},{"name":"column3","type":"Uint64"},{"name":"column4","type":"Uint64"},{"name":"column5","type":"Int64"},{"name":"column6","type":"Int64"},{"name":"column7","type":"Int64"},{"name":"column8","type":"Float"},{"name":"column9","type":"Double"},{"name":"column10","type":"Double"},{"name":"column11","type":"Bool"},{"name":"column12","type":"Bool"},{"name":"column13","type":"String"},{"name":"column14","type":"Utf8"},{"name":"column15","type":"Decimal(6,3)"},{"name":"column16","type":"Utf8"},{"name":"column17","type":"Utf8"},{"name":"column18","type":"Int8"},{"name":"column19","type":"Int16"},{"name":"column20","type":"Uint8"},{"name":"column21","type":"Uint16"}] + "columns": [ + {"name": "column0", "type": "Int32"}, + {"name": "column1", "type": "Int32"}, + {"name": "column2", "type": "Int64"}, + {"name": "column3", "type": "Uint64"}, + {"name": "column4", "type": "Uint64"}, + {"name": "column5", "type": "Int64"}, + {"name": "column6", "type": "Int64"}, + {"name": "column7", "type": "Int64"}, + {"name": "column8", "type": "Float"}, + {"name": "column9", "type": "Double"}, + {"name": "column10", "type": "Double"}, + {"name": "column11", "type": "Bool"}, + {"name": "column12", "type": "Bool"}, + {"name": "column13", "type": "String"}, + {"name": "column14", "type": "Utf8"}, + {"name": "column15", "type": "Decimal(6,3)"}, + {"name": "column16", "type": "Utf8"}, + {"name": "column17", "type": "Utf8"}, + {"name": "column18", "type": "Int8"}, + {"name": "column19", "type": "Int16"}, + {"name": "column20", "type": "Uint8"}, + {"name": "column21", "type": "Uint16"}, + ], } @responses.activate() @mock.patch("jwt.encode") def test_complex_results(self, mock_jwt): # json response and results could be found here: https://github.com/ydb-platform/ydb/blob/284b7efb67edcdade0b12c849b7fad40739ad62b/ydb/tests/fq/http_api/test_http_api.py#L445 - self.setup_mocks_for_query_execution(mock_jwt, + self.setup_mocks_for_query_execution( + mock_jwt, { - "rows": [[[], [1, 2], [], [["YWJj", 1]], [["xyz", 1]], None, "PT15M", "2019-09-16", "2019-09-16T10:46:05Z", "2019-09-16T11:27:44.345849Z", "2019-09-16,Europe/Moscow", "2019-09-16T14:32:40,Europe/Moscow", "2019-09-16T14:32:55.874913,Europe/Moscow", ["One", 12], [1, "eHl6"], ["a", 1], ["monday", None], 1, {}, {"a": 1, "b": "xyz"}, None, None, [[[1, [[177]]]]], [[[1, []]]], [[[1, []]]], ["Foo", None], ["Bar", None], [], [1, "cHJpdmV0", "2019-09-16"]]], - "columns": [{"name": "column0", "type": "EmptyList"}, {"name": "column1", "type": "List"}, {"name": "column2", "type": "EmptyDict"}, {"name": "column3", "type": "Dict"}, {"name": "column4", "type": "Dict"}, {"name": "column5", "type": "Uuid"}, {"name": "column6", "type": "Interval"}, {"name": "column7", "type": "Date"}, {"name": "column8", "type": "Datetime"}, {"name": "column9", "type": "Timestamp"}, {"name": "column10", "type": "TzDate"}, {"name": "column11", "type": "TzDatetime"}, {"name": "column12", "type": "TzTimestamp"}, {"name": "column13", "type": "Variant<'One':Int32,'Two':String>"}, {"name": "column14", "type": "Variant"}, {"name": "column15", "type": "Variant<'a':Int32>"}, {"name": "column16", "type": "Enum<'monday'>"}, {"name": "column17", "type": "Tagged"}, {"name": "column18", "type": "Struct<>"}, {"name": "column19", "type": "Struct<'a':Int32,'b':Utf8>"}, {"name": "column20", "type": "Void"}, {"name": "column21", "type": "Null"}, {"name": "column22", "type": "Optional?>"}, {"name": "column23", "type": "Optional?>"}, {"name": "column24", "type": "Optional?>"}, {"name": "column25", "type": "Enum<'Bar','Foo'>"}, {"name": "column26", "type": "Enum<'Bar','Foo'>"}, {"name": "column27", "type": "Tuple<>"}, {"name": "column28", "type": "Tuple"}] - } + "rows": [ + [ + [], + [1, 2], + [], + [["YWJj", 1]], + [["xyz", 1]], + None, + "PT15M", + "2019-09-16", + "2019-09-16T10:46:05Z", + "2019-09-16T11:27:44.345849Z", + "2019-09-16,Europe/Moscow", + "2019-09-16T14:32:40,Europe/Moscow", + "2019-09-16T14:32:55.874913,Europe/Moscow", + ["One", 12], + [1, "eHl6"], + ["a", 1], + ["monday", None], + 1, + {}, + {"a": 1, "b": "xyz"}, + None, + None, + [[[1, [[177]]]]], + [[[1, []]]], + [[[1, []]]], + ["Foo", None], + ["Bar", None], + [], + [1, "cHJpdmV0", "2019-09-16"], + ] + ], + "columns": [ + {"name": "column0", "type": "EmptyList"}, + {"name": "column1", "type": "List"}, + {"name": "column2", "type": "EmptyDict"}, + {"name": "column3", "type": "Dict"}, + {"name": "column4", "type": "Dict"}, + {"name": "column5", "type": "Uuid"}, + {"name": "column6", "type": "Interval"}, + {"name": "column7", "type": "Date"}, + {"name": "column8", "type": "Datetime"}, + {"name": "column9", "type": "Timestamp"}, + {"name": "column10", "type": "TzDate"}, + {"name": "column11", "type": "TzDatetime"}, + {"name": "column12", "type": "TzTimestamp"}, + {"name": "column13", "type": "Variant<'One':Int32,'Two':String>"}, + {"name": "column14", "type": "Variant"}, + {"name": "column15", "type": "Variant<'a':Int32>"}, + {"name": "column16", "type": "Enum<'monday'>"}, + {"name": "column17", "type": "Tagged"}, + {"name": "column18", "type": "Struct<>"}, + {"name": "column19", "type": "Struct<'a':Int32,'b':Utf8>"}, + {"name": "column20", "type": "Void"}, + {"name": "column21", "type": "Null"}, + {"name": "column22", "type": "Optional?>"}, + {"name": "column23", "type": "Optional?>"}, + {"name": "column24", "type": "Optional?>"}, + {"name": "column25", "type": "Enum<'Bar','Foo'>"}, + {"name": "column26", "type": "Enum<'Bar','Foo'>"}, + {"name": "column27", "type": "Tuple<>"}, + {"name": "column28", "type": "Tuple"}, + ], + }, ) self._init_hook() @@ -208,8 +371,38 @@ def test_complex_results(self, mock_jwt): "Foo", "Bar", [], - (1, "privet", datetime(2019, 9, 16, 0, 0)) + (1, "privet", datetime(2019, 9, 16, 0, 0)), ] ], - "columns": [{"name": "column0", "type": "EmptyList"}, {"name": "column1", "type": "List"}, {"name": "column2", "type": "EmptyDict"}, {"name": "column3", "type": "Dict"}, {"name": "column4", "type": "Dict"}, {"name": "column5", "type": "Uuid"}, {"name": "column6", "type": "Interval"}, {"name": "column7", "type": "Date"}, {"name": "column8", "type": "Datetime"}, {"name": "column9", "type": "Timestamp"}, {"name": "column10", "type": "TzDate"}, {"name": "column11", "type": "TzDatetime"}, {"name": "column12", "type": "TzTimestamp"}, {"name": "column13", "type": "Variant<'One':Int32,'Two':String>"}, {"name": "column14", "type": "Variant"}, {"name": "column15", "type": "Variant<'a':Int32>"}, {"name": "column16", "type": "Enum<'monday'>"}, {"name": "column17", "type": "Tagged"}, {"name": "column18", "type": "Struct<>"}, {"name": "column19", "type": "Struct<'a':Int32,'b':Utf8>"}, {"name": "column20", "type": "Void"}, {"name": "column21", "type": "Null"}, {"name": "column22", "type": "Optional?>"}, {"name": "column23", "type": "Optional?>"}, {"name": "column24", "type": "Optional?>"}, {"name": "column25", "type": "Enum<'Bar','Foo'>"}, {"name": "column26", "type": "Enum<'Bar','Foo'>"}, {"name": "column27", "type": "Tuple<>"}, {"name": "column28", "type": "Tuple"}] + "columns": [ + {"name": "column0", "type": "EmptyList"}, + {"name": "column1", "type": "List"}, + {"name": "column2", "type": "EmptyDict"}, + {"name": "column3", "type": "Dict"}, + {"name": "column4", "type": "Dict"}, + {"name": "column5", "type": "Uuid"}, + {"name": "column6", "type": "Interval"}, + {"name": "column7", "type": "Date"}, + {"name": "column8", "type": "Datetime"}, + {"name": "column9", "type": "Timestamp"}, + {"name": "column10", "type": "TzDate"}, + {"name": "column11", "type": "TzDatetime"}, + {"name": "column12", "type": "TzTimestamp"}, + {"name": "column13", "type": "Variant<'One':Int32,'Two':String>"}, + {"name": "column14", "type": "Variant"}, + {"name": "column15", "type": "Variant<'a':Int32>"}, + {"name": "column16", "type": "Enum<'monday'>"}, + {"name": "column17", "type": "Tagged"}, + {"name": "column18", "type": "Struct<>"}, + {"name": "column19", "type": "Struct<'a':Int32,'b':Utf8>"}, + {"name": "column20", "type": "Void"}, + {"name": "column21", "type": "Null"}, + {"name": "column22", "type": "Optional?>"}, + {"name": "column23", "type": "Optional?>"}, + {"name": "column24", "type": "Optional?>"}, + {"name": "column25", "type": "Enum<'Bar','Foo'>"}, + {"name": "column26", "type": "Enum<'Bar','Foo'>"}, + {"name": "column27", "type": "Tuple<>"}, + {"name": "column28", "type": "Tuple"}, + ], } diff --git a/tests/providers/yandex/operators/test_yandexcloud_yq.py b/tests/providers/yandex/operators/test_yandexcloud_yq.py index f996bd3afff16..5bc4f835bc439 100644 --- a/tests/providers/yandex/operators/test_yandexcloud_yq.py +++ b/tests/providers/yandex/operators/test_yandexcloud_yq.py @@ -48,19 +48,17 @@ def setup_method(self): @patch("airflow.hooks.base.BaseHook.get_connection") def test_execute_query(self, mock_get_connection): mock_get_connection.return_value = Connection(extra={"oauth": OAUTH_TOKEN}) - operator = YQExecuteQueryOperator( - task_id="simple_sql", - sql="select 987", - folder_id="my_folder_id" - ) + operator = YQExecuteQueryOperator(task_id="simple_sql", sql="select 987", folder_id="my_folder_id") context = {"ti": MagicMock()} responses.post( "https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries", match=[ - matchers.header_matcher({"Content-Type": "application/json", "Authorization": f"Bearer {OAUTH_TOKEN}"}), + matchers.header_matcher( + {"Content-Type": "application/json", "Authorization": f"Bearer {OAUTH_TOKEN}"} + ), matchers.query_param_matcher({"project": FOLDER_ID}), - matchers.json_params_matcher({"text": "select 987"}) + matchers.json_params_matcher({"text": "select 987"}), ], json={"id": "query1"}, status=200, @@ -89,7 +87,9 @@ def test_execute_query(self, mock_get_connection): context["ti"].xcom_push.assert_has_calls( [ - call(key="web_link", value=f"https://yq.cloud.yandex.ru/folders/{FOLDER_ID}/ide/queries/query1"), + call( + key="web_link", value=f"https://yq.cloud.yandex.ru/folders/{FOLDER_ID}/ide/queries/query1" + ), ] ) @@ -105,5 +105,7 @@ def test_execute_query(self, mock_get_connection): status=200, ) - with pytest.raises(RuntimeError, match=re.escape("""Query query1 failed with issues=['some error']""")): + with pytest.raises( + RuntimeError, match=re.escape("""Query query1 failed with issues=['some error']""") + ): operator.execute(context) From 20fed73660ba410ae4128f4233de6653909964ea Mon Sep 17 00:00:00 2001 From: Sergey Uzhakov Date: Fri, 16 Feb 2024 17:42:45 +0300 Subject: [PATCH 20/34] refactor, restyling --- airflow/providers/yandex/__init__.py | 2 +- .../providers/yandex/hooks/yandexcloud_yq.py | 40 ++++------------ airflow/providers/yandex/links/__init__.py | 16 +++++++ airflow/providers/yandex/links/yq.py | 41 ++++++++++++++++ .../yandex/operators/yandexcloud_yq.py | 21 ++------ airflow/providers/yandex/provider.yaml | 10 ++-- .../providers/yandex/yq_client/__init__.py | 16 +++++++ .../{hooks => yq_client}/http_client.py | 10 ++-- .../{hooks => yq_client}/query_results.py | 48 +++++++++---------- generated/provider_dependencies.json | 4 +- .../yandex/operators/test_yandexcloud_yq.py | 2 +- 11 files changed, 124 insertions(+), 86 deletions(-) create mode 100644 airflow/providers/yandex/links/__init__.py create mode 100644 airflow/providers/yandex/links/yq.py create mode 100644 airflow/providers/yandex/yq_client/__init__.py rename airflow/providers/yandex/{hooks => yq_client}/http_client.py (96%) rename airflow/providers/yandex/{hooks => yq_client}/query_results.py (85%) diff --git a/airflow/providers/yandex/__init__.py b/airflow/providers/yandex/__init__.py index d50e5d867783f..690909f9ce08d 100644 --- a/airflow/providers/yandex/__init__.py +++ b/airflow/providers/yandex/__init__.py @@ -27,7 +27,7 @@ __all__ = ["__version__"] -__version__ = "3.9.0" +__version__ = "3.8.0" try: from airflow import __version__ as airflow_version diff --git a/airflow/providers/yandex/hooks/yandexcloud_yq.py b/airflow/providers/yandex/hooks/yandexcloud_yq.py index 02162ed74a162..1d0106fda207f 100644 --- a/airflow/providers/yandex/hooks/yandexcloud_yq.py +++ b/airflow/providers/yandex/hooks/yandexcloud_yq.py @@ -15,41 +15,23 @@ # specific language governing permissions and limitations # under the License. from __future__ import annotations -from requests.packages.urllib3.util.retry import Retry -from datetime import timedelta -import logging -import requests import time +from datetime import timedelta from typing import Any -import jwt - -# These two lines enable debugging at httplib level (requests->urllib3->http.client) -# You will see the REQUEST, including HEADERS and DATA, and RESPONSE with HEADERS but without DATA. -# The only thing missing will be the response.body which is not logged. - -import http.client - -http.client.HTTPConnection.debuglevel = 1 -# You must initialize logging, otherwise you'll not see debug output. -logging.basicConfig() -logging.getLogger().setLevel(logging.DEBUG) -requests_log = logging.getLogger("requests.packages.urllib3") -requests_log.setLevel(logging.DEBUG) -requests_log.propagate = True +import jwt +import requests +from requests.packages.urllib3.util.retry import Retry -from airflow.providers.yandex.hooks.yandex import YandexCloudBaseHook from airflow.exceptions import AirflowException +from airflow.providers.yandex.hooks.yandex import YandexCloudBaseHook from airflow.providers.yandex.utils.user_agent import provider_user_agent - -from .http_client import YQHttpClientConfig, YQHttpClient +from airflow.providers.yandex.yq_client.http_client import YQHttpClient, YQHttpClientConfig class YQHook(YandexCloudBaseHook): - """ - A hook for Yandex Query - """ + """A hook for Yandex Query.""" def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) @@ -61,7 +43,7 @@ def __init__(self, *args, **kwargs) -> None: self.client: YQHttpClient = YQHttpClient(config=config) def close(self): - """Release all resources""" + """Release all resources.""" self.client.close() def create_query(self, query_text: str | None, name: str | None = None) -> str: @@ -70,14 +52,13 @@ def create_query(self, query_text: str | None, name: str | None = None) -> str: :param query_text: SQL text. :param name: name for the query """ - return self.client.create_query( name=name, query_text=query_text, ) def wait_results(self, query_id: str, execution_timeout: timedelta = timedelta(minutes=30)) -> Any: - """Wait for query complete and get results + """Wait for query complete and get results. :param query_id: ID of query. :param execution_timeout: how long to wait for the query to complete. @@ -110,7 +91,7 @@ def get_query_status(self, query_id: str) -> str: return self.client.get_query_status(query_id) def compose_query_web_link(self, query_id: str): - """Compose web link to query in Yandex Query UI + """Compose web link to query in Yandex Query UI. :param query_id: ID of query. """ @@ -144,7 +125,6 @@ def _resolve_service_account_key(sa_info: dict) -> str: def _create_session() -> requests.Session: session = requests.Session() session.verify = False - session.timeout = 20 retry = Retry(backoff_factor=0.3, total=10) session.mount("http://", requests.adapters.HTTPAdapter(max_retries=retry)) session.mount("https://", requests.adapters.HTTPAdapter(max_retries=retry)) diff --git a/airflow/providers/yandex/links/__init__.py b/airflow/providers/yandex/links/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/airflow/providers/yandex/links/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/airflow/providers/yandex/links/yq.py b/airflow/providers/yandex/links/yq.py new file mode 100644 index 0000000000000..b168c5b0cf67e --- /dev/null +++ b/airflow/providers/yandex/links/yq.py @@ -0,0 +1,41 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING + +from airflow.models import BaseOperatorLink, XCom + +if TYPE_CHECKING: + from airflow.models import BaseOperator + from airflow.models.taskinstancekey import TaskInstanceKey + from airflow.utils.context import Context + +XCOM_WEBLINK_KEY = "web_link" + + +class YQLink(BaseOperatorLink): + """Web link to query in Yandex Query UI.""" + + name = "Yandex Query" + + def get_link(self, operator: BaseOperator, *, ti_key: TaskInstanceKey): + return XCom.get_value(key=XCOM_WEBLINK_KEY, ti_key=ti_key) or "https://yq.cloud.yandex.ru" + + @staticmethod + def persist(context: Context, task_instance: BaseOperator, web_link: str) -> None: + task_instance.xcom_push(context, key=XCOM_WEBLINK_KEY, value=web_link) diff --git a/airflow/providers/yandex/operators/yandexcloud_yq.py b/airflow/providers/yandex/operators/yandexcloud_yq.py index e18f4133279a9..792b76bbee762 100644 --- a/airflow/providers/yandex/operators/yandexcloud_yq.py +++ b/airflow/providers/yandex/operators/yandexcloud_yq.py @@ -16,26 +16,15 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Sequence, Any -from airflow.configuration import conf +from typing import TYPE_CHECKING, Any, Sequence from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator -from airflow.models import BaseOperator, BaseOperatorLink, XCom -from airflow.models.taskinstancekey import TaskInstanceKey from airflow.providers.yandex.hooks.yandexcloud_yq import YQHook +from airflow.providers.yandex.links.yq import YQLink if TYPE_CHECKING: from airflow.utils.context import Context -XCOM_WEBLINK_KEY = "web_link" - - -class YQLink(BaseOperatorLink): - name = "Yandex Query" - - def get_link(self, operator: BaseOperator, *, ti_key: TaskInstanceKey): - return XCom.get_value(key=XCOM_WEBLINK_KEY, ti_key=ti_key) or "https://yq.cloud.yandex.ru" - class YQExecuteQueryOperator(SQLExecuteQueryOperator): """ @@ -62,12 +51,10 @@ def __init__( connection_id: str | None = None, public_ssh_key: str | None = None, service_account_id: str | None = None, - deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs, ) -> None: super().__init__(**kwargs) self.name = name - self.deferrable = deferrable self.folder_id = folder_id self.connection_id = connection_id self.public_ssh_key = public_ssh_key @@ -88,7 +75,7 @@ def execute(self, context: Context) -> Any: # pass to YQLink web_link = self.hook.compose_query_web_link(self.query_id) - context["ti"].xcom_push(key=XCOM_WEBLINK_KEY, value=web_link) + YQLink.persist(context, self, web_link) results = self.hook.wait_results(self.query_id) # forget query to avoid 'stop_query' in on_kill @@ -98,3 +85,5 @@ def execute(self, context: Context) -> Any: def on_kill(self) -> None: if self.hook is not None and self.query_id is not None: self.hook.stop_query(self.query_id) + self.hook.close() + self.hook = None diff --git a/airflow/providers/yandex/provider.yaml b/airflow/providers/yandex/provider.yaml index 7a91ed9cd259b..f2402379f62de 100644 --- a/airflow/providers/yandex/provider.yaml +++ b/airflow/providers/yandex/provider.yaml @@ -19,8 +19,9 @@ package-name: apache-airflow-providers-yandex name: Yandex description: | - Yandex including `Yandex.Cloud `__ + This package is for Yandex, including: + - `Yandex.Cloud `__ state: ready source-date-epoch: 1707636562 # note that those versions are maintained by release manager - do not update them manually @@ -88,15 +89,16 @@ hooks: - integration-name: Yandex.Cloud YQ python-modules: - airflow.providers.yandex.hooks.yandexcloud_yq - - airflow.providers.yandex.hooks.http_client - - airflow.providers.yandex.hooks.query_results connection-types: - hook-class-name: airflow.providers.yandex.hooks.yandex.YandexCloudBaseHook connection-type: yandexcloud +secrets-backends: + - airflow.providers.yandex.secrets.lockbox.LockboxSecretBackend + extra-links: - - airflow.providers.yandex.operators.yandexcloud_yq.YQLink + - airflow.providers.yandex.links.yq.YQLink config: yandex: diff --git a/airflow/providers/yandex/yq_client/__init__.py b/airflow/providers/yandex/yq_client/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/airflow/providers/yandex/yq_client/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/airflow/providers/yandex/hooks/http_client.py b/airflow/providers/yandex/yq_client/http_client.py similarity index 96% rename from airflow/providers/yandex/hooks/http_client.py rename to airflow/providers/yandex/yq_client/http_client.py index 4e136c628a978..7bfac19933c94 100644 --- a/airflow/providers/yandex/hooks/http_client.py +++ b/airflow/providers/yandex/yq_client/http_client.py @@ -14,10 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - -# This file is a copy of https://github.com/ydb-platform/ydb/tree/284b7efb67edcdade0b12c849b7fad40739ad62b/ydb/core/fq/libs/http_api_client -# It is highly recommended to modify original file first in YDB project and merge it here afterwards - from __future__ import annotations import logging @@ -29,7 +25,7 @@ from typing import Any from urllib3.util.retry import Retry -from .query_results import YQResults +from airflow.providers.yandex.yq_client.query_results import YQResults MAX_RETRY_FOR_SESSION = 4 BACK_OFF_FACTOR = 0.3 @@ -54,7 +50,7 @@ def requests_retry_session( return session -class YQHttpClientConfig(object): +class YQHttpClientConfig: def __init__( self, token: str | None = None, @@ -147,7 +143,7 @@ def create_query( request_id=None, expected_code=200, ): - body = dict() + body = {} if query_text is not None: body["text"] = query_text diff --git a/airflow/providers/yandex/hooks/query_results.py b/airflow/providers/yandex/yq_client/query_results.py similarity index 85% rename from airflow/providers/yandex/hooks/query_results.py rename to airflow/providers/yandex/yq_client/query_results.py index 8da38a1aeaf30..c258e44c21ff0 100644 --- a/airflow/providers/yandex/hooks/query_results.py +++ b/airflow/providers/yandex/yq_client/query_results.py @@ -14,12 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - -# This file is a copy of https://github.com/ydb-platform/ydb/tree/284b7efb67edcdade0b12c849b7fad40739ad62b/ydb/core/fq/libs/http_api_client -# It is highly recommended to modify original file first in YDB project and merge it here afterwards - from __future__ import annotations -from typing import Any, Optional + +from typing import Any import base64 import pprint import dateutil.parser @@ -28,7 +25,7 @@ class YQResults: - """Holds and formats query execution results""" + """Holds and formats query execution results.""" def __init__(self, results: dict[str, Any]): self._raw_results = results @@ -40,13 +37,13 @@ def _convert_from_float(value: float | str) -> float: return float(value) @staticmethod - def _convert_from_pgfloat(value: str | None) -> Optional[float]: + def _convert_from_pgfloat(value: str | None) -> float | None: if value is None: return None return float(value) @staticmethod - def _convert_from_pgint(value: str | None) -> Optional[int]: + def _convert_from_pgint(value: str | None) -> int | None: if value is None: return None return int(value) @@ -56,7 +53,7 @@ def _convert_from_decimal(value: str) -> Decimal: return Decimal(value) @staticmethod - def _convert_from_pgnumeric(value: str | None) -> Optional[Decimal]: + def _convert_from_pgnumeric(value: str | None) -> Decimal | None: if value is None: return None return Decimal(value) @@ -75,7 +72,7 @@ def _convert_from_datetime(value: str) -> datetime: return dateutil.parser.isoparse(value) @staticmethod - def _convert_from_pgdatetime(value: str | None) -> Optional[datetime]: + def _convert_from_pgdatetime(value: str | None) -> datetime | None: if value is None: return None return dateutil.parser.isoparse(value) @@ -121,7 +118,7 @@ def _extract_from_dict(type_name: str) -> tuple[str, str]: return key_type, value_type @staticmethod - def _extract_from_variant_over_struct(type_name: str) -> tuple[str, str]: + def _extract_from_variant_over_struct(type_name: str) -> dict[str, str]: # Variant<'One':Int32,'Two':String> -> {One: Int32, Two: String} types_with_names = YQResults._split_type_list(type_name[len("Variant<") : -1]) result = {} @@ -133,12 +130,12 @@ def _extract_from_variant_over_struct(type_name: str) -> tuple[str, str]: return result @staticmethod - def _extract_from_variant_over_tuple(type_name: str) -> tuple[str, str]: + def _extract_from_variant_over_tuple(type_name: str) -> list[str]: # Variant -> [Int32, String] return YQResults._split_type_list(type_name[len("Variant<") : -1]) @staticmethod - def _convert_from_optional(value: list[Any]) -> Optional[Any]: + def _convert_from_optional(value: list[Any]) -> Any: # Optional types are encoded as [[]] objects # If type is Uint16, value is encoded as {"rows":[[value]]} # If type is Optional, value is encoded as {"rows":[[[value]]]} @@ -157,7 +154,7 @@ def id(v): @staticmethod def _get_converter(column_type: str) -> Any: - """Returns converter based on column type""" + """Returns converter based on column type.""" # primitives if column_type in [ @@ -227,39 +224,38 @@ def convert(x): return convert if column_type.startswith("Tuple<"): - inner_types = YQResults._extract_from_tuple(column_type) - inner_converters = [YQResults._get_converter(t) for t in inner_types] + inner_types_list = YQResults._extract_from_tuple(column_type) + inner_converters_list = [YQResults._get_converter(t) for t in inner_types_list] def convert(x): assert len(x) == len( - inner_converters - ), f"Wrong lenght for tuple value: {len(x)} != {len(inner_converters)}" - return tuple([c(v) for (c, v) in zip(inner_converters, x)]) + inner_converters_list + ), f"Wrong length for tuple value: {len(x)} != {len(inner_converters_list)}" + return tuple([c(v) for (c, v) in zip(inner_converters_list, x)]) return convert # variant over struct if column_type.startswith("Variant<'"): - inner_types = YQResults._extract_from_variant_over_struct(column_type) - inner_converters = {k: YQResults._get_converter(t) for k, t in inner_types.items()} + inner_types_dict = YQResults._extract_from_variant_over_struct(column_type) + inner_converters_dict = {k: YQResults._get_converter(t) for k, t in inner_types_dict.items()} def convert(x): - return inner_converters[x[0]](x[1]) + return inner_converters_dict[x[0]](x[1]) return convert # variant over tuple if column_type.startswith("Variant<"): - inner_types = YQResults._extract_from_variant_over_tuple(column_type) - inner_converters = [YQResults._get_converter(t) for t in inner_types] + inner_types_list = YQResults._extract_from_variant_over_tuple(column_type) + inner_converters_list = [YQResults._get_converter(t) for t in inner_types_list] def convert(x): - return inner_converters[x[0]](x[1]) + return inner_converters_list[x[0]](x[1]) return convert if column_type == "EmptyDict": - def convert(x): return {} diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index 10d50d9bbad3a..c751aa92ddc3a 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -1184,7 +1184,9 @@ "yandexcloud>=0.228.0" ], "devel-deps": [], - "cross-providers-deps": [], + "cross-providers-deps": [ + "common.sql" + ], "excluded-python-versions": [], "state": "ready" }, diff --git a/tests/providers/yandex/operators/test_yandexcloud_yq.py b/tests/providers/yandex/operators/test_yandexcloud_yq.py index 5bc4f835bc439..99cbc3a103a89 100644 --- a/tests/providers/yandex/operators/test_yandexcloud_yq.py +++ b/tests/providers/yandex/operators/test_yandexcloud_yq.py @@ -88,7 +88,7 @@ def test_execute_query(self, mock_get_connection): context["ti"].xcom_push.assert_has_calls( [ call( - key="web_link", value=f"https://yq.cloud.yandex.ru/folders/{FOLDER_ID}/ide/queries/query1" + key="web_link", value=f"https://yq.cloud.yandex.ru/folders/{FOLDER_ID}/ide/queries/query1", execution_date=None ), ] ) From ebe4a47332fb5d912223ba25fe34314e1e31c2cf Mon Sep 17 00:00:00 2001 From: Sergey Uzhakov Date: Fri, 16 Feb 2024 17:44:30 +0300 Subject: [PATCH 21/34] revert version --- airflow/providers/yandex/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/providers/yandex/__init__.py b/airflow/providers/yandex/__init__.py index 690909f9ce08d..d50e5d867783f 100644 --- a/airflow/providers/yandex/__init__.py +++ b/airflow/providers/yandex/__init__.py @@ -27,7 +27,7 @@ __all__ = ["__version__"] -__version__ = "3.8.0" +__version__ = "3.9.0" try: from airflow import __version__ as airflow_version From e6badc2f9e31d7eeda45452e0efef1293a533c3c Mon Sep 17 00:00:00 2001 From: Sergey Uzhakov Date: Sat, 17 Feb 2024 11:35:00 +0300 Subject: [PATCH 22/34] change text to trigger CI checks --- airflow/providers/yandex/hooks/yandexcloud_yq.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/providers/yandex/hooks/yandexcloud_yq.py b/airflow/providers/yandex/hooks/yandexcloud_yq.py index 1d0106fda207f..1668b087eb54c 100644 --- a/airflow/providers/yandex/hooks/yandexcloud_yq.py +++ b/airflow/providers/yandex/hooks/yandexcloud_yq.py @@ -31,7 +31,7 @@ class YQHook(YandexCloudBaseHook): - """A hook for Yandex Query.""" + """A hook to work with Yandex Query.""" def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) From 8a8cc974009b165c0a7ac35ac24681ab5c10cbbb Mon Sep 17 00:00:00 2001 From: Sergey Uzhakov Date: Sat, 17 Feb 2024 22:42:28 +0300 Subject: [PATCH 23/34] fixes for linters --- airflow/providers/yandex/yq_client/http_client.py | 14 ++++++++++---- .../providers/yandex/yq_client/query_results.py | 7 ++++--- .../providers/yandex/hooks/test_yandexcloud_yq.py | 4 +--- .../yandex/operators/test_yandexcloud_yq.py | 11 +++++++---- 4 files changed, 22 insertions(+), 14 deletions(-) diff --git a/airflow/providers/yandex/yq_client/http_client.py b/airflow/providers/yandex/yq_client/http_client.py index 7bfac19933c94..30a9ab4c05600 100644 --- a/airflow/providers/yandex/yq_client/http_client.py +++ b/airflow/providers/yandex/yq_client/http_client.py @@ -33,7 +33,7 @@ ERROR_CODES = (500, 502, 504) -def requests_retry_session( +def _requests_retry_session( session, retries=MAX_RETRY_FOR_SESSION, back_off_factor=BACK_OFF_FACTOR, status_force_list=ERROR_CODES ): retry = Retry( @@ -51,6 +51,8 @@ def requests_retry_session( class YQHttpClientConfig: + """YandexQuery HTTP client config.""" + def __init__( self, token: str | None = None, @@ -69,17 +71,21 @@ def __init__( class YQHttpClientException(Exception): - def __init__(self, message: str, status: str = None, msg: str = None, details: Any = None) -> None: + """YandexQuery client exception type.""" + + def __init__(self, message: str, status: str | None = None, msg: str | None = None, details: Any = None) -> None: super().__init__(message) self.status = status self.msg = msg self.details = details -class YQHttpClient(object): +class YQHttpClient: + """YandexQuery HTTP client.""" + def __init__(self, config: YQHttpClientConfig): self.config = config - self.session = requests_retry_session(session=requests.Session()) + self.session = _requests_retry_session(session=requests.Session()) def __enter__(self): return self diff --git a/airflow/providers/yandex/yq_client/query_results.py b/airflow/providers/yandex/yq_client/query_results.py index c258e44c21ff0..1b4ec75e81d21 100644 --- a/airflow/providers/yandex/yq_client/query_results.py +++ b/airflow/providers/yandex/yq_client/query_results.py @@ -16,12 +16,13 @@ # under the License. from __future__ import annotations -from typing import Any import base64 import pprint -import dateutil.parser from datetime import datetime from decimal import Decimal +from typing import Any + +import dateutil.parser class YQResults: @@ -155,7 +156,6 @@ def id(v): @staticmethod def _get_converter(column_type: str) -> Any: """Returns converter based on column type.""" - # primitives if column_type in [ "Int8", @@ -256,6 +256,7 @@ def convert(x): return convert if column_type == "EmptyDict": + def convert(x): return {} diff --git a/tests/providers/yandex/hooks/test_yandexcloud_yq.py b/tests/providers/yandex/hooks/test_yandexcloud_yq.py index 23fef99aec20b..ea53f082d6fda 100644 --- a/tests/providers/yandex/hooks/test_yandexcloud_yq.py +++ b/tests/providers/yandex/hooks/test_yandexcloud_yq.py @@ -130,9 +130,7 @@ def test_select_results(self, mock_jwt): responses.post( "https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries/query1/stop", match=[ - matchers.header_matcher( - {"Content-Type": "application/json", "Authorization": "Bearer super_token"} - ), + matchers.header_matcher({"Authorization": "Bearer super_token"}), matchers.query_param_matcher({"project": "my_folder_id"}), ], status=204, diff --git a/tests/providers/yandex/operators/test_yandexcloud_yq.py b/tests/providers/yandex/operators/test_yandexcloud_yq.py index 99cbc3a103a89..ee465804f8bb7 100644 --- a/tests/providers/yandex/operators/test_yandexcloud_yq.py +++ b/tests/providers/yandex/operators/test_yandexcloud_yq.py @@ -16,15 +16,16 @@ # under the License. from __future__ import annotations +import re from datetime import datetime, timedelta +from unittest.mock import MagicMock, call, patch + import pytest import responses from responses import matchers -import re -from unittest.mock import MagicMock, call, patch -from airflow.models.dag import DAG from airflow.models import Connection +from airflow.models.dag import DAG from airflow.providers.yandex.operators.yandexcloud_yq import YQExecuteQueryOperator OAUTH_TOKEN = "my_oauth_token" @@ -88,7 +89,9 @@ def test_execute_query(self, mock_get_connection): context["ti"].xcom_push.assert_has_calls( [ call( - key="web_link", value=f"https://yq.cloud.yandex.ru/folders/{FOLDER_ID}/ide/queries/query1", execution_date=None + key="web_link", + value=f"https://yq.cloud.yandex.ru/folders/{FOLDER_ID}/ide/queries/query1", + execution_date=None, ), ] ) From ece1e0c75e263ea5c155a90f7dc2a9640e4ecca7 Mon Sep 17 00:00:00 2001 From: Sergey Uzhakov Date: Sat, 2 Mar 2024 16:43:00 +0300 Subject: [PATCH 24/34] rework --- .../yandex/hooks/{yandexcloud_yq.py => yq.py} | 2 +- .../yandex/operators/yandexcloud_yq.py | 2 +- airflow/providers/yandex/provider.yaml | 2 +- .../providers/yandex/yq_client/http_client.py | 1 - tests/providers/yandex/hooks/test_yq.py | 114 ++++++++++++++ tests/providers/yandex/yq_client/__init__.py | 16 ++ .../yandex/yq_client/test_http_client.py | 124 +++++++++++++++ .../test_query_results.py} | 145 +----------------- 8 files changed, 265 insertions(+), 141 deletions(-) rename airflow/providers/yandex/hooks/{yandexcloud_yq.py => yq.py} (99%) create mode 100644 tests/providers/yandex/hooks/test_yq.py create mode 100644 tests/providers/yandex/yq_client/__init__.py create mode 100644 tests/providers/yandex/yq_client/test_http_client.py rename tests/providers/yandex/{hooks/test_yandexcloud_yq.py => yq_client/test_query_results.py} (70%) diff --git a/airflow/providers/yandex/hooks/yandexcloud_yq.py b/airflow/providers/yandex/hooks/yq.py similarity index 99% rename from airflow/providers/yandex/hooks/yandexcloud_yq.py rename to airflow/providers/yandex/hooks/yq.py index 1668b087eb54c..1d0106fda207f 100644 --- a/airflow/providers/yandex/hooks/yandexcloud_yq.py +++ b/airflow/providers/yandex/hooks/yq.py @@ -31,7 +31,7 @@ class YQHook(YandexCloudBaseHook): - """A hook to work with Yandex Query.""" + """A hook for Yandex Query.""" def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) diff --git a/airflow/providers/yandex/operators/yandexcloud_yq.py b/airflow/providers/yandex/operators/yandexcloud_yq.py index 792b76bbee762..a7f5d9ba24d76 100644 --- a/airflow/providers/yandex/operators/yandexcloud_yq.py +++ b/airflow/providers/yandex/operators/yandexcloud_yq.py @@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, Any, Sequence from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator -from airflow.providers.yandex.hooks.yandexcloud_yq import YQHook +from airflow.providers.yandex.hooks.yq import YQHook from airflow.providers.yandex.links.yq import YQLink if TYPE_CHECKING: diff --git a/airflow/providers/yandex/provider.yaml b/airflow/providers/yandex/provider.yaml index f2402379f62de..12627e56557fd 100644 --- a/airflow/providers/yandex/provider.yaml +++ b/airflow/providers/yandex/provider.yaml @@ -88,7 +88,7 @@ hooks: - airflow.providers.yandex.hooks.yandexcloud_dataproc - integration-name: Yandex.Cloud YQ python-modules: - - airflow.providers.yandex.hooks.yandexcloud_yq + - airflow.providers.yandex.hooks.yq connection-types: - hook-class-name: airflow.providers.yandex.hooks.yandex.YandexCloudBaseHook diff --git a/airflow/providers/yandex/yq_client/http_client.py b/airflow/providers/yandex/yq_client/http_client.py index 30a9ab4c05600..e182db96bd0bf 100644 --- a/airflow/providers/yandex/yq_client/http_client.py +++ b/airflow/providers/yandex/yq_client/http_client.py @@ -59,7 +59,6 @@ def __init__( project: str | None = None, user_agent: str | None = "Python YQ HTTP SDK", ) -> None: - assert len(token) > 0, "empty token" self.token = token self.project = project self.user_agent = user_agent diff --git a/tests/providers/yandex/hooks/test_yq.py b/tests/providers/yandex/hooks/test_yq.py new file mode 100644 index 0000000000000..1a2359077010e --- /dev/null +++ b/tests/providers/yandex/hooks/test_yq.py @@ -0,0 +1,114 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from datetime import timedelta +import responses +from responses import matchers +from unittest import mock + +from airflow.models import Connection +from airflow.providers.yandex.hooks.yq import YQHook + +OAUTH_TOKEN = "my_oauth_token" +SERVICE_ACCOUNT_AUTH_KEY_JSON = """{"id":"my_id", "service_account_id":"my_sa1", "private_key":"my_pk"}""" + + +class DummySDK: + def __init__(self) -> None: + self.client = None + + +class TestYandexCloudYqHook: + def _init_hook(self): + with mock.patch("airflow.hooks.base.BaseHook.get_connection") as mock_get_connection: + mock_get_connection.return_value = self.connection + self.hook = YQHook(default_folder_id="my_folder_id") + + def setup_method(self): + self.connection = Connection(extra={"service_account_json": SERVICE_ACCOUNT_AUTH_KEY_JSON}) + + @responses.activate() + def test_oauth_token_usage(self): + responses.post( + "https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries", + match=[ + matchers.header_matcher( + {"Content-Type": "application/json", "Authorization": f"Bearer {OAUTH_TOKEN}"} + ), + matchers.query_param_matcher({"project": "my_folder_id"}), + ], + json={"id": "query1"}, + status=200, + ) + + self.connection = Connection(extra={"oauth": OAUTH_TOKEN}) + self._init_hook() + query_id = self.hook.create_query(query_text="select 777", name="my query") + assert query_id == "query1" + + with mock.patch("airflow.providers.yandex.yq_client.http_client.YQHttpClient.compose_query_web_link") as m: + m.return_value = "http://gg.zz" + assert self.hook.compose_query_web_link("query1") == "http://gg.zz" + m.assert_called_once_with("query1") + + @responses.activate() + @mock.patch("yandexcloud.SDK") + @mock.patch("jwt.encode") + def test_select_results(self, mock_jwt, mock_sdk): + responses.post( + "https://iam.api.cloud.yandex.net/iam/v1/tokens", + json={"iamToken": "super_token"}, + status=200, + ) + + mock_jwt.return_value = "zzzz" + mock_sdk.return_value = DummySDK() + + with mock.patch.multiple( + "airflow.providers.yandex.yq_client.http_client.YQHttpClient", + create_query=mock.DEFAULT, + wait_query_to_succeed=mock.DEFAULT, + get_query_all_result_sets=mock.DEFAULT, + get_query_status=mock.DEFAULT, + get_query=mock.DEFAULT, + stop_query=mock.DEFAULT, + ) as mocks: + self._init_hook() + mocks["create_query"].return_value = "query1" + mocks["wait_query_to_succeed"].return_value = 2 + mocks["get_query_all_result_sets"].return_value = {"x": 765} + mocks["get_query_status"].return_value = "COMPLETED" + mocks["get_query"].return_value = {"id": "my_q"} + + query_id = self.hook.create_query(query_text="select 777", name="my query") + assert query_id == "query1" + mocks["create_query"].assert_called_once_with(query_text="select 777", name="my query") + + results = self.hook.wait_results(query_id, execution_timeout=timedelta(minutes=10)) + assert results == {"x": 765} + mocks["wait_query_to_succeed"].assert_called_once_with(query_id, execution_timeout=timedelta(minutes=10), stop_on_timeout=True) + mocks["get_query_all_result_sets"].assert_called_once_with(query_id=query_id, result_set_count=2) + + assert self.hook.get_query_status(query_id) == "COMPLETED" + mocks["get_query_status"].assert_called_once_with(query_id) + + assert self.hook.get_query(query_id) == {"id": "my_q"} + mocks["get_query"].assert_called_once_with(query_id) + + self.hook.stop_query(query_id) + mocks["stop_query"].assert_called_once_with(query_id) diff --git a/tests/providers/yandex/yq_client/__init__.py b/tests/providers/yandex/yq_client/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/tests/providers/yandex/yq_client/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/providers/yandex/yq_client/test_http_client.py b/tests/providers/yandex/yq_client/test_http_client.py new file mode 100644 index 0000000000000..74485005d0621 --- /dev/null +++ b/tests/providers/yandex/yq_client/test_http_client.py @@ -0,0 +1,124 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import responses +from responses import matchers + +from airflow.providers.yandex.yq_client.http_client import YQHttpClient, YQHttpClientConfig + +IAM_TOKEN = "my_iam_token" +PROJECT="my_project" + + +class TestYQHttpClient: + def setup_method(self): + config = YQHttpClientConfig(IAM_TOKEN, PROJECT) + self.client = YQHttpClient(config) + + def setup_mocks_for_query_execution(self, query_results_json_list): + responses.post( + "https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries", + match=[ + matchers.header_matcher( + {"Content-Type": "application/json", "Authorization": f"Bearer {IAM_TOKEN}"} + ), + matchers.query_param_matcher({"project": PROJECT}), + matchers.json_params_matcher({"name": "my query", "text": "select 777"}), + ], + json={"id": "query1"}, + status=200, + ) + + responses.get( + "https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries/query1/status", + json={"status": "RUNNING"}, + status=200, + ) + + responses.get( + "https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries/query1/status", + json={"status": "COMPLETED"}, + status=200, + ) + + result_set_count = len(query_results_json_list) + responses.get( + "https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries/query1", + json={"id": "query1", "result_sets": [{"rows_count": 1, "truncated": False} for _ in range(result_set_count)]}, + status=200, + ) + + for i in range(result_set_count): + responses.get( + f"https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries/query1/results/{i}", + json=query_results_json_list[i], + status=200, + ) + + def _create_test_query(self): + query_id = self.client.create_query(query_text="select 777", name="my query") + assert query_id == "query1" + return query_id + + @responses.activate() + def test_select_results(self): + self.setup_mocks_for_query_execution( + [{"rows": [[777]], "columns": [{"name": "column0", "type": "Int32"}]}] + ) + + query_id = self._create_test_query() + + result_set_count = self.client.wait_query_to_succeed(query_id) + assert result_set_count == 1 + results = self.client.get_query_all_result_sets(query_id, result_set_count=result_set_count) + assert results == {"rows": [[777]], "columns": [{"name": "column0", "type": "Int32"}]} + + responses.post( + "https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries/query1/stop", + match=[ + matchers.header_matcher({"Authorization": f"Bearer {IAM_TOKEN}"}), + matchers.query_param_matcher({"project": PROJECT}), + ], + status=204, + ) + + assert self.client.get_query_status(query_id) == "COMPLETED" + assert self.client.get_query(query_id) == { + "id": "query1", + "result_sets": [{"rows_count": 1, "truncated": False}], + } + self.client.stop_query(query_id) + + assert self.client.compose_query_web_link(query_id) == f"https://yq.cloud.yandex.ru/folders/{PROJECT}/ide/queries/query1" + + @responses.activate() + def test_select_two_record_sets(self): + self.setup_mocks_for_query_execution( + [ + {"rows": [[777]], "columns": [{"name": "column0", "type": "Int32"}]}, + {"rows": [["zzz"]], "columns": [{"name": "aaaa", "type": "Utf8"}]}, + ] + ) + + query_id = self._create_test_query() + + result_set_count = self.client.wait_query_to_succeed(query_id) + assert result_set_count == 2 + results = self.client.get_query_all_result_sets(query_id, result_set_count=result_set_count) + assert results[0] == {"rows": [[777]], "columns": [{"name": "column0", "type": "Int32"}]} + assert results[1] == {"rows": [["zzz"]], "columns": [{"name": "aaaa", "type": "Utf8"}]} diff --git a/tests/providers/yandex/hooks/test_yandexcloud_yq.py b/tests/providers/yandex/yq_client/test_query_results.py similarity index 70% rename from tests/providers/yandex/hooks/test_yandexcloud_yq.py rename to tests/providers/yandex/yq_client/test_query_results.py index ea53f082d6fda..f445cffe0b19d 100644 --- a/tests/providers/yandex/hooks/test_yandexcloud_yq.py +++ b/tests/providers/yandex/yq_client/test_query_results.py @@ -16,139 +16,17 @@ # under the License. from __future__ import annotations -import responses from datetime import datetime from dateutil.tz import tzutc from decimal import Decimal -from responses import matchers -from unittest import mock -from airflow.models import Connection -from airflow.providers.yandex.hooks.yandexcloud_yq import YQHook +from airflow.providers.yandex.yq_client.query_results import YQResults -OAUTH_TOKEN = "my_oauth_token" -SERVICE_ACCOUNT_AUTH_KEY_JSON = ( - """{"id":"my_id", "service_account_id":"my_sa1", "private_key":"-----BEGIN PRIVATE KEY----- my_pk"}""" -) - -class TestYandexCloudYqHook: - def _init_hook(self): - with mock.patch("airflow.hooks.base.BaseHook.get_connection") as mock_get_connection: - mock_get_connection.return_value = self.connection - self.hook = YQHook(default_folder_id="my_folder_id") - - def setup_method(self): - self.connection = Connection(extra={"service_account_json": SERVICE_ACCOUNT_AUTH_KEY_JSON}) - - def setup_mocks_for_query_execution(self, mock_jwt, query_results_json): - responses.post( - "https://iam.api.cloud.yandex.net/iam/v1/tokens", - json={"iamToken": "super_token"}, - status=200, - ) - mock_jwt.return_value = "zzzz" - - responses.post( - "https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries", - match=[ - matchers.header_matcher( - {"Content-Type": "application/json", "Authorization": "Bearer super_token"} - ), - matchers.query_param_matcher({"project": "my_folder_id"}), - matchers.json_params_matcher({"name": "my query", "text": "select 777"}), - ], - json={"id": "query1"}, - status=200, - ) - - responses.get( - "https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries/query1/status", - json={"status": "RUNNING"}, - status=200, - ) - - responses.get( - "https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries/query1/status", - json={"status": "COMPLETED"}, - status=200, - ) - - responses.get( - "https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries/query1", - json={"id": "query1", "result_sets": [{"rows_count": 1, "truncated": False}]}, - status=200, - ) - - responses.get( - "https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries/query1/results/0", - json=query_results_json, - status=200, - ) - - def _create_test_query(self): - query_id = self.hook.create_query(query_text="select 777", name="my query") - assert query_id == "query1" - return query_id - - @responses.activate() - def test_oauth_token_usage(self): - responses.post( - "https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries", - match=[ - matchers.header_matcher( - {"Content-Type": "application/json", "Authorization": f"Bearer {OAUTH_TOKEN}"} - ), - matchers.query_param_matcher({"project": "my_folder_id"}), - ], - json={"id": "query1"}, - status=200, - ) - - self.connection = Connection(extra={"oauth": OAUTH_TOKEN}) - self._init_hook() - self._create_test_query() - - @responses.activate() - @mock.patch("jwt.encode") - def test_select_results(self, mock_jwt): - self.setup_mocks_for_query_execution( - mock_jwt, {"rows": [[777]], "columns": [{"name": "column0", "type": "Int32"}]} - ) - - self._init_hook() - query_id = self._create_test_query() - - assert ( - self.hook.compose_query_web_link(query_id) - == "https://yq.cloud.yandex.ru/folders/my_folder_id/ide/queries/query1" - ) - - results = self.hook.wait_results(query_id) - assert results == {"rows": [[777]], "columns": [{"name": "column0", "type": "Int32"}]} - - responses.post( - "https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries/query1/stop", - match=[ - matchers.header_matcher({"Authorization": "Bearer super_token"}), - matchers.query_param_matcher({"project": "my_folder_id"}), - ], - status=204, - ) - - assert self.hook.get_query_status(query_id) == "COMPLETED" - assert self.hook.get_query(query_id) == { - "id": "query1", - "result_sets": [{"rows_count": 1, "truncated": False}], - } - self.hook.stop_query(query_id) - - @responses.activate() - @mock.patch("jwt.encode") - def test_integral_results(self, mock_jwt): +class TestYQResults: + def test_integral_results(self): # json response and results could be found here: https://github.com/ydb-platform/ydb/blob/284b7efb67edcdade0b12c849b7fad40739ad62b/ydb/tests/fq/http_api/test_http_api.py#L336 - self.setup_mocks_for_query_execution( - mock_jwt, + r = YQResults( { "rows": [ [ @@ -203,10 +81,8 @@ def test_integral_results(self, mock_jwt): }, ) - self._init_hook() - query_id = self._create_test_query() + results = r.results - results = self.hook.wait_results(query_id) assert results == { "rows": [ [ @@ -260,12 +136,9 @@ def test_integral_results(self, mock_jwt): ], } - @responses.activate() - @mock.patch("jwt.encode") - def test_complex_results(self, mock_jwt): + def test_complex_results(self): # json response and results could be found here: https://github.com/ydb-platform/ydb/blob/284b7efb67edcdade0b12c849b7fad40739ad62b/ydb/tests/fq/http_api/test_http_api.py#L445 - self.setup_mocks_for_query_execution( - mock_jwt, + r = YQResults( { "rows": [ [ @@ -334,10 +207,8 @@ def test_complex_results(self, mock_jwt): }, ) - self._init_hook() - query_id = self._create_test_query() + results = r.results - results = self.hook.wait_results(query_id) assert results == { "rows": [ [ From 843cd93a4b1ce1a2aae82694acf6c1072ad40478 Mon Sep 17 00:00:00 2001 From: Sergey Uzhakov Date: Sat, 2 Mar 2024 16:44:20 +0300 Subject: [PATCH 25/34] restyling --- airflow/providers/yandex/yq_client/http_client.py | 8 +++++--- tests/providers/yandex/hooks/test_yq.py | 8 ++++++-- tests/providers/yandex/yq_client/test_http_client.py | 12 +++++++++--- .../providers/yandex/yq_client/test_query_results.py | 3 ++- 4 files changed, 22 insertions(+), 9 deletions(-) diff --git a/airflow/providers/yandex/yq_client/http_client.py b/airflow/providers/yandex/yq_client/http_client.py index e182db96bd0bf..e7c57f65c0ec9 100644 --- a/airflow/providers/yandex/yq_client/http_client.py +++ b/airflow/providers/yandex/yq_client/http_client.py @@ -19,10 +19,10 @@ import logging import time from datetime import datetime -import requests +from typing import Any +import requests from requests.adapters import HTTPAdapter -from typing import Any from urllib3.util.retry import Retry from airflow.providers.yandex.yq_client.query_results import YQResults @@ -72,7 +72,9 @@ def __init__( class YQHttpClientException(Exception): """YandexQuery client exception type.""" - def __init__(self, message: str, status: str | None = None, msg: str | None = None, details: Any = None) -> None: + def __init__( + self, message: str, status: str | None = None, msg: str | None = None, details: Any = None + ) -> None: super().__init__(message) self.status = status self.msg = msg diff --git a/tests/providers/yandex/hooks/test_yq.py b/tests/providers/yandex/hooks/test_yq.py index 1a2359077010e..d95b340daf127 100644 --- a/tests/providers/yandex/hooks/test_yq.py +++ b/tests/providers/yandex/hooks/test_yq.py @@ -61,7 +61,9 @@ def test_oauth_token_usage(self): query_id = self.hook.create_query(query_text="select 777", name="my query") assert query_id == "query1" - with mock.patch("airflow.providers.yandex.yq_client.http_client.YQHttpClient.compose_query_web_link") as m: + with mock.patch( + "airflow.providers.yandex.yq_client.http_client.YQHttpClient.compose_query_web_link" + ) as m: m.return_value = "http://gg.zz" assert self.hook.compose_query_web_link("query1") == "http://gg.zz" m.assert_called_once_with("query1") @@ -101,7 +103,9 @@ def test_select_results(self, mock_jwt, mock_sdk): results = self.hook.wait_results(query_id, execution_timeout=timedelta(minutes=10)) assert results == {"x": 765} - mocks["wait_query_to_succeed"].assert_called_once_with(query_id, execution_timeout=timedelta(minutes=10), stop_on_timeout=True) + mocks["wait_query_to_succeed"].assert_called_once_with( + query_id, execution_timeout=timedelta(minutes=10), stop_on_timeout=True + ) mocks["get_query_all_result_sets"].assert_called_once_with(query_id=query_id, result_set_count=2) assert self.hook.get_query_status(query_id) == "COMPLETED" diff --git a/tests/providers/yandex/yq_client/test_http_client.py b/tests/providers/yandex/yq_client/test_http_client.py index 74485005d0621..2ed5ea6fec79a 100644 --- a/tests/providers/yandex/yq_client/test_http_client.py +++ b/tests/providers/yandex/yq_client/test_http_client.py @@ -22,7 +22,7 @@ from airflow.providers.yandex.yq_client.http_client import YQHttpClient, YQHttpClientConfig IAM_TOKEN = "my_iam_token" -PROJECT="my_project" +PROJECT = "my_project" class TestYQHttpClient: @@ -59,7 +59,10 @@ def setup_mocks_for_query_execution(self, query_results_json_list): result_set_count = len(query_results_json_list) responses.get( "https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries/query1", - json={"id": "query1", "result_sets": [{"rows_count": 1, "truncated": False} for _ in range(result_set_count)]}, + json={ + "id": "query1", + "result_sets": [{"rows_count": 1, "truncated": False} for _ in range(result_set_count)], + }, status=200, ) @@ -104,7 +107,10 @@ def test_select_results(self): } self.client.stop_query(query_id) - assert self.client.compose_query_web_link(query_id) == f"https://yq.cloud.yandex.ru/folders/{PROJECT}/ide/queries/query1" + assert ( + self.client.compose_query_web_link(query_id) + == f"https://yq.cloud.yandex.ru/folders/{PROJECT}/ide/queries/query1" + ) @responses.activate() def test_select_two_record_sets(self): diff --git a/tests/providers/yandex/yq_client/test_query_results.py b/tests/providers/yandex/yq_client/test_query_results.py index f445cffe0b19d..30becbd015926 100644 --- a/tests/providers/yandex/yq_client/test_query_results.py +++ b/tests/providers/yandex/yq_client/test_query_results.py @@ -17,9 +17,10 @@ from __future__ import annotations from datetime import datetime -from dateutil.tz import tzutc from decimal import Decimal +from dateutil.tz import tzutc + from airflow.providers.yandex.yq_client.query_results import YQResults From f2fc45a029f23124dd6a21029c33333f61eae33c Mon Sep 17 00:00:00 2001 From: Sergey Uzhakov Date: Sun, 10 Mar 2024 12:41:34 +0300 Subject: [PATCH 26/34] fix CI tests, add yq link tests --- airflow/providers/yandex/hooks/yq.py | 2 +- .../providers/yandex/yq_client/http_client.py | 4 +- .../yandex/yq_client/query_results.py | 2 +- dev/breeze/tests/test_selective_checks.py | 2 +- tests/providers/yandex/hooks/test_yq.py | 3 +- tests/providers/yandex/links/__init__.py | 16 +++++ tests/providers/yandex/links/test_yq.py | 62 +++++++++++++++++++ 7 files changed, 86 insertions(+), 5 deletions(-) create mode 100644 tests/providers/yandex/links/__init__.py create mode 100644 tests/providers/yandex/links/test_yq.py diff --git a/airflow/providers/yandex/hooks/yq.py b/airflow/providers/yandex/hooks/yq.py index 1d0106fda207f..288ee9ae71a67 100644 --- a/airflow/providers/yandex/hooks/yq.py +++ b/airflow/providers/yandex/hooks/yq.py @@ -22,7 +22,7 @@ import jwt import requests -from requests.packages.urllib3.util.retry import Retry +from urllib3.util.retry import Retry from airflow.exceptions import AirflowException from airflow.providers.yandex.hooks.yandex import YandexCloudBaseHook diff --git a/airflow/providers/yandex/yq_client/http_client.py b/airflow/providers/yandex/yq_client/http_client.py index e7c57f65c0ec9..a999f4367b67e 100644 --- a/airflow/providers/yandex/yq_client/http_client.py +++ b/airflow/providers/yandex/yq_client/http_client.py @@ -32,6 +32,8 @@ TIME_BETWEEN_RETRIES = 1000 ERROR_CODES = (500, 502, 504) +logger = logging.getLogger(__name__) + def _requests_retry_session( session, retries=MAX_RETRY_FOR_SESSION, back_off_factor=BACK_OFF_FACTOR, status_force_list=ERROR_CODES @@ -124,7 +126,7 @@ def _compose_web_url(self, path: str) -> str: return self.config.web_base_url + path def _validate_http_error(self, response, expected_code=200) -> None: - logging.debug("Response: %s, %s", response.status_code, response.text) + logger.debug("Response: %s, %s", response.status_code, response.text) if response.status_code != expected_code: if response.headers.get("Content-Type", "").startswith("application/json"): body = response.json() diff --git a/airflow/providers/yandex/yq_client/query_results.py b/airflow/providers/yandex/yq_client/query_results.py index 1b4ec75e81d21..a321e835b256f 100644 --- a/airflow/providers/yandex/yq_client/query_results.py +++ b/airflow/providers/yandex/yq_client/query_results.py @@ -155,7 +155,7 @@ def id(v): @staticmethod def _get_converter(column_type: str) -> Any: - """Returns converter based on column type.""" + """Return converter based on column type.""" # primitives if column_type in [ "Int8", diff --git a/dev/breeze/tests/test_selective_checks.py b/dev/breeze/tests/test_selective_checks.py index 68a03f714f6b9..f0810e0090bc5 100644 --- a/dev/breeze/tests/test_selective_checks.py +++ b/dev/breeze/tests/test_selective_checks.py @@ -1463,7 +1463,7 @@ def test_upgrade_to_newer_dependencies( "docs-list-as-string": "apache-airflow amazon apache.drill apache.druid apache.hive " "apache.impala apache.pinot common.sql databricks elasticsearch " "exasol google jdbc microsoft.mssql mysql odbc openlineage " - "oracle pgvector postgres presto slack snowflake sqlite teradata trino vertica", + "oracle pgvector postgres presto slack snowflake sqlite teradata trino vertica yandex", }, id="Common SQL provider package python files changed", ), diff --git a/tests/providers/yandex/hooks/test_yq.py b/tests/providers/yandex/hooks/test_yq.py index d95b340daf127..98ec29ea98534 100644 --- a/tests/providers/yandex/hooks/test_yq.py +++ b/tests/providers/yandex/hooks/test_yq.py @@ -17,9 +17,10 @@ from __future__ import annotations from datetime import timedelta +from unittest import mock + import responses from responses import matchers -from unittest import mock from airflow.models import Connection from airflow.providers.yandex.hooks.yq import YQHook diff --git a/tests/providers/yandex/links/__init__.py b/tests/providers/yandex/links/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/tests/providers/yandex/links/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/providers/yandex/links/test_yq.py b/tests/providers/yandex/links/test_yq.py new file mode 100644 index 0000000000000..4bed6ba5b56a5 --- /dev/null +++ b/tests/providers/yandex/links/test_yq.py @@ -0,0 +1,62 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock + +from airflow.models.xcom import XCom +from airflow.models.taskinstance import TaskInstance + +from airflow.providers.yandex.links.yq import YQLink +from tests.test_utils.mock_operators import MockOperator + + +def test_persist(): + mock_context = mock.MagicMock() + + YQLink.persist( + context=mock_context, + task_instance=MockOperator(task_id="test_task_id"), + web_link="g.com" + ) + + ti = mock_context["ti"] + ti.xcom_push.assert_called_once_with( + execution_date=None, + key="web_link", + value="g.com", + ) + + +def test_default_link(): + with mock.patch.object(XCom, "get_value") as m: + m.return_value = None + link = YQLink() + + op = MockOperator(task_id="test_task_id") + ti = TaskInstance(task=op, run_id="run_id1") + assert link.get_link(op, ti_key=ti.key) == "https://yq.cloud.yandex.ru" + + +def test_link(): + with mock.patch.object(XCom, "get_value") as m: + m.return_value = "https://g.com" + link = YQLink() + + op = MockOperator(task_id="test_task_id") + ti = TaskInstance(task=op, run_id="run_id1") + assert link.get_link(op, ti_key=ti.key) == "https://g.com" From 54110acabd10e8841ba2033665076a87860dfec2 Mon Sep 17 00:00:00 2001 From: Sergey Uzhakov Date: Sun, 10 Mar 2024 13:52:49 +0300 Subject: [PATCH 27/34] add doc strings --- airflow/providers/yandex/yq_client/http_client.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/airflow/providers/yandex/yq_client/http_client.py b/airflow/providers/yandex/yq_client/http_client.py index a999f4367b67e..7cb06df376344 100644 --- a/airflow/providers/yandex/yq_client/http_client.py +++ b/airflow/providers/yandex/yq_client/http_client.py @@ -91,12 +91,15 @@ def __init__(self, config: YQHttpClientConfig): self.session = _requests_retry_session(session=requests.Session()) def __enter__(self): + """Return the object when a context manager is created.""" return self def __exit__(self, *args): + """Close network connection when exiting the context manager.""" self.session.close() def close(self): + """Close network connection.""" self.session.close() def _build_headers(self, idempotency_key=None, request_id=None) -> dict[str, str]: From d238e206e9053c22c91184135d4ce718375dbad6 Mon Sep 17 00:00:00 2001 From: Sergey Uzhakov Date: Sun, 10 Mar 2024 14:34:16 +0300 Subject: [PATCH 28/34] fix link style tests --- tests/providers/yandex/links/test_yq.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/tests/providers/yandex/links/test_yq.py b/tests/providers/yandex/links/test_yq.py index 4bed6ba5b56a5..82113fa44e076 100644 --- a/tests/providers/yandex/links/test_yq.py +++ b/tests/providers/yandex/links/test_yq.py @@ -18,9 +18,8 @@ from unittest import mock -from airflow.models.xcom import XCom from airflow.models.taskinstance import TaskInstance - +from airflow.models.xcom import XCom from airflow.providers.yandex.links.yq import YQLink from tests.test_utils.mock_operators import MockOperator @@ -28,11 +27,7 @@ def test_persist(): mock_context = mock.MagicMock() - YQLink.persist( - context=mock_context, - task_instance=MockOperator(task_id="test_task_id"), - web_link="g.com" - ) + YQLink.persist(context=mock_context, task_instance=MockOperator(task_id="test_task_id"), web_link="g.com") ti = mock_context["ti"] ti.xcom_push.assert_called_once_with( From ac36e5330bcc02405de53f60f9161423a437bdbb Mon Sep 17 00:00:00 2001 From: Sergey Uzhakov Date: Thu, 14 Mar 2024 12:54:21 +0300 Subject: [PATCH 29/34] rename files, add deps, fix doc string --- .../yandex/operators/{yandexcloud_yq.py => yq.py} | 9 ++++----- airflow/providers/yandex/provider.yaml | 5 ++++- .../operators/{test_yandexcloud_yq.py => test_yq.py} | 2 +- 3 files changed, 9 insertions(+), 7 deletions(-) rename airflow/providers/yandex/operators/{yandexcloud_yq.py => yq.py} (91%) rename tests/providers/yandex/operators/{test_yandexcloud_yq.py => test_yq.py} (97%) diff --git a/airflow/providers/yandex/operators/yandexcloud_yq.py b/airflow/providers/yandex/operators/yq.py similarity index 91% rename from airflow/providers/yandex/operators/yandexcloud_yq.py rename to airflow/providers/yandex/operators/yq.py index a7f5d9ba24d76..e6419efd7172b 100644 --- a/airflow/providers/yandex/operators/yandexcloud_yq.py +++ b/airflow/providers/yandex/operators/yq.py @@ -33,8 +33,7 @@ class YQExecuteQueryOperator(SQLExecuteQueryOperator): :param sql: the SQL code to be executed as a single string :param name: name of the query in YandexQuery :param folder_id: cloud folder id where to create query - :param connection_id: Airflow connection ID to get parameters from - :param folder_id: cloud folder id where to create query + :param yandex_conn_id: Airflow connection ID to get parameters from """ operator_extra_links = (YQLink(),) @@ -48,7 +47,7 @@ def __init__( *, name: str | None = None, folder_id: str | None = None, - connection_id: str | None = None, + yandex_conn_id: str | None = None, public_ssh_key: str | None = None, service_account_id: str | None = None, **kwargs, @@ -56,7 +55,7 @@ def __init__( super().__init__(**kwargs) self.name = name self.folder_id = folder_id - self.connection_id = connection_id + self.yandex_conn_id = yandex_conn_id self.public_ssh_key = public_ssh_key self.service_account_id = service_account_id @@ -65,7 +64,7 @@ def __init__( def execute(self, context: Context) -> Any: self.hook = YQHook( - yandex_conn_id=self.connection_id, + yandex_conn_id=self.yandex_conn_id, default_folder_id=self.folder_id, default_public_ssh_key=self.public_ssh_key, default_service_account_id=self.service_account_id, diff --git a/airflow/providers/yandex/provider.yaml b/airflow/providers/yandex/provider.yaml index 12627e56557fd..af4d1d8c54450 100644 --- a/airflow/providers/yandex/provider.yaml +++ b/airflow/providers/yandex/provider.yaml @@ -49,6 +49,9 @@ versions: dependencies: - apache-airflow>=2.6.0 - yandexcloud>=0.228.0 + - python-dateutil>=2.8.0 + # Requests 3 if it will be released, will be heavily breaking. + - requests>=2.27.0,<3 integrations: - integration-name: Yandex.Cloud @@ -77,7 +80,7 @@ operators: - integration-name: Yandex.Cloud YQ python-modules: - - airflow.providers.yandex.operators.yandexcloud_yq + - airflow.providers.yandex.operators.yq hooks: - integration-name: Yandex.Cloud diff --git a/tests/providers/yandex/operators/test_yandexcloud_yq.py b/tests/providers/yandex/operators/test_yq.py similarity index 97% rename from tests/providers/yandex/operators/test_yandexcloud_yq.py rename to tests/providers/yandex/operators/test_yq.py index ee465804f8bb7..040f4089b4c6a 100644 --- a/tests/providers/yandex/operators/test_yandexcloud_yq.py +++ b/tests/providers/yandex/operators/test_yq.py @@ -26,7 +26,7 @@ from airflow.models import Connection from airflow.models.dag import DAG -from airflow.providers.yandex.operators.yandexcloud_yq import YQExecuteQueryOperator +from airflow.providers.yandex.operators.yq import YQExecuteQueryOperator OAUTH_TOKEN = "my_oauth_token" FOLDER_ID = "my_folder_id" From 2c3f3152c654c9fd3c1fc2f66c233a8dc38ec9e1 Mon Sep 17 00:00:00 2001 From: Sergey Uzhakov Date: Thu, 14 Mar 2024 15:50:04 +0300 Subject: [PATCH 30/34] replace SQLExecuteQueryOperator with BaseOperator --- airflow/providers/yandex/operators/yq.py | 26 +++++++++++++---------- dev/breeze/tests/test_selective_checks.py | 2 +- generated/provider_dependencies.json | 6 +++--- 3 files changed, 19 insertions(+), 15 deletions(-) diff --git a/airflow/providers/yandex/operators/yq.py b/airflow/providers/yandex/operators/yq.py index e6419efd7172b..d6c258305c042 100644 --- a/airflow/providers/yandex/operators/yq.py +++ b/airflow/providers/yandex/operators/yq.py @@ -16,9 +16,10 @@ # under the License. from __future__ import annotations +from functools import cached_property from typing import TYPE_CHECKING, Any, Sequence -from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator +from airflow.models import BaseOperator from airflow.providers.yandex.hooks.yq import YQHook from airflow.providers.yandex.links.yq import YQLink @@ -26,7 +27,7 @@ from airflow.utils.context import Context -class YQExecuteQueryOperator(SQLExecuteQueryOperator): +class YQExecuteQueryOperator(BaseOperator): """ Executes sql code using Yandex Query service. @@ -50,6 +51,7 @@ def __init__( yandex_conn_id: str | None = None, public_ssh_key: str | None = None, service_account_id: str | None = None, + sql: str, **kwargs, ) -> None: super().__init__(**kwargs) @@ -58,18 +60,21 @@ def __init__( self.yandex_conn_id = yandex_conn_id self.public_ssh_key = public_ssh_key self.service_account_id = service_account_id + self.sql = sql - self.hook: YQHook | None = None self.query_id: str | None = None - def execute(self, context: Context) -> Any: - self.hook = YQHook( - yandex_conn_id=self.yandex_conn_id, - default_folder_id=self.folder_id, - default_public_ssh_key=self.public_ssh_key, - default_service_account_id=self.service_account_id, - ) + @cached_property + def hook(self) -> YQHook: + """Get valid hook.""" + return YQHook( + yandex_conn_id=self.yandex_conn_id, + default_folder_id=self.folder_id, + default_public_ssh_key=self.public_ssh_key, + default_service_account_id=self.service_account_id, + ) + def execute(self, context: Context) -> Any: self.query_id = self.hook.create_query(query_text=self.sql, name=self.name) # pass to YQLink @@ -85,4 +90,3 @@ def on_kill(self) -> None: if self.hook is not None and self.query_id is not None: self.hook.stop_query(self.query_id) self.hook.close() - self.hook = None diff --git a/dev/breeze/tests/test_selective_checks.py b/dev/breeze/tests/test_selective_checks.py index f0810e0090bc5..68a03f714f6b9 100644 --- a/dev/breeze/tests/test_selective_checks.py +++ b/dev/breeze/tests/test_selective_checks.py @@ -1463,7 +1463,7 @@ def test_upgrade_to_newer_dependencies( "docs-list-as-string": "apache-airflow amazon apache.drill apache.druid apache.hive " "apache.impala apache.pinot common.sql databricks elasticsearch " "exasol google jdbc microsoft.mssql mysql odbc openlineage " - "oracle pgvector postgres presto slack snowflake sqlite teradata trino vertica yandex", + "oracle pgvector postgres presto slack snowflake sqlite teradata trino vertica", }, id="Common SQL provider package python files changed", ), diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index c751aa92ddc3a..6168100e371b2 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -1181,12 +1181,12 @@ "yandex": { "deps": [ "apache-airflow>=2.6.0", + "python-dateutil>=2.8.0", + "requests>=2.27.0,<3", "yandexcloud>=0.228.0" ], "devel-deps": [], - "cross-providers-deps": [ - "common.sql" - ], + "cross-providers-deps": [], "excluded-python-versions": [], "state": "ready" }, From c1a760d3489bf98bec8e98fc03c82b730f654e98 Mon Sep 17 00:00:00 2001 From: Sergey Uzhakov Date: Thu, 14 Mar 2024 19:30:08 +0300 Subject: [PATCH 31/34] fix static checks --- airflow/providers/yandex/operators/yq.py | 10 +++++----- pyproject.toml | 2 ++ 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/airflow/providers/yandex/operators/yq.py b/airflow/providers/yandex/operators/yq.py index d6c258305c042..52261edd31893 100644 --- a/airflow/providers/yandex/operators/yq.py +++ b/airflow/providers/yandex/operators/yq.py @@ -68,11 +68,11 @@ def __init__( def hook(self) -> YQHook: """Get valid hook.""" return YQHook( - yandex_conn_id=self.yandex_conn_id, - default_folder_id=self.folder_id, - default_public_ssh_key=self.public_ssh_key, - default_service_account_id=self.service_account_id, - ) + yandex_conn_id=self.yandex_conn_id, + default_folder_id=self.folder_id, + default_public_ssh_key=self.public_ssh_key, + default_service_account_id=self.service_account_id, + ) def execute(self, context: Context) -> Any: self.query_id = self.hook.create_query(query_text=self.sql, name=self.name) diff --git a/pyproject.toml b/pyproject.toml index 2a26414becc74..de6a4733d182a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -969,6 +969,8 @@ weaviate = [ # source: airflow/providers/weaviate/provider.yaml "weaviate-client>=3.24.2", ] yandex = [ # source: airflow/providers/yandex/provider.yaml + "python-dateutil>=2.8.0", + "requests>=2.27.0,<3", "yandexcloud>=0.228.0", ] zendesk = [ # source: airflow/providers/zendesk/provider.yaml From daa96803d53611f5537aee370fed62b354d60ec8 Mon Sep 17 00:00:00 2001 From: Sergey Uzhakov Date: Thu, 14 Mar 2024 19:47:28 +0300 Subject: [PATCH 32/34] fight with static checks --- airflow/providers/yandex/yq_client/query_results.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/airflow/providers/yandex/yq_client/query_results.py b/airflow/providers/yandex/yq_client/query_results.py index a321e835b256f..b243dbe49477b 100644 --- a/airflow/providers/yandex/yq_client/query_results.py +++ b/airflow/providers/yandex/yq_client/query_results.py @@ -143,7 +143,9 @@ def _convert_from_optional(value: list[Any]) -> Any: # If value is None than result is {"rows":[[[]]]} # So check if len equals 1 it means that it contains value # if len is 0 it means it has no value i.e. value is None - assert len(value) < 2, str(value) + if len(value) > 1: + raise RuntimeError(f"Value should have len 0 or 1, but has len {len(value)}, value: {value}") + if len(value) == 1: return value[0] @@ -228,9 +230,11 @@ def convert(x): inner_converters_list = [YQResults._get_converter(t) for t in inner_types_list] def convert(x): - assert len(x) == len( - inner_converters_list - ), f"Wrong length for tuple value: {len(x)} != {len(inner_converters_list)}" + if len(x) != len(inner_converters_list): + raise RuntimeError( + f"Wrong length for tuple value: {len(x)} != {len(inner_converters_list)}" + ) + return tuple([c(v) for (c, v) in zip(inner_converters_list, x)]) return convert From 072e625e2ef86c09d96c0dc9aa411944d60cf677 Mon Sep 17 00:00:00 2001 From: Sergey Uzhakov Date: Sat, 16 Mar 2024 15:59:55 +0300 Subject: [PATCH 33/34] remove http client, use py package --- airflow/providers/yandex/hooks/yq.py | 2 +- airflow/providers/yandex/provider.yaml | 1 + .../providers/yandex/yq_client/__init__.py | 16 - .../providers/yandex/yq_client/http_client.py | 329 ----------------- .../yandex/yq_client/query_results.py | 336 ------------------ generated/provider_dependencies.json | 1 + pyproject.toml | 1 + tests/providers/yandex/hooks/test_yq.py | 4 +- tests/providers/yandex/yq_client/__init__.py | 16 - .../yandex/yq_client/test_http_client.py | 130 ------- .../yandex/yq_client/test_query_results.py | 278 --------------- 11 files changed, 6 insertions(+), 1108 deletions(-) delete mode 100644 airflow/providers/yandex/yq_client/__init__.py delete mode 100644 airflow/providers/yandex/yq_client/http_client.py delete mode 100644 airflow/providers/yandex/yq_client/query_results.py delete mode 100644 tests/providers/yandex/yq_client/__init__.py delete mode 100644 tests/providers/yandex/yq_client/test_http_client.py delete mode 100644 tests/providers/yandex/yq_client/test_query_results.py diff --git a/airflow/providers/yandex/hooks/yq.py b/airflow/providers/yandex/hooks/yq.py index 288ee9ae71a67..6432c876a90f6 100644 --- a/airflow/providers/yandex/hooks/yq.py +++ b/airflow/providers/yandex/hooks/yq.py @@ -27,7 +27,7 @@ from airflow.exceptions import AirflowException from airflow.providers.yandex.hooks.yandex import YandexCloudBaseHook from airflow.providers.yandex.utils.user_agent import provider_user_agent -from airflow.providers.yandex.yq_client.http_client import YQHttpClient, YQHttpClientConfig +from yandex_query_client import YQHttpClient, YQHttpClientConfig class YQHook(YandexCloudBaseHook): diff --git a/airflow/providers/yandex/provider.yaml b/airflow/providers/yandex/provider.yaml index af4d1d8c54450..0135ac3fb4835 100644 --- a/airflow/providers/yandex/provider.yaml +++ b/airflow/providers/yandex/provider.yaml @@ -49,6 +49,7 @@ versions: dependencies: - apache-airflow>=2.6.0 - yandexcloud>=0.228.0 + - yandex-query-client>=0.1.2 - python-dateutil>=2.8.0 # Requests 3 if it will be released, will be heavily breaking. - requests>=2.27.0,<3 diff --git a/airflow/providers/yandex/yq_client/__init__.py b/airflow/providers/yandex/yq_client/__init__.py deleted file mode 100644 index 13a83393a9124..0000000000000 --- a/airflow/providers/yandex/yq_client/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. diff --git a/airflow/providers/yandex/yq_client/http_client.py b/airflow/providers/yandex/yq_client/http_client.py deleted file mode 100644 index 7cb06df376344..0000000000000 --- a/airflow/providers/yandex/yq_client/http_client.py +++ /dev/null @@ -1,329 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -import logging -import time -from datetime import datetime -from typing import Any - -import requests -from requests.adapters import HTTPAdapter -from urllib3.util.retry import Retry - -from airflow.providers.yandex.yq_client.query_results import YQResults - -MAX_RETRY_FOR_SESSION = 4 -BACK_OFF_FACTOR = 0.3 -TIME_BETWEEN_RETRIES = 1000 -ERROR_CODES = (500, 502, 504) - -logger = logging.getLogger(__name__) - - -def _requests_retry_session( - session, retries=MAX_RETRY_FOR_SESSION, back_off_factor=BACK_OFF_FACTOR, status_force_list=ERROR_CODES -): - retry = Retry( - total=retries, - read=retries, - connect=retries, - backoff_factor=back_off_factor, - status_forcelist=status_force_list, - allowed_methods=frozenset(["GET", "POST"]), - ) - adapter = HTTPAdapter(max_retries=retry) - session.mount("http://", adapter) - session.mount("https://", adapter) - return session - - -class YQHttpClientConfig: - """YandexQuery HTTP client config.""" - - def __init__( - self, - token: str | None = None, - project: str | None = None, - user_agent: str | None = "Python YQ HTTP SDK", - ) -> None: - self.token = token - self.project = project - self.user_agent = user_agent - - # urls should not contain trailing / - self.endpoint: str = "https://api.yandex-query.cloud.yandex.net" - self.web_base_url: str = "https://yq.cloud.yandex.ru" - self.token_prefix = "Bearer " - - -class YQHttpClientException(Exception): - """YandexQuery client exception type.""" - - def __init__( - self, message: str, status: str | None = None, msg: str | None = None, details: Any = None - ) -> None: - super().__init__(message) - self.status = status - self.msg = msg - self.details = details - - -class YQHttpClient: - """YandexQuery HTTP client.""" - - def __init__(self, config: YQHttpClientConfig): - self.config = config - self.session = _requests_retry_session(session=requests.Session()) - - def __enter__(self): - """Return the object when a context manager is created.""" - return self - - def __exit__(self, *args): - """Close network connection when exiting the context manager.""" - self.session.close() - - def close(self): - """Close network connection.""" - self.session.close() - - def _build_headers(self, idempotency_key=None, request_id=None) -> dict[str, str]: - headers = {"Authorization": f"{self.config.token_prefix}{self.config.token}"} - if idempotency_key is not None: - headers["Idempotency-Key"] = idempotency_key - - if request_id is not None: - headers["x-request-id"] = request_id - - if self.config.user_agent is not None: - headers["User-Agent"] = self.config.user_agent - - return headers - - def _build_params(self) -> dict[str, str]: - params = {} - if self.config.project is not None: - params["project"] = self.config.project - - return params - - def _compose_api_url(self, path: str) -> str: - return self.config.endpoint + path - - def _compose_web_url(self, path: str) -> str: - return self.config.web_base_url + path - - def _validate_http_error(self, response, expected_code=200) -> None: - logger.debug("Response: %s, %s", response.status_code, response.text) - if response.status_code != expected_code: - if response.headers.get("Content-Type", "").startswith("application/json"): - body = response.json() - status = body.get("status") - msg = body.get("message") - details = body.get("details") - raise YQHttpClientException( - f"Error occurred. http code={response.status_code}, status={status}, msg={msg}, details={details}", - status=status, - msg=msg, - details=details, - ) - - raise YQHttpClientException(f"Error occurred: {response.status_code}, {response.text}") - - def create_query( - self, - query_text=None, - query_type=None, - name=None, - description=None, - idempotency_key=None, - request_id=None, - expected_code=200, - ): - body = {} - if query_text is not None: - body["text"] = query_text - - if query_type is not None: - body["type"] = query_type - - if name is not None: - body["name"] = name - - if description is not None: - body["description"] = description - - response = self.session.post( - self._compose_api_url("/api/fq/v1/queries"), - headers=self._build_headers(idempotency_key=idempotency_key, request_id=request_id), - params=self._build_params(), - json=body, - ) - - self._validate_http_error(response, expected_code=expected_code) - return response.json()["id"] - - def get_query_status(self, query_id, request_id=None, expected_code=200) -> Any: - response = self.session.get( - self._compose_api_url(f"/api/fq/v1/queries/{query_id}/status"), - headers=self._build_headers(request_id=request_id), - params=self._build_params(), - ) - - self._validate_http_error(response, expected_code=expected_code) - return response.json()["status"] - - def get_query(self, query_id, request_id=None, expected_code=200) -> Any: - response = self.session.get( - self._compose_api_url(f"/api/fq/v1/queries/{query_id}"), - headers=self._build_headers(request_id=request_id), - params=self._build_params(), - ) - - self._validate_http_error(response, expected_code=expected_code) - return response.json() - - def stop_query( - self, - query_id: str, - idempotency_key: str | None = None, - request_id: str | None = None, - expected_code: int = 204, - ) -> Any: - headers = self._build_headers(idempotency_key=idempotency_key, request_id=request_id) - response = self.session.post( - self._compose_api_url(f"/api/fq/v1/queries/{query_id}/stop"), - headers=headers, - params=self._build_params(), - ) - self._validate_http_error(response, expected_code=expected_code) - return response - - def wait_query_to_complete(self, query_id, execution_timeout=None, stop_on_timeout=False) -> str: - status = None - delay = 0.2 # start with 0.2 sec - try: - start = datetime.now() - while True: - if execution_timeout is not None and datetime.now() > start + execution_timeout: - raise TimeoutError(f"Query {query_id} execution timeout, last status {status}") - - status = self.get_query_status(query_id) - if status not in ["RUNNING", "PENDING"]: - return status - - time.sleep(delay) - delay *= 2 - delay = min(2, delay) # up to 2 seconds - - except TimeoutError: - if stop_on_timeout: - self.stop_query(query_id) - raise - - def wait_query_to_succeed(self, query_id, execution_timeout=None, stop_on_timeout=False) -> int: - status = self.wait_query_to_complete( - query_id=query_id, execution_timeout=execution_timeout, stop_on_timeout=stop_on_timeout - ) - - query = self.get_query(query_id) - if status != "COMPLETED": - issues = query["issues"] - raise RuntimeError(f"Query {query_id} failed with issues={issues}") - - return len(query["result_sets"]) - - def get_query_result_set_page( - self, - query_id, - result_set_index, - offset=None, - limit=None, - raw_format=False, - request_id=None, - expected_code=200, - ) -> Any: - params = self._build_params() - if offset is not None: - params["offset"] = offset - - if limit is not None: - params["limit"] = limit - - response = self.session.get( - self._compose_api_url(f"/api/fq/v1/queries/{query_id}/results/{result_set_index}"), - headers=self._build_headers(request_id=request_id), - params=params, - ) - - self._validate_http_error(response, expected_code=expected_code) - return response.json() - - def get_query_result_set(self, query_id: str, result_set_index: int, raw_format: bool = False) -> Any: - offset = 0 - limit = 1000 - columns = None - rows = [] - while True: - part = self.get_query_result_set_page( - query_id, result_set_index=result_set_index, offset=offset, limit=limit, raw_format=raw_format - ) - - if columns is None: - columns = part["columns"] - - r = part["rows"] - rows.extend(r) - if len(r) != limit: - break - - offset += limit - - result = {"rows": rows, "columns": columns} - if raw_format: - return result - - return YQResults(result).results - - def get_query_all_result_sets( - self, query_id: str, result_set_count: int, raw_format: bool = False - ) -> Any: - result = [] - for i in range(0, result_set_count): - r = self.get_query_result_set(query_id, result_set_index=i, raw_format=raw_format) - - if result_set_count == 1: - return r - - result.append(r) - return result - - def get_openapi_spec(self) -> str: - response = self.session.get(self._compose_api_url("/resources/v1/openapi.yaml")) - self._validate_http_error(response) - return response.text - - def compose_query_web_link(self, query_id) -> str: - return self._compose_web_url(f"/folders/{self.config.project}/ide/queries/{query_id}") - - @staticmethod - def result_set_to_dataframe(data): - import pandas as pd - - column_names = [column["name"] for column in data["columns"]] - return pd.DataFrame(data["rows"], columns=column_names) diff --git a/airflow/providers/yandex/yq_client/query_results.py b/airflow/providers/yandex/yq_client/query_results.py deleted file mode 100644 index b243dbe49477b..0000000000000 --- a/airflow/providers/yandex/yq_client/query_results.py +++ /dev/null @@ -1,336 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -import base64 -import pprint -from datetime import datetime -from decimal import Decimal -from typing import Any - -import dateutil.parser - - -class YQResults: - """Holds and formats query execution results.""" - - def __init__(self, results: dict[str, Any]): - self._raw_results = results - self._results = None - - @staticmethod - def _convert_from_float(value: float | str) -> float: - # special values, e.g inf encoded as str, normal values are in float - return float(value) - - @staticmethod - def _convert_from_pgfloat(value: str | None) -> float | None: - if value is None: - return None - return float(value) - - @staticmethod - def _convert_from_pgint(value: str | None) -> int | None: - if value is None: - return None - return int(value) - - @staticmethod - def _convert_from_decimal(value: str) -> Decimal: - return Decimal(value) - - @staticmethod - def _convert_from_pgnumeric(value: str | None) -> Decimal | None: - if value is None: - return None - return Decimal(value) - - @staticmethod - def _convert_from_base64(value: str) -> str | bytes: - b = base64.b64decode(value) - try: - return b.decode("utf-8") - except UnicodeDecodeError: - return b - - @staticmethod - def _convert_from_datetime(value: str) -> datetime: - # suitable for yql data and datetime parsing - return dateutil.parser.isoparse(value) - - @staticmethod - def _convert_from_pgdatetime(value: str | None) -> datetime | None: - if value is None: - return None - return dateutil.parser.isoparse(value) - - @staticmethod - def _convert_from_enum(value: list) -> str: - return str(value[0]) - - @staticmethod - def _extract_from_optional(type_name: str) -> str: - # Uint16? -> Uint16 - if type_name.endswith("?"): - return type_name[0:-1] - - # Optional -> Uint16 - return type_name[len("Optional<") : -1] - - @staticmethod - def _extract_from_set(type_name: str) -> str: - # Set -> Uint16 - return type_name[len("Set<") : -1] - - @staticmethod - def _extract_from_list(type_name: str) -> str: - # List -> Uint16 - return type_name[len("List<") : -1] - - @staticmethod - def _split_type_list(type_list: str) -> list[str]: - # naive implementation - # fixme: fix it - return type_list.split(",") - - @staticmethod - def _extract_from_tuple(type_name: str) -> list[str]: - # Tuple -> [Uint16, String, Double] - return YQResults._split_type_list(type_name[len("Tuple<") : -1]) - - @staticmethod - def _extract_from_dict(type_name: str) -> tuple[str, str]: - # Dict -> (Uint16, String) - [key_type, value_type] = YQResults._split_type_list(type_name[len("Dict<") : -1]) - return key_type, value_type - - @staticmethod - def _extract_from_variant_over_struct(type_name: str) -> dict[str, str]: - # Variant<'One':Int32,'Two':String> -> {One: Int32, Two: String} - types_with_names = YQResults._split_type_list(type_name[len("Variant<") : -1]) - result = {} - for t in types_with_names: - [n, t] = t.split(":") - # strip ' - n = n[1:-1] - result[n] = t - return result - - @staticmethod - def _extract_from_variant_over_tuple(type_name: str) -> list[str]: - # Variant -> [Int32, String] - return YQResults._split_type_list(type_name[len("Variant<") : -1]) - - @staticmethod - def _convert_from_optional(value: list[Any]) -> Any: - # Optional types are encoded as [[]] objects - # If type is Uint16, value is encoded as {"rows":[[value]]} - # If type is Optional, value is encoded as {"rows":[[[value]]]} - # If value is None than result is {"rows":[[[]]]} - # So check if len equals 1 it means that it contains value - # if len is 0 it means it has no value i.e. value is None - if len(value) > 1: - raise RuntimeError(f"Value should have len 0 or 1, but has len {len(value)}, value: {value}") - - if len(value) == 1: - return value[0] - - return None - - @staticmethod - def id(v): - return v - - @staticmethod - def _get_converter(column_type: str) -> Any: - """Return converter based on column type.""" - # primitives - if column_type in [ - "Int8", - "Int16", - "Int32", - "Int64", - "Uint8", - "Uint16", - "Uint32", - "Uint64", - "Bool", - "Utf8", - "Uuid", - "Void", - "Null", - "EmptyList", - "Struct<>", - "Tuple<>", - ]: - return YQResults.id - - if column_type == "String": - return YQResults._convert_from_base64 - - if column_type in ["Float", "Double"]: - return YQResults._convert_from_float - - if column_type.startswith("Decimal("): - return YQResults._convert_from_decimal - - if column_type.startswith("Enum<"): - return YQResults._convert_from_enum - - if column_type in ["Date", "Datetime", "Timestamp"]: - return YQResults._convert_from_datetime - - # containers - if column_type.startswith("Optional<") or column_type.endswith("?"): - # If type is Optional than get base type - inner_converter = YQResults._get_converter(YQResults._extract_from_optional(column_type)) - - # Remove "Optional" encoding - # and convert resulting value as others - def convert(x): - inner_value = YQResults._convert_from_optional(x) - if inner_value is None: - return None - return inner_converter(inner_value) - - return convert - - if column_type.startswith("Set<"): - inner_converter = YQResults._get_converter(YQResults._extract_from_set(column_type)) - - def convert(x): - return {inner_converter(v) for v in x} - - return convert - - if column_type.startswith("List<"): - inner_converter = YQResults._get_converter(YQResults._extract_from_list(column_type)) - - def convert(x): - return [inner_converter(v) for v in x] - - return convert - - if column_type.startswith("Tuple<"): - inner_types_list = YQResults._extract_from_tuple(column_type) - inner_converters_list = [YQResults._get_converter(t) for t in inner_types_list] - - def convert(x): - if len(x) != len(inner_converters_list): - raise RuntimeError( - f"Wrong length for tuple value: {len(x)} != {len(inner_converters_list)}" - ) - - return tuple([c(v) for (c, v) in zip(inner_converters_list, x)]) - - return convert - - # variant over struct - if column_type.startswith("Variant<'"): - inner_types_dict = YQResults._extract_from_variant_over_struct(column_type) - inner_converters_dict = {k: YQResults._get_converter(t) for k, t in inner_types_dict.items()} - - def convert(x): - return inner_converters_dict[x[0]](x[1]) - - return convert - - # variant over tuple - if column_type.startswith("Variant<"): - inner_types_list = YQResults._extract_from_variant_over_tuple(column_type) - inner_converters_list = [YQResults._get_converter(t) for t in inner_types_list] - - def convert(x): - return inner_converters_list[x[0]](x[1]) - - return convert - - if column_type == "EmptyDict": - - def convert(x): - return {} - - return convert - - if column_type.startswith("Dict<"): - key_type, value_type = YQResults._extract_from_dict(column_type) - key_converter = YQResults._get_converter(key_type) - value_converter = YQResults._get_converter(value_type) - - def convert(x): - return {key_converter(v[0]): value_converter(v[1]) for v in x} - - return convert - - # pg types - if column_type.startswith("pgfloat"): - return YQResults._convert_from_pgfloat - - if column_type in ["pgint2", "pgint4", "pgint8"]: - return YQResults._convert_from_pgint - - if column_type == "pgnumeric": - return YQResults._convert_from_pgnumeric - - if column_type in ["pgdate", "pgtimestamp"]: - return YQResults._convert_from_pgdatetime - - if column_type.startswith("pg"): - return YQResults.id - - # unsupported type - return YQResults.id - - def _convert(self): - converters = [] - converted_results = [] - for column in self._raw_results["columns"]: - converters.append(YQResults._get_converter(column["type"])) - - for row in self._raw_results["rows"]: - new_row = [] - for index, value in enumerate(row): - converter = converters[index] - new_row.append(value if converter is None else converter(value)) - - converted_results.append(new_row) - - self._results = {"rows": converted_results, "columns": self._raw_results["columns"]} - - def _repr_pretty_(self, p, cycle): - p.text(pprint.pformat(self._results)) - - @property - def results(self): - if self._results is None: - self._convert() - - return self._results - - @property - def raw_results(self): - return self._raw_results - - def to_table(self): - return self._results["rows"] - - def to_dataframe(self): - result_set = self._results - columns = [column["name"] for column in result_set["columns"]] - import pandas - - return pandas.DataFrame(result_set["rows"], columns=columns) diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index 6168100e371b2..26fe5a1aee343 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -1183,6 +1183,7 @@ "apache-airflow>=2.6.0", "python-dateutil>=2.8.0", "requests>=2.27.0,<3", + "yandex-query-client>=0.1.2", "yandexcloud>=0.228.0" ], "devel-deps": [], diff --git a/pyproject.toml b/pyproject.toml index de6a4733d182a..1f56369c01526 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -971,6 +971,7 @@ weaviate = [ # source: airflow/providers/weaviate/provider.yaml yandex = [ # source: airflow/providers/yandex/provider.yaml "python-dateutil>=2.8.0", "requests>=2.27.0,<3", + "yandex-query-client>=0.1.2", "yandexcloud>=0.228.0", ] zendesk = [ # source: airflow/providers/zendesk/provider.yaml diff --git a/tests/providers/yandex/hooks/test_yq.py b/tests/providers/yandex/hooks/test_yq.py index 98ec29ea98534..9ac0cb201eb4b 100644 --- a/tests/providers/yandex/hooks/test_yq.py +++ b/tests/providers/yandex/hooks/test_yq.py @@ -63,7 +63,7 @@ def test_oauth_token_usage(self): assert query_id == "query1" with mock.patch( - "airflow.providers.yandex.yq_client.http_client.YQHttpClient.compose_query_web_link" + "yandex_query_client.YQHttpClient.compose_query_web_link" ) as m: m.return_value = "http://gg.zz" assert self.hook.compose_query_web_link("query1") == "http://gg.zz" @@ -83,7 +83,7 @@ def test_select_results(self, mock_jwt, mock_sdk): mock_sdk.return_value = DummySDK() with mock.patch.multiple( - "airflow.providers.yandex.yq_client.http_client.YQHttpClient", + "yandex_query_client.YQHttpClient", create_query=mock.DEFAULT, wait_query_to_succeed=mock.DEFAULT, get_query_all_result_sets=mock.DEFAULT, diff --git a/tests/providers/yandex/yq_client/__init__.py b/tests/providers/yandex/yq_client/__init__.py deleted file mode 100644 index 13a83393a9124..0000000000000 --- a/tests/providers/yandex/yq_client/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. diff --git a/tests/providers/yandex/yq_client/test_http_client.py b/tests/providers/yandex/yq_client/test_http_client.py deleted file mode 100644 index 2ed5ea6fec79a..0000000000000 --- a/tests/providers/yandex/yq_client/test_http_client.py +++ /dev/null @@ -1,130 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -import responses -from responses import matchers - -from airflow.providers.yandex.yq_client.http_client import YQHttpClient, YQHttpClientConfig - -IAM_TOKEN = "my_iam_token" -PROJECT = "my_project" - - -class TestYQHttpClient: - def setup_method(self): - config = YQHttpClientConfig(IAM_TOKEN, PROJECT) - self.client = YQHttpClient(config) - - def setup_mocks_for_query_execution(self, query_results_json_list): - responses.post( - "https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries", - match=[ - matchers.header_matcher( - {"Content-Type": "application/json", "Authorization": f"Bearer {IAM_TOKEN}"} - ), - matchers.query_param_matcher({"project": PROJECT}), - matchers.json_params_matcher({"name": "my query", "text": "select 777"}), - ], - json={"id": "query1"}, - status=200, - ) - - responses.get( - "https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries/query1/status", - json={"status": "RUNNING"}, - status=200, - ) - - responses.get( - "https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries/query1/status", - json={"status": "COMPLETED"}, - status=200, - ) - - result_set_count = len(query_results_json_list) - responses.get( - "https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries/query1", - json={ - "id": "query1", - "result_sets": [{"rows_count": 1, "truncated": False} for _ in range(result_set_count)], - }, - status=200, - ) - - for i in range(result_set_count): - responses.get( - f"https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries/query1/results/{i}", - json=query_results_json_list[i], - status=200, - ) - - def _create_test_query(self): - query_id = self.client.create_query(query_text="select 777", name="my query") - assert query_id == "query1" - return query_id - - @responses.activate() - def test_select_results(self): - self.setup_mocks_for_query_execution( - [{"rows": [[777]], "columns": [{"name": "column0", "type": "Int32"}]}] - ) - - query_id = self._create_test_query() - - result_set_count = self.client.wait_query_to_succeed(query_id) - assert result_set_count == 1 - results = self.client.get_query_all_result_sets(query_id, result_set_count=result_set_count) - assert results == {"rows": [[777]], "columns": [{"name": "column0", "type": "Int32"}]} - - responses.post( - "https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries/query1/stop", - match=[ - matchers.header_matcher({"Authorization": f"Bearer {IAM_TOKEN}"}), - matchers.query_param_matcher({"project": PROJECT}), - ], - status=204, - ) - - assert self.client.get_query_status(query_id) == "COMPLETED" - assert self.client.get_query(query_id) == { - "id": "query1", - "result_sets": [{"rows_count": 1, "truncated": False}], - } - self.client.stop_query(query_id) - - assert ( - self.client.compose_query_web_link(query_id) - == f"https://yq.cloud.yandex.ru/folders/{PROJECT}/ide/queries/query1" - ) - - @responses.activate() - def test_select_two_record_sets(self): - self.setup_mocks_for_query_execution( - [ - {"rows": [[777]], "columns": [{"name": "column0", "type": "Int32"}]}, - {"rows": [["zzz"]], "columns": [{"name": "aaaa", "type": "Utf8"}]}, - ] - ) - - query_id = self._create_test_query() - - result_set_count = self.client.wait_query_to_succeed(query_id) - assert result_set_count == 2 - results = self.client.get_query_all_result_sets(query_id, result_set_count=result_set_count) - assert results[0] == {"rows": [[777]], "columns": [{"name": "column0", "type": "Int32"}]} - assert results[1] == {"rows": [["zzz"]], "columns": [{"name": "aaaa", "type": "Utf8"}]} diff --git a/tests/providers/yandex/yq_client/test_query_results.py b/tests/providers/yandex/yq_client/test_query_results.py deleted file mode 100644 index 30becbd015926..0000000000000 --- a/tests/providers/yandex/yq_client/test_query_results.py +++ /dev/null @@ -1,278 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -from datetime import datetime -from decimal import Decimal - -from dateutil.tz import tzutc - -from airflow.providers.yandex.yq_client.query_results import YQResults - - -class TestYQResults: - def test_integral_results(self): - # json response and results could be found here: https://github.com/ydb-platform/ydb/blob/284b7efb67edcdade0b12c849b7fad40739ad62b/ydb/tests/fq/http_api/test_http_api.py#L336 - r = YQResults( - { - "rows": [ - [ - 100, - -100, - 200, - 200, - 10000000000, - -20000000000, - "18014398509481984", - "-18014398509481984", - 123.5, - -789.125, - "inf", - True, - False, - "aGVsbG8=", - "hello", - "1.23", - 'he"llo_again', - "Я Привет", - 1, - 2, - 3, - 4, - ] - ], - "columns": [ - {"name": "column0", "type": "Int32"}, - {"name": "column1", "type": "Int32"}, - {"name": "column2", "type": "Int64"}, - {"name": "column3", "type": "Uint64"}, - {"name": "column4", "type": "Uint64"}, - {"name": "column5", "type": "Int64"}, - {"name": "column6", "type": "Int64"}, - {"name": "column7", "type": "Int64"}, - {"name": "column8", "type": "Float"}, - {"name": "column9", "type": "Double"}, - {"name": "column10", "type": "Double"}, - {"name": "column11", "type": "Bool"}, - {"name": "column12", "type": "Bool"}, - {"name": "column13", "type": "String"}, - {"name": "column14", "type": "Utf8"}, - {"name": "column15", "type": "Decimal(6,3)"}, - {"name": "column16", "type": "Utf8"}, - {"name": "column17", "type": "Utf8"}, - {"name": "column18", "type": "Int8"}, - {"name": "column19", "type": "Int16"}, - {"name": "column20", "type": "Uint8"}, - {"name": "column21", "type": "Uint16"}, - ], - }, - ) - - results = r.results - - assert results == { - "rows": [ - [ - 100, - -100, - 200, - 200, - 10000000000, - -20000000000, - "18014398509481984", - "-18014398509481984", - 123.5, - -789.125, - float("inf"), - True, - False, - "hello", - "hello", - Decimal("1.23"), - 'he"llo_again', - "Я Привет", - 1, - 2, - 3, - 4, - ] - ], - "columns": [ - {"name": "column0", "type": "Int32"}, - {"name": "column1", "type": "Int32"}, - {"name": "column2", "type": "Int64"}, - {"name": "column3", "type": "Uint64"}, - {"name": "column4", "type": "Uint64"}, - {"name": "column5", "type": "Int64"}, - {"name": "column6", "type": "Int64"}, - {"name": "column7", "type": "Int64"}, - {"name": "column8", "type": "Float"}, - {"name": "column9", "type": "Double"}, - {"name": "column10", "type": "Double"}, - {"name": "column11", "type": "Bool"}, - {"name": "column12", "type": "Bool"}, - {"name": "column13", "type": "String"}, - {"name": "column14", "type": "Utf8"}, - {"name": "column15", "type": "Decimal(6,3)"}, - {"name": "column16", "type": "Utf8"}, - {"name": "column17", "type": "Utf8"}, - {"name": "column18", "type": "Int8"}, - {"name": "column19", "type": "Int16"}, - {"name": "column20", "type": "Uint8"}, - {"name": "column21", "type": "Uint16"}, - ], - } - - def test_complex_results(self): - # json response and results could be found here: https://github.com/ydb-platform/ydb/blob/284b7efb67edcdade0b12c849b7fad40739ad62b/ydb/tests/fq/http_api/test_http_api.py#L445 - r = YQResults( - { - "rows": [ - [ - [], - [1, 2], - [], - [["YWJj", 1]], - [["xyz", 1]], - None, - "PT15M", - "2019-09-16", - "2019-09-16T10:46:05Z", - "2019-09-16T11:27:44.345849Z", - "2019-09-16,Europe/Moscow", - "2019-09-16T14:32:40,Europe/Moscow", - "2019-09-16T14:32:55.874913,Europe/Moscow", - ["One", 12], - [1, "eHl6"], - ["a", 1], - ["monday", None], - 1, - {}, - {"a": 1, "b": "xyz"}, - None, - None, - [[[1, [[177]]]]], - [[[1, []]]], - [[[1, []]]], - ["Foo", None], - ["Bar", None], - [], - [1, "cHJpdmV0", "2019-09-16"], - ] - ], - "columns": [ - {"name": "column0", "type": "EmptyList"}, - {"name": "column1", "type": "List"}, - {"name": "column2", "type": "EmptyDict"}, - {"name": "column3", "type": "Dict"}, - {"name": "column4", "type": "Dict"}, - {"name": "column5", "type": "Uuid"}, - {"name": "column6", "type": "Interval"}, - {"name": "column7", "type": "Date"}, - {"name": "column8", "type": "Datetime"}, - {"name": "column9", "type": "Timestamp"}, - {"name": "column10", "type": "TzDate"}, - {"name": "column11", "type": "TzDatetime"}, - {"name": "column12", "type": "TzTimestamp"}, - {"name": "column13", "type": "Variant<'One':Int32,'Two':String>"}, - {"name": "column14", "type": "Variant"}, - {"name": "column15", "type": "Variant<'a':Int32>"}, - {"name": "column16", "type": "Enum<'monday'>"}, - {"name": "column17", "type": "Tagged"}, - {"name": "column18", "type": "Struct<>"}, - {"name": "column19", "type": "Struct<'a':Int32,'b':Utf8>"}, - {"name": "column20", "type": "Void"}, - {"name": "column21", "type": "Null"}, - {"name": "column22", "type": "Optional?>"}, - {"name": "column23", "type": "Optional?>"}, - {"name": "column24", "type": "Optional?>"}, - {"name": "column25", "type": "Enum<'Bar','Foo'>"}, - {"name": "column26", "type": "Enum<'Bar','Foo'>"}, - {"name": "column27", "type": "Tuple<>"}, - {"name": "column28", "type": "Tuple"}, - ], - }, - ) - - results = r.results - - assert results == { - "rows": [ - [ - [], - [1, 2], - {}, - {"abc": 1}, - {"xyz": 1}, - None, # seems like http api doesn't support uuid values - "PT15M", - datetime(2019, 9, 16, 0, 0), - datetime(2019, 9, 16, 10, 46, 5, tzinfo=tzutc()), - datetime(2019, 9, 16, 11, 27, 44, 345849, tzinfo=tzutc()), - "2019-09-16,Europe/Moscow", - "2019-09-16T14:32:40,Europe/Moscow", - "2019-09-16T14:32:55.874913,Europe/Moscow", - 12, - "xyz", - 1, - "monday", - 1, - {}, - {"a": 1, "b": "xyz"}, - None, - None, - 177, - None, - None, - "Foo", - "Bar", - [], - (1, "privet", datetime(2019, 9, 16, 0, 0)), - ] - ], - "columns": [ - {"name": "column0", "type": "EmptyList"}, - {"name": "column1", "type": "List"}, - {"name": "column2", "type": "EmptyDict"}, - {"name": "column3", "type": "Dict"}, - {"name": "column4", "type": "Dict"}, - {"name": "column5", "type": "Uuid"}, - {"name": "column6", "type": "Interval"}, - {"name": "column7", "type": "Date"}, - {"name": "column8", "type": "Datetime"}, - {"name": "column9", "type": "Timestamp"}, - {"name": "column10", "type": "TzDate"}, - {"name": "column11", "type": "TzDatetime"}, - {"name": "column12", "type": "TzTimestamp"}, - {"name": "column13", "type": "Variant<'One':Int32,'Two':String>"}, - {"name": "column14", "type": "Variant"}, - {"name": "column15", "type": "Variant<'a':Int32>"}, - {"name": "column16", "type": "Enum<'monday'>"}, - {"name": "column17", "type": "Tagged"}, - {"name": "column18", "type": "Struct<>"}, - {"name": "column19", "type": "Struct<'a':Int32,'b':Utf8>"}, - {"name": "column20", "type": "Void"}, - {"name": "column21", "type": "Null"}, - {"name": "column22", "type": "Optional?>"}, - {"name": "column23", "type": "Optional?>"}, - {"name": "column24", "type": "Optional?>"}, - {"name": "column25", "type": "Enum<'Bar','Foo'>"}, - {"name": "column26", "type": "Enum<'Bar','Foo'>"}, - {"name": "column27", "type": "Tuple<>"}, - {"name": "column28", "type": "Tuple"}, - ], - } From da104590d459608c97792bc241e712719c094414 Mon Sep 17 00:00:00 2001 From: Sergey Uzhakov Date: Sat, 16 Mar 2024 18:16:01 +0300 Subject: [PATCH 34/34] fix static checks --- airflow/providers/yandex/hooks/yq.py | 2 +- tests/providers/yandex/hooks/test_yq.py | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/airflow/providers/yandex/hooks/yq.py b/airflow/providers/yandex/hooks/yq.py index 6432c876a90f6..963709d89b66c 100644 --- a/airflow/providers/yandex/hooks/yq.py +++ b/airflow/providers/yandex/hooks/yq.py @@ -23,11 +23,11 @@ import jwt import requests from urllib3.util.retry import Retry +from yandex_query_client import YQHttpClient, YQHttpClientConfig from airflow.exceptions import AirflowException from airflow.providers.yandex.hooks.yandex import YandexCloudBaseHook from airflow.providers.yandex.utils.user_agent import provider_user_agent -from yandex_query_client import YQHttpClient, YQHttpClientConfig class YQHook(YandexCloudBaseHook): diff --git a/tests/providers/yandex/hooks/test_yq.py b/tests/providers/yandex/hooks/test_yq.py index 9ac0cb201eb4b..3b3db91dd1eab 100644 --- a/tests/providers/yandex/hooks/test_yq.py +++ b/tests/providers/yandex/hooks/test_yq.py @@ -62,9 +62,7 @@ def test_oauth_token_usage(self): query_id = self.hook.create_query(query_text="select 777", name="my query") assert query_id == "query1" - with mock.patch( - "yandex_query_client.YQHttpClient.compose_query_web_link" - ) as m: + with mock.patch("yandex_query_client.YQHttpClient.compose_query_web_link") as m: m.return_value = "http://gg.zz" assert self.hook.compose_query_web_link("query1") == "http://gg.zz" m.assert_called_once_with("query1")