From 8ced563f26e93b3a6814361f6296615c50aa5e18 Mon Sep 17 00:00:00 2001 From: Satish Chinthanippu Date: Tue, 2 Jul 2024 13:49:20 -0700 Subject: [PATCH] Added support of Teradata Compute Cluster Provision, Decommission, Suspend and Resume operations (#40509) Support added to Teradata Provider about Teradata Compute Cluster feature Provisioning compute cluster instance Decommissioning compute cluster instance Resume and Suspend of compute cluster instance --- .../operators/teradata_compute_cluster.py | 513 +++++++++++++ airflow/providers/teradata/provider.yaml | 7 + .../providers/teradata/triggers/__init__.py | 16 + .../triggers/teradata_compute_cluster.py | 155 ++++ airflow/providers/teradata/utils/__init__.py | 16 + airflow/providers/teradata/utils/constants.py | 46 ++ .../operators/compute_cluster.rst | 107 +++ .../test_teradata_compute_cluster.py | 713 ++++++++++++++++++ tests/providers/teradata/triggers/__init__.py | 17 + .../triggers/test_teradata_compute_cluster.py | 174 +++++ tests/providers/teradata/utils/__init__.py | 17 + .../teradata/utils/test_constants.py | 110 +++ .../example_teradata_compute_cluster.py | 158 ++++ 13 files changed, 2049 insertions(+) create mode 100644 airflow/providers/teradata/operators/teradata_compute_cluster.py create mode 100644 airflow/providers/teradata/triggers/__init__.py create mode 100644 airflow/providers/teradata/triggers/teradata_compute_cluster.py create mode 100644 airflow/providers/teradata/utils/__init__.py create mode 100644 airflow/providers/teradata/utils/constants.py create mode 100644 docs/apache-airflow-providers-teradata/operators/compute_cluster.rst create mode 100644 tests/providers/teradata/operators/test_teradata_compute_cluster.py create mode 100644 tests/providers/teradata/triggers/__init__.py create mode 100644 tests/providers/teradata/triggers/test_teradata_compute_cluster.py create mode 100644 tests/providers/teradata/utils/__init__.py create mode 100644 tests/providers/teradata/utils/test_constants.py create mode 100644 tests/system/providers/teradata/example_teradata_compute_cluster.py diff --git a/airflow/providers/teradata/operators/teradata_compute_cluster.py b/airflow/providers/teradata/operators/teradata_compute_cluster.py new file mode 100644 index 0000000000000..6759c04ce74c1 --- /dev/null +++ b/airflow/providers/teradata/operators/teradata_compute_cluster.py @@ -0,0 +1,513 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import re +from abc import abstractmethod +from enum import Enum +from functools import cached_property +from typing import TYPE_CHECKING + +from airflow.models import BaseOperator +from airflow.providers.teradata.hooks.teradata import TeradataHook +from airflow.providers.teradata.utils.constants import Constants + +if TYPE_CHECKING: + from airflow.utils.context import Context + +from datetime import timedelta +from typing import TYPE_CHECKING, Any, Sequence, cast + +from airflow.providers.teradata.triggers.teradata_compute_cluster import TeradataComputeClusterSyncTrigger + +if TYPE_CHECKING: + from airflow.utils.context import Context + +from airflow.exceptions import AirflowException + + +# Represents +# 1. Compute Cluster Setup - Provision and Decomission operations +# 2. Compute Cluster State - Resume and Suspend operations +class _Operation(Enum): + SETUP = 1 + STATE = 2 + + +# Handler to handle single result set of a SQL query +def _single_result_row_handler(cursor): + records = cursor.fetchone() + if isinstance(records, list): + return records[0] + if records is None: + return records + raise TypeError(f"Unexpected results: {cursor.fetchone()!r}") + + +# Providers given operation is setup or state operation +def _determine_operation_context(operation): + if operation == Constants.CC_CREATE_OPR or operation == Constants.CC_DROP_OPR: + return _Operation.SETUP + return _Operation.STATE + + +class _TeradataComputeClusterOperator(BaseOperator): + """ + Teradata Compute Cluster Base Operator to set up and status operations of compute cluster. + + :param compute_profile_name: Name of the Compute Profile to manage. + :param compute_group_name: Name of compute group to which compute profile belongs. + :param teradata_conn_id: The :ref:`Teradata connection id ` + reference to a specific Teradata database. + :param timeout: Time elapsed before the task times out and fails. + """ + + template_fields: Sequence[str] = ( + "compute_profile_name", + "compute_group_name", + "teradata_conn_id", + "timeout", + ) + + ui_color = "#e07c24" + + def __init__( + self, + compute_profile_name: str, + compute_group_name: str | None = None, + teradata_conn_id: str = TeradataHook.default_conn_name, + timeout: int = Constants.CC_OPR_TIME_OUT, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.compute_profile_name = compute_profile_name + self.compute_group_name = compute_group_name + self.teradata_conn_id = teradata_conn_id + self.timeout = timeout + + @cached_property + def hook(self) -> TeradataHook: + return TeradataHook(teradata_conn_id=self.teradata_conn_id) + + @abstractmethod + def execute(self, context: Context): + pass + + def execute_complete(self, context: Context, event: dict[str, Any]) -> None: + """ + Execute when the trigger fires - returns immediately. + + Relies on trigger to throw an exception, otherwise it assumes execution was successful. + """ + self._compute_cluster_execute_complete(event) + + def _compute_cluster_execute(self): + # Verifies the provided compute profile name. + if ( + self.compute_profile_name is None + or self.compute_profile_name == "None" + or self.compute_profile_name == "" + ): + self.log.info("Invalid compute cluster profile name") + raise AirflowException(Constants.CC_OPR_EMPTY_PROFILE_ERROR_MSG) + # Verifies if the provided Teradata instance belongs to Vantage Cloud Lake. + lake_support_find_sql = "SELECT count(1) from DBC.StorageV WHERE StorageName='TD_OFSSTORAGE'" + lake_support_result = self.hook.run(lake_support_find_sql, handler=_single_result_row_handler) + if lake_support_result is None: + raise AirflowException(Constants.CC_GRP_LAKE_SUPPORT_ONLY_MSG) + # Getting teradata db version. Considering teradata instance is Lake when db version is 20 or above + db_version_get_sql = "SELECT InfoData AS Version FROM DBC.DBCInfoV WHERE InfoKey = 'VERSION'" + try: + db_version_result = self.hook.run(db_version_get_sql, handler=_single_result_row_handler) + if db_version_result is not None: + db_version_result = str(db_version_result) + db_version = db_version_result.split(".")[0] + if db_version is not None and int(db_version) < 20: + raise AirflowException(Constants.CC_GRP_LAKE_SUPPORT_ONLY_MSG) + else: + raise AirflowException("Error occurred while getting teradata database version") + except Exception as ex: + self.log.error("Error occurred while getting teradata database version: %s ", str(ex)) + raise AirflowException("Error occurred while getting teradata database version") + + def _compute_cluster_execute_complete(self, event: dict[str, Any]) -> None: + if event["status"] == "success": + return event["message"] + elif event["status"] == "error": + raise AirflowException(event["message"]) + + def _handle_cc_status(self, operation_type, sql): + create_sql_result = self._hook_run(sql, handler=_single_result_row_handler) + self.log.info( + "%s query ran successfully. Differing to trigger to check status in db. Result from sql: %s", + operation_type, + create_sql_result, + ) + self.defer( + timeout=timedelta(minutes=self.timeout), + trigger=TeradataComputeClusterSyncTrigger( + teradata_conn_id=cast(str, self.teradata_conn_id), + compute_profile_name=self.compute_profile_name, + compute_group_name=self.compute_group_name, + operation_type=operation_type, + poll_interval=Constants.CC_POLL_INTERVAL, + ), + method_name="execute_complete", + ) + + return create_sql_result + + def _hook_run(self, query, handler=None): + try: + if handler is not None: + return self.hook.run(query, handler=handler) + else: + return self.hook.run(query) + except Exception as ex: + self.log.error(str(ex)) + raise + + def _get_initially_suspended(self, create_cp_query): + initially_suspended = "FALSE" + pattern = r"INITIALLY_SUSPENDED\s*\(\s*'(TRUE|FALSE)'\s*\)" + # Search for the pattern in the input string + match = re.search(pattern, create_cp_query, re.IGNORECASE) + if match: + # Get the value of INITIALLY_SUSPENDED + initially_suspended = match.group(1).strip().upper() + return initially_suspended + + +class TeradataComputeClusterProvisionOperator(_TeradataComputeClusterOperator): + """ + + Creates the new Computer Cluster with specified Compute Group Name and Compute Profile Name. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:TeradataComputeClusterProvisionOperator` + + :param compute_profile_name: Name of the Compute Profile to manage. + :param compute_group_name: Name of compute group to which compute profile belongs. + :param query_strategy: Query strategy to use. Refers to the approach or method used by the + Teradata Optimizer to execute SQL queries efficiently within a Teradata computer cluster. + Valid query_strategy value is either 'STANDARD' or 'ANALYTIC'. Default at database level is STANDARD. + :param compute_map: ComputeMapName of the compute map. The compute_map in a compute cluster profile refers + to the mapping of compute resources to a specific node or set of nodes within the cluster. + :param compute_attribute: Optional attributes of compute profile. Example compute attribute + MIN_COMPUTE_COUNT(1) MAX_COMPUTE_COUNT(5) INITIALLY_SUSPENDED('FALSE') + :param teradata_conn_id: The :ref:`Teradata connection id ` + reference to a specific Teradata database. + :param timeout: Time elapsed before the task times out and fails. + """ + + template_fields: Sequence[str] = ( + "compute_profile_name", + "compute_group_name", + "query_strategy", + "compute_map", + "compute_attribute", + "teradata_conn_id", + "timeout", + ) + + ui_color = "#e07c24" + + def __init__( + self, + query_strategy: str | None = None, + compute_map: str | None = None, + compute_attribute: str | None = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.query_strategy = query_strategy + self.compute_map = compute_map + self.compute_attribute = compute_attribute + + def _build_ccp_setup_query(self): + create_cp_query = "CREATE COMPUTE PROFILE " + self.compute_profile_name + if self.compute_group_name: + create_cp_query = create_cp_query + " IN " + self.compute_group_name + if self.compute_map is not None: + create_cp_query = create_cp_query + ", INSTANCE = " + self.compute_map + if self.query_strategy is not None: + create_cp_query = create_cp_query + ", INSTANCE TYPE = " + self.query_strategy + if self.compute_attribute is not None: + create_cp_query = create_cp_query + " USING " + self.compute_attribute + return create_cp_query + + def execute(self, context: Context): + """ + Initiate the execution of CREATE COMPUTE SQL statement. + + Initiate the execution of the SQL statement for provisioning the compute cluster within Teradata Vantage + Lake, effectively creates the compute cluster. + Airflow runs this method on the worker and defers using the trigger. + """ + super().execute(context) + return self._compute_cluster_execute() + + def _compute_cluster_execute(self): + super()._compute_cluster_execute() + if self.compute_group_name: + cg_status_query = ( + "SELECT count(1) FROM DBC.ComputeGroups WHERE UPPER(ComputeGroupName) = UPPER('" + + self.compute_group_name + + "')" + ) + cg_status_result = self._hook_run(cg_status_query, _single_result_row_handler) + if cg_status_result is not None: + cg_status_result = str(cg_status_result) + else: + cg_status_result = 0 + if int(cg_status_result) == 0: + create_cg_query = "CREATE COMPUTE GROUP " + self.compute_group_name + if self.query_strategy is not None: + create_cg_query = ( + create_cg_query + " USING QUERY_STRATEGY ('" + self.query_strategy + "')" + ) + self._hook_run(create_cg_query, _single_result_row_handler) + cp_status_query = ( + "SEL ComputeProfileState FROM DBC.ComputeProfilesVX WHERE UPPER(ComputeProfileName) = UPPER('" + + self.compute_profile_name + + "')" + ) + if self.compute_group_name: + cp_status_query += " AND UPPER(ComputeGroupName) = UPPER('" + self.compute_group_name + "')" + cp_status_result = self._hook_run(cp_status_query, handler=_single_result_row_handler) + if cp_status_result is not None: + cp_status_result = str(cp_status_result) + msg = f"Compute Profile {self.compute_profile_name} is already exists under Compute Group {self.compute_group_name}. Status is {cp_status_result}" + self.log.info(msg) + return cp_status_result + else: + create_cp_query = self._build_ccp_setup_query() + operation = Constants.CC_CREATE_OPR + initially_suspended = self._get_initially_suspended(create_cp_query) + if initially_suspended == "TRUE": + operation = Constants.CC_CREATE_SUSPEND_OPR + return self._handle_cc_status(operation, create_cp_query) + + +class TeradataComputeClusterDecommissionOperator(_TeradataComputeClusterOperator): + """ + Drops the compute cluster with specified Compute Group Name and Compute Profile Name. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:TeradataComputeClusterDecommissionOperator` + + :param compute_profile_name: Name of the Compute Profile to manage. + :param compute_group_name: Name of compute group to which compute profile belongs. + :param delete_compute_group: Indicates whether the compute group should be deleted. + When set to True, it signals the system to remove the specified compute group. + Conversely, when set to False, no action is taken on the compute group. + :param teradata_conn_id: The :ref:`Teradata connection id ` + reference to a specific Teradata database. + :param timeout: Time elapsed before the task times out and fails. + """ + + template_fields: Sequence[str] = ( + "compute_profile_name", + "compute_group_name", + "delete_compute_group", + "teradata_conn_id", + "timeout", + ) + + ui_color = "#e07c24" + + def __init__( + self, + delete_compute_group: bool = False, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.delete_compute_group = delete_compute_group + + def execute(self, context: Context): + """ + Initiate the execution of DROP COMPUTE SQL statement. + + Initiate the execution of the SQL statement for decommissioning the compute cluster within Teradata Vantage + Lake, effectively drops the compute cluster. + Airflow runs this method on the worker and defers using the trigger. + """ + super().execute(context) + return self._compute_cluster_execute() + + def _compute_cluster_execute(self): + super()._compute_cluster_execute() + cp_drop_query = "DROP COMPUTE PROFILE " + self.compute_profile_name + if self.compute_group_name: + cp_drop_query = cp_drop_query + " IN COMPUTE GROUP " + self.compute_group_name + self._hook_run(cp_drop_query, handler=_single_result_row_handler) + self.log.info( + "Compute Profile %s IN Compute Group %s is successfully dropped", + self.compute_profile_name, + self.compute_group_name, + ) + if self.delete_compute_group: + cg_drop_query = "DROP COMPUTE GROUP " + self.compute_group_name + self._hook_run(cg_drop_query, handler=_single_result_row_handler) + self.log.info("Compute Group %s is successfully dropped", self.compute_group_name) + + +class TeradataComputeClusterResumeOperator(_TeradataComputeClusterOperator): + """ + Teradata Compute Cluster Operator to Resume the specified Teradata Vantage Cloud Lake Compute Cluster. + + Resumes the Teradata Vantage Lake Computer Cluster by employing the RESUME SQL statement within the + Teradata Vantage Lake Compute Cluster SQL Interface. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:TeradataComputeClusterResumeOperator` + + :param compute_profile_name: Name of the Compute Profile to manage. + :param compute_group_name: Name of compute group to which compute profile belongs. + :param teradata_conn_id: The :ref:`Teradata connection id ` + reference to a specific Teradata database. + :param timeout: Time elapsed before the task times out and fails. Time is in minutes. + """ + + template_fields: Sequence[str] = ( + "compute_profile_name", + "compute_group_name", + "teradata_conn_id", + "timeout", + ) + + ui_color = "#e07c24" + + def __init__( + self, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + def execute(self, context: Context): + """ + Initiate the execution of RESUME COMPUTE SQL statement. + + Initiate the execution of the SQL statement for resuming the compute cluster within Teradata Vantage + Lake, effectively resumes the compute cluster. + Airflow runs this method on the worker and defers using the trigger. + """ + super().execute(context) + return self._compute_cluster_execute() + + def _compute_cluster_execute(self): + super()._compute_cluster_execute() + cc_status_query = ( + "SEL ComputeProfileState FROM DBC.ComputeProfilesVX WHERE UPPER(ComputeProfileName) = UPPER('" + + self.compute_profile_name + + "')" + ) + if self.compute_group_name: + cc_status_query += " AND UPPER(ComputeGroupName) = UPPER('" + self.compute_group_name + "')" + cc_status_result = self._hook_run(cc_status_query, handler=_single_result_row_handler) + if cc_status_result is not None: + cp_status_result = str(cc_status_result) + # Generates an error message if the compute cluster does not exist for the specified + # compute profile and compute group. + else: + self.log.info(Constants.CC_GRP_PRP_NON_EXISTS_MSG) + raise AirflowException(Constants.CC_GRP_PRP_NON_EXISTS_MSG) + if cp_status_result != Constants.CC_RESUME_DB_STATUS: + cp_resume_query = f"RESUME COMPUTE FOR COMPUTE PROFILE {self.compute_profile_name}" + if self.compute_group_name: + cp_resume_query = f"{cp_resume_query} IN COMPUTE GROUP {self.compute_group_name}" + return self._handle_cc_status(Constants.CC_RESUME_OPR, cp_resume_query) + else: + self.log.info( + "Compute Cluster %s already %s", self.compute_profile_name, Constants.CC_RESUME_DB_STATUS + ) + + +class TeradataComputeClusterSuspendOperator(_TeradataComputeClusterOperator): + """ + Teradata Compute Cluster Operator to suspend the specified Teradata Vantage Cloud Lake Compute Cluster. + + Suspends the Teradata Vantage Lake Computer Cluster by employing the SUSPEND SQL statement within the + Teradata Vantage Lake Compute Cluster SQL Interface. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:TeradataComputeClusterSuspendOperator` + + :param compute_profile_name: Name of the Compute Profile to manage. + :param compute_group_name: Name of compute group to which compute profile belongs. + :param teradata_conn_id: The :ref:`Teradata connection id ` + reference to a specific Teradata database. + :param timeout: Time elapsed before the task times out and fails. + """ + + template_fields: Sequence[str] = ( + "compute_profile_name", + "compute_group_name", + "teradata_conn_id", + "timeout", + ) + + ui_color = "#e07c24" + + def __init__( + self, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + def execute(self, context: Context): + """ + Initiate the execution of SUSPEND COMPUTE SQL statement. + + Initiate the execution of the SQL statement for suspending the compute cluster within Teradata Vantage + Lake, effectively suspends the compute cluster. + Airflow runs this method on the worker and defers using the trigger. + """ + super().execute(context) + return self._compute_cluster_execute() + + def _compute_cluster_execute(self): + super()._compute_cluster_execute() + sql = ( + "SEL ComputeProfileState FROM DBC.ComputeProfilesVX WHERE UPPER(ComputeProfileName) = UPPER('" + + self.compute_profile_name + + "')" + ) + if self.compute_group_name: + sql += " AND UPPER(ComputeGroupName) = UPPER('" + self.compute_group_name + "')" + result = self._hook_run(sql, handler=_single_result_row_handler) + if result is not None: + result = str(result) + # Generates an error message if the compute cluster does not exist for the specified + # compute profile and compute group. + else: + self.log.info(Constants.CC_GRP_PRP_NON_EXISTS_MSG) + raise AirflowException(Constants.CC_GRP_PRP_NON_EXISTS_MSG) + if result != Constants.CC_SUSPEND_DB_STATUS: + sql = f"SUSPEND COMPUTE FOR COMPUTE PROFILE {self.compute_profile_name}" + if self.compute_group_name: + sql = f"{sql} IN COMPUTE GROUP {self.compute_group_name}" + return self._handle_cc_status(Constants.CC_SUSPEND_OPR, sql) + else: + self.log.info( + "Compute Cluster %s already %s", self.compute_profile_name, Constants.CC_SUSPEND_DB_STATUS + ) diff --git a/airflow/providers/teradata/provider.yaml b/airflow/providers/teradata/provider.yaml index 9855cec14d377..a5a93ad7a85ac 100644 --- a/airflow/providers/teradata/provider.yaml +++ b/airflow/providers/teradata/provider.yaml @@ -50,6 +50,7 @@ integrations: external-doc-url: https://www.teradata.com/ how-to-guide: - /docs/apache-airflow-providers-teradata/operators/teradata.rst + - /docs/apache-airflow-providers-teradata/operators/compute_cluster.rst logo: /integration-logos/teradata/Teradata.png tags: [software] @@ -57,6 +58,7 @@ operators: - integration-name: Teradata python-modules: - airflow.providers.teradata.operators.teradata + - airflow.providers.teradata.operators.teradata_compute_cluster hooks: - integration-name: Teradata @@ -80,3 +82,8 @@ transfers: connection-types: - hook-class-name: airflow.providers.teradata.hooks.teradata.TeradataHook connection-type: teradata + +triggers: + - integration-name: Teradata + python-modules: + - airflow.providers.teradata.triggers.teradata_compute_cluster diff --git a/airflow/providers/teradata/triggers/__init__.py b/airflow/providers/teradata/triggers/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/airflow/providers/teradata/triggers/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/airflow/providers/teradata/triggers/teradata_compute_cluster.py b/airflow/providers/teradata/triggers/teradata_compute_cluster.py new file mode 100644 index 0000000000000..5b971535d26d2 --- /dev/null +++ b/airflow/providers/teradata/triggers/teradata_compute_cluster.py @@ -0,0 +1,155 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import asyncio +from typing import Any, AsyncIterator + +from airflow.exceptions import AirflowException +from airflow.providers.common.sql.hooks.sql import fetch_one_handler +from airflow.providers.teradata.hooks.teradata import TeradataHook +from airflow.providers.teradata.utils.constants import Constants +from airflow.triggers.base import BaseTrigger, TriggerEvent + + +class TeradataComputeClusterSyncTrigger(BaseTrigger): + """ + Fetch the status of the suspend or resume operation for the specified compute cluster. + + :param teradata_conn_id: The :ref:`Teradata connection id ` + reference to a specific Teradata database. + :param compute_profile_name: Name of the Compute Profile to manage. + :param compute_group_name: Name of compute group to which compute profile belongs. + :param opr_type: Compute cluster operation - SUSPEND/RESUME + :param poll_interval: polling period in minutes to check for the status + """ + + def __init__( + self, + teradata_conn_id: str, + compute_profile_name: str, + compute_group_name: str | None = None, + operation_type: str | None = None, + poll_interval: float | None = None, + ): + super().__init__() + self.teradata_conn_id = teradata_conn_id + self.compute_profile_name = compute_profile_name + self.compute_group_name = compute_group_name + self.operation_type = operation_type + self.poll_interval = poll_interval + + def serialize(self) -> tuple[str, dict[str, Any]]: + """Serialize TeradataComputeClusterSyncTrigger arguments and classpath.""" + return ( + "airflow.providers.teradata.triggers.teradata_compute_cluster.TeradataComputeClusterSyncTrigger", + { + "teradata_conn_id": self.teradata_conn_id, + "compute_profile_name": self.compute_profile_name, + "compute_group_name": self.compute_group_name, + "operation_type": self.operation_type, + "poll_interval": self.poll_interval, + }, + ) + + async def run(self) -> AsyncIterator[TriggerEvent]: + """Wait for Compute Cluster operation to complete.""" + try: + while True: + status = await self.get_status() + if status is None or len(status) == 0: + self.log.info(Constants.CC_GRP_PRP_NON_EXISTS_MSG) + raise AirflowException(Constants.CC_GRP_PRP_NON_EXISTS_MSG) + if ( + self.operation_type == Constants.CC_SUSPEND_OPR + or self.operation_type == Constants.CC_CREATE_SUSPEND_OPR + ): + if status == Constants.CC_SUSPEND_DB_STATUS: + break + elif ( + self.operation_type == Constants.CC_RESUME_OPR + or self.operation_type == Constants.CC_CREATE_OPR + ): + if status == Constants.CC_RESUME_DB_STATUS: + break + if self.poll_interval is not None: + self.poll_interval = float(self.poll_interval) + else: + self.poll_interval = float(Constants.CC_POLL_INTERVAL) + await asyncio.sleep(self.poll_interval) + if ( + self.operation_type == Constants.CC_SUSPEND_OPR + or self.operation_type == Constants.CC_CREATE_SUSPEND_OPR + ): + if status == Constants.CC_SUSPEND_DB_STATUS: + yield TriggerEvent( + { + "status": "success", + "message": Constants.CC_OPR_SUCCESS_STATUS_MSG + % (self.compute_profile_name, self.operation_type), + } + ) + else: + yield TriggerEvent( + { + "status": "error", + "message": Constants.CC_OPR_FAILURE_STATUS_MSG + % (self.compute_profile_name, self.operation_type), + } + ) + elif ( + self.operation_type == Constants.CC_RESUME_OPR + or self.operation_type == Constants.CC_CREATE_OPR + ): + if status == Constants.CC_RESUME_DB_STATUS: + yield TriggerEvent( + { + "status": "success", + "message": Constants.CC_OPR_SUCCESS_STATUS_MSG + % (self.compute_profile_name, self.operation_type), + } + ) + else: + yield TriggerEvent( + { + "status": "error", + "message": Constants.CC_OPR_FAILURE_STATUS_MSG + % (self.compute_profile_name, self.operation_type), + } + ) + else: + yield TriggerEvent({"status": "error", "message": "Invalid operation"}) + except Exception as e: + yield TriggerEvent({"status": "error", "message": str(e)}) + except asyncio.CancelledError: + self.log.error(Constants.CC_OPR_TIMEOUT_ERROR, self.operation_type) + + async def get_status(self) -> str: + """Return compute cluster SUSPEND/RESUME operation status.""" + sql = ( + "SEL ComputeProfileState FROM DBC.ComputeProfilesVX WHERE UPPER(ComputeProfileName) = UPPER('" + + self.compute_profile_name + + "')" + ) + if self.compute_group_name: + sql += " AND UPPER(ComputeGroupName) = UPPER('" + self.compute_group_name + "')" + hook = TeradataHook(teradata_conn_id=self.teradata_conn_id) + result_set = hook.run(sql, handler=fetch_one_handler) + status = "" + if isinstance(result_set, list) and isinstance(result_set[0], str): + status = str(result_set[0]) + return status diff --git a/airflow/providers/teradata/utils/__init__.py b/airflow/providers/teradata/utils/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/airflow/providers/teradata/utils/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/airflow/providers/teradata/utils/constants.py b/airflow/providers/teradata/utils/constants.py new file mode 100644 index 0000000000000..ee356ceb402e5 --- /dev/null +++ b/airflow/providers/teradata/utils/constants.py @@ -0,0 +1,46 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + + +class Constants: + """Define constants for Teradata Provider.""" + + CC_CREATE_OPR = "CREATE" + CC_CREATE_SUSPEND_OPR = "CREATE_SUSPEND" + CC_DROP_OPR = "DROP" + CC_SUSPEND_OPR = "SUSPEND" + CC_RESUME_OPR = "RESUME" + CC_INITIALIZE_DB_STATUS = "Initializing" + CC_SUSPEND_DB_STATUS = "Suspended" + CC_RESUME_DB_STATUS = "Running" + CC_OPR_SUCCESS_STATUS_MSG = "Compute Cluster %s %s operation completed successfully." + CC_OPR_FAILURE_STATUS_MSG = "Compute Cluster %s %s operation has failed." + CC_OPR_INITIALIZING_STATUS_MSG = "The environment is currently initializing. Please wait." + CC_OPR_EMPTY_PROFILE_ERROR_MSG = "Please provide a valid name for the compute cluster profile." + CC_GRP_PRP_NON_EXISTS_MSG = "The specified Compute cluster is not present or The user doesn't have permission to access compute cluster." + CC_GRP_PRP_UN_AUTHORIZED_MSG = "The %s operation is not authorized for the user." + CC_GRP_LAKE_SUPPORT_ONLY_MSG = "Compute Groups is supported only on Vantage Cloud Lake." + CC_OPR_TIMEOUT_ERROR = ( + "There is an issue with the %s operation. Kindly consult the administrator for assistance." + ) + CC_GRP_PRP_EXISTS_MSG = "The specified Compute cluster is already exists." + CC_OPR_EMPTY_COPY_PROFILE_ERROR_MSG = ( + "Please provide a valid name for the source and target compute profile." + ) + CC_OPR_TIME_OUT = 1200 + CC_POLL_INTERVAL = 60 diff --git a/docs/apache-airflow-providers-teradata/operators/compute_cluster.rst b/docs/apache-airflow-providers-teradata/operators/compute_cluster.rst new file mode 100644 index 0000000000000..ceaf27ee74676 --- /dev/null +++ b/docs/apache-airflow-providers-teradata/operators/compute_cluster.rst @@ -0,0 +1,107 @@ +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +.. _howto/operator:TeradataComputeClusterProvisionOperator: + + +======================================= +TeradataComputeClusterProvisionOperator +======================================= + +The purpose of ``TeradataComputeClusterProvisionOperator`` is to provision the new Teradata Vantage Cloud Lake +Compute Cluster with specified Compute Group Name and Compute Profile Name. +Use the :class:`TeradataComputeClusterProvisionOperator ` +to provision the new Compute Cluster in Teradata Vantage Cloud Lake. + + + +An example usage of the TeradataComputeClusterProvisionOperator to provision the new Compute Cluster in +Teradata Vantage Cloud Lake is as follows: + +.. exampleinclude:: /../../tests/system/providers/teradata/example_teradata_compute_cluster.py + :language: python + :start-after: [START teradata_vantage_lake_compute_cluster_provision_howto_guide] + :end-before: [END teradata_vantage_lake_compute_cluster_provision_howto_guide] + + +.. _howto/operator:TeradataComputeClusterDecommissionOperator: + + +========================================== +TeradataComputeClusterDecommissionOperator +========================================== + +The purpose of ``TeradataComputeClusterDecommissionOperator`` is to decommission the specified Teradata Vantage Cloud Lake +Compute Cluster. +Use the :class:`TeradataComputeClusterProvisionOperator ` +to decommission the specified Teradata Vantage Cloud Lake Compute Cluster. + + + +An example usage of the TeradataComputeClusterDecommissionOperator to decommission the specified Teradata Vantage Cloud +Lake Compute Cluster is as follows: + +.. exampleinclude:: /../../tests/system/providers/teradata/example_teradata_compute_cluster.py + :language: python + :start-after: [START teradata_vantage_lake_compute_cluster_decommission_howto_guide] + :end-before: [END teradata_vantage_lake_compute_cluster_decommission_howto_guide] + + +.. _howto/operator:TeradataComputeClusterResumeOperator: + + +===================================== +TeradataComputeClusterResumeOperator +===================================== + +The purpose of ``TeradataComputeClusterResumeOperator`` is to start the Teradata Vantage Cloud Lake +Compute Cluster of specified Compute Group Name and Compute Profile Name. +Use the :class:`TeradataComputeClusterResumeOperator ` +to start the specified Compute Cluster in Teradata Vantage Cloud Lake. + + + +An example usage of the TeradataComputeClusterSuspendOperator to start the specified Compute Cluster in +Teradata Vantage Cloud Lake is as follows: + +.. exampleinclude:: /../../tests/system/providers/teradata/example_teradata_compute_cluster.py + :language: python + :start-after: [START teradata_vantage_lake_compute_cluster_resume_howto_guide] + :end-before: [END teradata_vantage_lake_compute_cluster_resume_howto_guide] + +.. _howto/operator:TeradataComputeClusterSuspendOperator: + + +===================================== +TeradataComputeClusterSuspendOperator +===================================== + +The purpose of ``TeradataComputeClusterSuspendOperator`` is to suspend the Teradata Vantage Cloud Lake +Compute Cluster of specified Compute Group Name and Compute Profile Name. +Use the :class:`TeradataComputeClusterSuspendOperator ` +to suspend the specified Compute Cluster in Teradata Vantage Cloud Lake. + + + +An example usage of the TeradataComputeClusterSuspendOperator to suspend the specified Compute Cluster in +Teradata Vantage Cloud Lake is as follows: + +.. exampleinclude:: /../../tests/system/providers/teradata/example_teradata_compute_cluster.py + :language: python + :start-after: [START teradata_vantage_lake_compute_cluster_suspend_howto_guide] + :end-before: [END teradata_vantage_lake_compute_cluster_suspend_howto_guide] diff --git a/tests/providers/teradata/operators/test_teradata_compute_cluster.py b/tests/providers/teradata/operators/test_teradata_compute_cluster.py new file mode 100644 index 0000000000000..d36f59848b8ad --- /dev/null +++ b/tests/providers/teradata/operators/test_teradata_compute_cluster.py @@ -0,0 +1,713 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest.mock import call, patch + +import pytest + +from airflow.exceptions import AirflowException +from airflow.providers.teradata.operators.teradata_compute_cluster import ( + TeradataComputeClusterDecommissionOperator, + TeradataComputeClusterProvisionOperator, + TeradataComputeClusterResumeOperator, + TeradataComputeClusterSuspendOperator, + _single_result_row_handler, +) +from airflow.providers.teradata.triggers.teradata_compute_cluster import TeradataComputeClusterSyncTrigger +from airflow.providers.teradata.utils.constants import Constants + + +@pytest.fixture +def compute_profile_name(): + return "test_profile" + + +@pytest.fixture +def compute_group_name(): + return "test_group" + + +@pytest.fixture +def query_strategy(): + return "test_query_strategy" + + +@pytest.fixture +def compute_map(): + return "test_compute_map" + + +@pytest.fixture +def compute_attribute(): + return "test_compute_attribute" + + +@pytest.fixture +def compute_cluster_provision_instance(compute_profile_name): + return TeradataComputeClusterProvisionOperator( + task_id="test", compute_profile_name=compute_profile_name, teradata_conn_id="test_conn" + ) + + +@pytest.fixture +def compute_cluster_decommission_instance(compute_profile_name): + return TeradataComputeClusterDecommissionOperator( + task_id="test", compute_profile_name=compute_profile_name, teradata_conn_id="test_conn" + ) + + +@pytest.fixture +def compute_cluster_resume_instance(compute_profile_name): + return TeradataComputeClusterResumeOperator( + task_id="test", compute_profile_name=compute_profile_name, teradata_conn_id="test_conn" + ) + + +@pytest.fixture +def compute_cluster_suspend_instance(compute_profile_name): + return TeradataComputeClusterSuspendOperator( + task_id="test", compute_profile_name=compute_profile_name, teradata_conn_id="test_conn" + ) + + +class TestTeradataComputeClusterOperator: + def test_compute_cluster_execute_invalid_profile(self, compute_cluster_provision_instance): + compute_cluster_provision_instance.compute_profile_name = None + with pytest.raises(AirflowException): + compute_cluster_provision_instance._compute_cluster_execute() + + def test_compute_cluster_execute_empty_profile(self, compute_cluster_provision_instance): + compute_cluster_provision_instance.compute_profile_name = "" + with pytest.raises(AirflowException): + compute_cluster_provision_instance._compute_cluster_execute() + + def test_compute_cluster_execute_none_profile(self, compute_cluster_provision_instance): + compute_cluster_provision_instance.compute_profile_name = "None" + with pytest.raises(AirflowException): + compute_cluster_provision_instance._compute_cluster_execute() + + def test_compute_cluster_execute_not_lake(self, compute_cluster_provision_instance): + with patch.object(compute_cluster_provision_instance, "hook") as mock_hook: + # Set up mock return values + mock_hook.run.side_effect = [None] + with pytest.raises(AirflowException): + compute_cluster_provision_instance._compute_cluster_execute() + + def test_compute_cluster_execute_not_lake_version_check(self, compute_cluster_provision_instance): + with patch.object(compute_cluster_provision_instance, "hook") as mock_hook: + # Set up mock return values + mock_hook.run.side_effect = ["1", "19"] + with pytest.raises(AirflowException): + compute_cluster_provision_instance._compute_cluster_execute() + + def test_compute_cluster_execute_not_lake_version_none(self, compute_cluster_provision_instance): + with patch.object(compute_cluster_provision_instance, "hook") as mock_hook: + # Set up mock return values + mock_hook.run.side_effect = ["1", None] + with pytest.raises(AirflowException): + compute_cluster_provision_instance._compute_cluster_execute() + + def test_compute_cluster_execute_not_lake_version_invalid(self, compute_cluster_provision_instance): + with patch.object(compute_cluster_provision_instance, "hook") as mock_hook: + # Set up mock return values + mock_hook.run.side_effect = ["1", "invalid"] + with pytest.raises(AirflowException): + compute_cluster_provision_instance._compute_cluster_execute() + + def test_compute_cluster_execute_complete_success(self, compute_cluster_provision_instance): + event = {"status": "success", "message": "Success message"} + # Call the method under test + result = compute_cluster_provision_instance._compute_cluster_execute_complete(event) + assert result == "Success message" + + def test_compute_cluster_execute_complete_error(self, compute_cluster_provision_instance): + event = {"status": "error", "message": "Error message"} + with pytest.raises(AirflowException): + compute_cluster_provision_instance._compute_cluster_execute_complete(event) + + def test_cc_execute_provision_new_cp(self, compute_cluster_provision_instance): + with patch.object(compute_cluster_provision_instance, "hook") as mock_hook: + # Set up mock return values + mock_hook.run.side_effect = ["1", "20.00", None, "Success"] + compute_profile_name = compute_cluster_provision_instance.compute_profile_name + with patch.object(compute_cluster_provision_instance, "defer") as mock_defer: + # Assert that defer method is called with the correct parameters + expected_trigger = TeradataComputeClusterSyncTrigger( + teradata_conn_id=compute_cluster_provision_instance.teradata_conn_id, + compute_profile_name=compute_profile_name, + operation_type=Constants.CC_CREATE_OPR, + poll_interval=Constants.CC_POLL_INTERVAL, + ) + mock_defer.return_value = expected_trigger + result = compute_cluster_provision_instance._compute_cluster_execute() + assert result == "Success" + mock_defer.assert_called_once() + mock_hook.run.assert_has_calls( + [ + call( + "SELECT count(1) from DBC.StorageV WHERE StorageName='TD_OFSSTORAGE'", + handler=_single_result_row_handler, + ), + call( + "SELECT InfoData AS Version FROM DBC.DBCInfoV WHERE InfoKey = 'VERSION'", + handler=_single_result_row_handler, + ), + call( + f"SEL ComputeProfileState FROM DBC.ComputeProfilesVX WHERE UPPER(ComputeProfileName) = UPPER('{compute_profile_name}')", + handler=_single_result_row_handler, + ), + call( + f"CREATE COMPUTE PROFILE {compute_profile_name}", + handler=_single_result_row_handler, + ), + ] + ) + + def test_cc_execute_provision_exists_cp(self, compute_cluster_provision_instance): + with patch.object(compute_cluster_provision_instance, "hook") as mock_hook: + # Set up mock return values + mock_hook.run.side_effect = ["1", "20.00", "RUNNING", "Success"] + compute_profile_name = compute_cluster_provision_instance.compute_profile_name + with patch.object(compute_cluster_provision_instance, "defer") as mock_defer: + # Assert that defer method is called with the correct parameters + expected_trigger = TeradataComputeClusterSyncTrigger( + teradata_conn_id=compute_cluster_provision_instance.teradata_conn_id, + compute_profile_name=compute_profile_name, + operation_type=Constants.CC_CREATE_OPR, + poll_interval=Constants.CC_POLL_INTERVAL, + ) + mock_defer.return_value = expected_trigger + result = compute_cluster_provision_instance._compute_cluster_execute() + assert result == "RUNNING" + mock_hook.run.assert_has_calls( + [ + call( + "SELECT count(1) from DBC.StorageV WHERE StorageName='TD_OFSSTORAGE'", + handler=_single_result_row_handler, + ), + call( + "SELECT InfoData AS Version FROM DBC.DBCInfoV WHERE InfoKey = 'VERSION'", + handler=_single_result_row_handler, + ), + call( + f"SEL ComputeProfileState FROM DBC.ComputeProfilesVX WHERE UPPER(ComputeProfileName) = UPPER('{compute_profile_name}')", + handler=_single_result_row_handler, + ), + ] + ) + + def test_cc_execute_provision_new_cp_exists_cg( + self, compute_group_name, compute_cluster_provision_instance + ): + with patch.object(compute_cluster_provision_instance, "hook") as mock_hook: + # Set up mock return values + mock_hook.run.side_effect = ["1", "20.00", "1", None, "Success"] + compute_cluster_provision_instance.compute_group_name = compute_group_name + compute_profile_name = compute_cluster_provision_instance.compute_profile_name + with patch.object(compute_cluster_provision_instance, "defer") as mock_defer: + # Assert that defer method is called with the correct parameters + expected_trigger = TeradataComputeClusterSyncTrigger( + teradata_conn_id=compute_cluster_provision_instance.teradata_conn_id, + compute_profile_name=compute_profile_name, + compute_group_name=compute_group_name, + operation_type=Constants.CC_CREATE_OPR, + poll_interval=Constants.CC_POLL_INTERVAL, + ) + mock_defer.return_value = expected_trigger + result = compute_cluster_provision_instance._compute_cluster_execute() + assert result == "Success" + mock_defer.assert_called_once() + mock_hook.run.assert_has_calls( + [ + call( + "SELECT count(1) from DBC.StorageV WHERE StorageName='TD_OFSSTORAGE'", + handler=_single_result_row_handler, + ), + call( + "SELECT InfoData AS Version FROM DBC.DBCInfoV WHERE InfoKey = 'VERSION'", + handler=_single_result_row_handler, + ), + call( + f"SELECT count(1) FROM DBC.ComputeGroups WHERE UPPER(ComputeGroupName) = UPPER('{compute_group_name}')", + handler=_single_result_row_handler, + ), + call( + f"SEL ComputeProfileState FROM DBC.ComputeProfilesVX WHERE UPPER(ComputeProfileName) = UPPER('{compute_profile_name}') AND UPPER(ComputeGroupName) = UPPER('{compute_group_name}')", + handler=_single_result_row_handler, + ), + call( + f"CREATE COMPUTE PROFILE {compute_profile_name} IN {compute_group_name}", + handler=_single_result_row_handler, + ), + ] + ) + + def test_cc_execute_provision_exists_cp_exists_cg( + self, compute_group_name, compute_cluster_provision_instance + ): + with patch.object(compute_cluster_provision_instance, "hook") as mock_hook: + # Set up mock return values + mock_hook.run.side_effect = ["1", "20.00", "1", "RUNNING", "Success"] + compute_profile_name = compute_cluster_provision_instance.compute_profile_name + compute_cluster_provision_instance.compute_group_name = compute_group_name + with patch.object(compute_cluster_provision_instance, "defer") as mock_defer: + # Assert that defer method is called with the correct parameters + expected_trigger = TeradataComputeClusterSyncTrigger( + teradata_conn_id=compute_cluster_provision_instance.teradata_conn_id, + compute_profile_name=compute_profile_name, + compute_group_name=compute_group_name, + operation_type=Constants.CC_CREATE_OPR, + poll_interval=Constants.CC_POLL_INTERVAL, + ) + mock_defer.return_value = expected_trigger + result = compute_cluster_provision_instance._compute_cluster_execute() + assert result == "RUNNING" + mock_hook.run.assert_has_calls( + [ + call( + "SELECT count(1) from DBC.StorageV WHERE StorageName='TD_OFSSTORAGE'", + handler=_single_result_row_handler, + ), + call( + "SELECT InfoData AS Version FROM DBC.DBCInfoV WHERE InfoKey = 'VERSION'", + handler=_single_result_row_handler, + ), + call( + f"SELECT count(1) FROM DBC.ComputeGroups WHERE UPPER(ComputeGroupName) = UPPER('{compute_group_name}')", + handler=_single_result_row_handler, + ), + call( + f"SEL ComputeProfileState FROM DBC.ComputeProfilesVX WHERE UPPER(ComputeProfileName) = UPPER('{compute_profile_name}') AND UPPER(ComputeGroupName) = UPPER('{compute_group_name}')", + handler=_single_result_row_handler, + ), + ] + ) + + def test_cc_execute_provision_new_cp_new_cg(self, compute_group_name, compute_cluster_provision_instance): + with patch.object(compute_cluster_provision_instance, "hook") as mock_hook: + # Set up mock return values + mock_hook.run.side_effect = ["1", "20.00", "0", "Success", None, "Success"] + compute_cluster_provision_instance.compute_group_name = compute_group_name + compute_profile_name = compute_cluster_provision_instance.compute_profile_name + with patch.object(compute_cluster_provision_instance, "defer") as mock_defer: + # Assert that defer method is called with the correct parameters + expected_trigger = TeradataComputeClusterSyncTrigger( + teradata_conn_id=compute_cluster_provision_instance.teradata_conn_id, + compute_profile_name=compute_profile_name, + compute_group_name=compute_group_name, + operation_type=Constants.CC_CREATE_OPR, + poll_interval=Constants.CC_POLL_INTERVAL, + ) + mock_defer.return_value = expected_trigger + result = compute_cluster_provision_instance._compute_cluster_execute() + assert result == "Success" + mock_defer.assert_called_once() + mock_hook.run.assert_has_calls( + [ + call( + "SELECT count(1) from DBC.StorageV WHERE StorageName='TD_OFSSTORAGE'", + handler=_single_result_row_handler, + ), + call( + "SELECT InfoData AS Version FROM DBC.DBCInfoV WHERE InfoKey = 'VERSION'", + handler=_single_result_row_handler, + ), + call( + f"SELECT count(1) FROM DBC.ComputeGroups WHERE UPPER(ComputeGroupName) = UPPER('{compute_group_name}')", + handler=_single_result_row_handler, + ), + call( + f"CREATE COMPUTE GROUP {compute_group_name}", handler=_single_result_row_handler + ), + call( + f"SEL ComputeProfileState FROM DBC.ComputeProfilesVX WHERE UPPER(ComputeProfileName) = UPPER('{compute_profile_name}') AND UPPER(ComputeGroupName) = UPPER('{compute_group_name}')", + handler=_single_result_row_handler, + ), + call( + f"CREATE COMPUTE PROFILE {compute_profile_name} IN {compute_group_name}", + handler=_single_result_row_handler, + ), + ] + ) + + def test_cc_execute_provision_new_cp_new_cg_with_options( + self, + compute_group_name, + query_strategy, + compute_map, + compute_attribute, + compute_cluster_provision_instance, + ): + with patch.object(compute_cluster_provision_instance, "hook") as mock_hook: + # Set up mock return values + mock_hook.run.side_effect = ["1", "20.00", "0", "Success", None, "Success"] + compute_cluster_provision_instance.compute_group_name = compute_group_name + compute_profile_name = compute_cluster_provision_instance.compute_profile_name + compute_cluster_provision_instance.query_strategy = query_strategy + compute_cluster_provision_instance.compute_map = compute_map + compute_cluster_provision_instance.compute_attribute = compute_attribute + + with patch.object(compute_cluster_provision_instance, "defer") as mock_defer: + # Assert that defer method is called with the correct parameters + expected_trigger = TeradataComputeClusterSyncTrigger( + teradata_conn_id=compute_cluster_provision_instance.teradata_conn_id, + compute_profile_name=compute_profile_name, + compute_group_name=compute_group_name, + operation_type=Constants.CC_CREATE_OPR, + poll_interval=Constants.CC_POLL_INTERVAL, + ) + mock_defer.return_value = expected_trigger + result = compute_cluster_provision_instance._compute_cluster_execute() + assert result == "Success" + mock_defer.assert_called_once() + mock_hook.run.assert_has_calls( + [ + call( + "SELECT count(1) from DBC.StorageV WHERE StorageName='TD_OFSSTORAGE'", + handler=_single_result_row_handler, + ), + call( + "SELECT InfoData AS Version FROM DBC.DBCInfoV WHERE InfoKey = 'VERSION'", + handler=_single_result_row_handler, + ), + call( + f"SELECT count(1) FROM DBC.ComputeGroups WHERE UPPER(ComputeGroupName) = UPPER('{compute_group_name}')", + handler=_single_result_row_handler, + ), + call( + f"CREATE COMPUTE GROUP {compute_group_name} USING QUERY_STRATEGY ('{query_strategy}')", + handler=_single_result_row_handler, + ), + call( + f"SEL ComputeProfileState FROM DBC.ComputeProfilesVX WHERE UPPER(ComputeProfileName) = UPPER('{compute_profile_name}') AND UPPER(ComputeGroupName) = UPPER('{compute_group_name}')", + handler=_single_result_row_handler, + ), + call( + f"CREATE COMPUTE PROFILE {compute_profile_name} IN {compute_group_name}, " + f"INSTANCE = {compute_map}, INSTANCE TYPE = {query_strategy} USING {compute_attribute}", + handler=_single_result_row_handler, + ), + ] + ) + + def test_compute_cluster_execute_drop_cp(self, compute_cluster_decommission_instance): + with patch.object(compute_cluster_decommission_instance, "hook") as mock_hook: + # Set up mock return values + mock_hook.run.side_effect = ["1", "20.00", None] + compute_profile_name = compute_cluster_decommission_instance.compute_profile_name + compute_cluster_decommission_instance._compute_cluster_execute() + mock_hook.run.assert_has_calls( + [ + call( + "SELECT count(1) from DBC.StorageV WHERE StorageName='TD_OFSSTORAGE'", + handler=_single_result_row_handler, + ), + call( + "SELECT InfoData AS Version FROM DBC.DBCInfoV WHERE InfoKey = 'VERSION'", + handler=_single_result_row_handler, + ), + call(f"DROP COMPUTE PROFILE {compute_profile_name}", handler=_single_result_row_handler), + ] + ) + + def test_compute_cluster_execute_drop_cp_cg( + self, compute_cluster_decommission_instance, compute_group_name + ): + with patch.object(compute_cluster_decommission_instance, "hook") as mock_hook: + # Set up mock return values + mock_hook.run.side_effect = ["1", "20.00", None, None] + compute_profile_name = compute_cluster_decommission_instance.compute_profile_name + compute_cluster_decommission_instance.compute_group_name = compute_group_name + compute_cluster_decommission_instance.delete_compute_group = True + compute_cluster_decommission_instance._compute_cluster_execute() + mock_hook.run.assert_has_calls( + [ + call( + "SELECT count(1) from DBC.StorageV WHERE StorageName='TD_OFSSTORAGE'", + handler=_single_result_row_handler, + ), + call( + "SELECT InfoData AS Version FROM DBC.DBCInfoV WHERE InfoKey = 'VERSION'", + handler=_single_result_row_handler, + ), + call( + f"DROP COMPUTE PROFILE {compute_profile_name} IN COMPUTE GROUP {compute_group_name}", + handler=_single_result_row_handler, + ), + call(f"DROP COMPUTE GROUP {compute_group_name}", handler=_single_result_row_handler), + ] + ) + + def test_compute_cluster_execute_resume_success(self, compute_cluster_resume_instance): + with patch.object(compute_cluster_resume_instance, "hook") as mock_hook: + # Set up mock return values + mock_hook.run.side_effect = ["1", "20.00", "Suspended", "Success"] + compute_profile_name = compute_cluster_resume_instance.compute_profile_name + + with patch.object(compute_cluster_resume_instance, "defer") as mock_defer: + expected_trigger = TeradataComputeClusterSyncTrigger( + teradata_conn_id=compute_cluster_resume_instance.teradata_conn_id, + compute_profile_name=compute_profile_name, + operation_type=Constants.CC_RESUME_OPR, + poll_interval=Constants.CC_POLL_INTERVAL, + ) + mock_defer.return_value = expected_trigger + result = compute_cluster_resume_instance._compute_cluster_execute() + assert result == "Success" + mock_defer.assert_called_once() + mock_hook.run.assert_has_calls( + [ + call( + "SELECT count(1) from DBC.StorageV WHERE StorageName='TD_OFSSTORAGE'", + handler=_single_result_row_handler, + ), + call( + "SELECT InfoData AS Version FROM DBC.DBCInfoV WHERE InfoKey = 'VERSION'", + handler=_single_result_row_handler, + ), + call( + f"SEL ComputeProfileState FROM DBC.ComputeProfilesVX WHERE UPPER(ComputeProfileName) = UPPER('{compute_profile_name}')", + handler=_single_result_row_handler, + ), + call( + f"RESUME COMPUTE FOR COMPUTE PROFILE {compute_profile_name}", + handler=_single_result_row_handler, + ), + ] + ) + + def test_compute_cluster_execute_resume_cg_success( + self, compute_group_name, compute_cluster_resume_instance + ): + with patch.object(compute_cluster_resume_instance, "hook") as mock_hook: + # Set up mock return values + mock_hook.run.side_effect = ["1", "20.00", "Suspended", "Success"] + compute_profile_name = compute_cluster_resume_instance.compute_profile_name + compute_cluster_resume_instance.compute_group_name = compute_group_name + with patch.object(compute_cluster_resume_instance, "defer") as mock_defer: + expected_trigger = TeradataComputeClusterSyncTrigger( + teradata_conn_id=compute_cluster_resume_instance.teradata_conn_id, + compute_profile_name=compute_profile_name, + operation_type=Constants.CC_RESUME_OPR, + poll_interval=Constants.CC_POLL_INTERVAL, + ) + mock_defer.return_value = expected_trigger + result = compute_cluster_resume_instance._compute_cluster_execute() + assert result == "Success" + mock_defer.assert_called_once() + mock_hook.run.assert_has_calls( + [ + call( + "SELECT count(1) from DBC.StorageV WHERE StorageName='TD_OFSSTORAGE'", + handler=_single_result_row_handler, + ), + call( + "SELECT InfoData AS Version FROM DBC.DBCInfoV WHERE InfoKey = 'VERSION'", + handler=_single_result_row_handler, + ), + call( + f"SEL ComputeProfileState FROM DBC.ComputeProfilesVX WHERE UPPER(ComputeProfileName) = UPPER('{compute_profile_name}')" + f" AND UPPER(ComputeGroupName) = UPPER('{compute_group_name}')", + handler=_single_result_row_handler, + ), + call( + f"RESUME COMPUTE FOR COMPUTE PROFILE {compute_profile_name} IN COMPUTE GROUP {compute_group_name}", + handler=_single_result_row_handler, + ), + ] + ) + + def test_compute_cluster_execute_resume_cc_not_exists( + self, compute_group_name, compute_cluster_resume_instance + ): + with patch.object(compute_cluster_resume_instance, "hook") as mock_hook: + # Set up mock return values + mock_hook.run.side_effect = ["1", "20.00", None] + compute_profile_name = compute_cluster_resume_instance.compute_profile_name + with pytest.raises(AirflowException): + compute_cluster_resume_instance._compute_cluster_execute() + mock_hook.run.assert_has_calls( + [ + call( + "SELECT count(1) from DBC.StorageV WHERE StorageName='TD_OFSSTORAGE'", + handler=_single_result_row_handler, + ), + call( + "SELECT InfoData AS Version FROM DBC.DBCInfoV WHERE InfoKey = 'VERSION'", + handler=_single_result_row_handler, + ), + call( + f"SEL ComputeProfileState FROM DBC.ComputeProfilesVX WHERE UPPER(ComputeProfileName) = UPPER('{compute_profile_name}')", + handler=_single_result_row_handler, + ), + ] + ) + + def test_compute_cluster_execute_resume_same_state(self, compute_cluster_resume_instance): + with patch.object(compute_cluster_resume_instance, "hook") as mock_hook: + # Set up mock return values + mock_hook.run.side_effect = ["1", "20.00", "Running"] + compute_profile_name = compute_cluster_resume_instance.compute_profile_name + result = compute_cluster_resume_instance._compute_cluster_execute() + assert result is None + mock_hook.run.assert_has_calls( + [ + call( + "SELECT count(1) from DBC.StorageV WHERE StorageName='TD_OFSSTORAGE'", + handler=_single_result_row_handler, + ), + call( + "SELECT InfoData AS Version FROM DBC.DBCInfoV WHERE InfoKey = 'VERSION'", + handler=_single_result_row_handler, + ), + call( + f"SEL ComputeProfileState FROM DBC.ComputeProfilesVX WHERE UPPER(ComputeProfileName) = UPPER('{compute_profile_name}')", + handler=_single_result_row_handler, + ), + ] + ) + + def test_compute_cluster_execute_suspend_success(self, compute_cluster_suspend_instance): + with patch.object(compute_cluster_suspend_instance, "hook") as mock_hook: + # Set up mock return values + mock_hook.run.side_effect = ["1", "20.00", "Running", "Success"] + compute_profile_name = compute_cluster_suspend_instance.compute_profile_name + + with patch.object(compute_cluster_suspend_instance, "defer") as mock_defer: + expected_trigger = TeradataComputeClusterSyncTrigger( + teradata_conn_id=compute_cluster_suspend_instance.teradata_conn_id, + compute_profile_name=compute_profile_name, + operation_type=Constants.CC_RESUME_OPR, + poll_interval=Constants.CC_POLL_INTERVAL, + ) + mock_defer.return_value = expected_trigger + result = compute_cluster_suspend_instance._compute_cluster_execute() + assert result == "Success" + mock_defer.assert_called_once() + mock_hook.run.assert_has_calls( + [ + call( + "SELECT count(1) from DBC.StorageV WHERE StorageName='TD_OFSSTORAGE'", + handler=_single_result_row_handler, + ), + call( + "SELECT InfoData AS Version FROM DBC.DBCInfoV WHERE InfoKey = 'VERSION'", + handler=_single_result_row_handler, + ), + call( + f"SEL ComputeProfileState FROM DBC.ComputeProfilesVX WHERE UPPER(ComputeProfileName) = UPPER('{compute_profile_name}')", + handler=_single_result_row_handler, + ), + call( + f"SUSPEND COMPUTE FOR COMPUTE PROFILE {compute_profile_name}", + handler=_single_result_row_handler, + ), + ] + ) + + def test_compute_cluster_execute_suspend_cg_success( + self, compute_group_name, compute_cluster_suspend_instance + ): + with patch.object(compute_cluster_suspend_instance, "hook") as mock_hook: + # Set up mock return values + mock_hook.run.side_effect = ["1", "20.00", "Running", "Success"] + compute_profile_name = compute_cluster_suspend_instance.compute_profile_name + compute_cluster_suspend_instance.compute_group_name = compute_group_name + with patch.object(compute_cluster_suspend_instance, "defer") as mock_defer: + expected_trigger = TeradataComputeClusterSyncTrigger( + teradata_conn_id=compute_cluster_suspend_instance.teradata_conn_id, + compute_profile_name=compute_profile_name, + operation_type=Constants.CC_RESUME_OPR, + poll_interval=Constants.CC_POLL_INTERVAL, + ) + mock_defer.return_value = expected_trigger + result = compute_cluster_suspend_instance._compute_cluster_execute() + assert result == "Success" + mock_defer.assert_called_once() + mock_hook.run.assert_has_calls( + [ + call( + "SELECT count(1) from DBC.StorageV WHERE StorageName='TD_OFSSTORAGE'", + handler=_single_result_row_handler, + ), + call( + "SELECT InfoData AS Version FROM DBC.DBCInfoV WHERE InfoKey = 'VERSION'", + handler=_single_result_row_handler, + ), + call( + f"SEL ComputeProfileState FROM DBC.ComputeProfilesVX WHERE UPPER(ComputeProfileName) = UPPER('{compute_profile_name}')" + f" AND UPPER(ComputeGroupName) = UPPER('{compute_group_name}')", + handler=_single_result_row_handler, + ), + call( + f"SUSPEND COMPUTE FOR COMPUTE PROFILE {compute_profile_name} IN COMPUTE GROUP {compute_group_name}", + handler=_single_result_row_handler, + ), + ] + ) + + def test_compute_cluster_execute_suspend_cc_not_exists( + self, compute_group_name, compute_cluster_suspend_instance + ): + with patch.object(compute_cluster_suspend_instance, "hook") as mock_hook: + # Set up mock return values + mock_hook.run.side_effect = ["1", "20.00", None] + compute_profile_name = compute_cluster_suspend_instance.compute_profile_name + with pytest.raises(AirflowException): + compute_cluster_suspend_instance._compute_cluster_execute() + mock_hook.run.assert_has_calls( + [ + call( + "SELECT count(1) from DBC.StorageV WHERE StorageName='TD_OFSSTORAGE'", + handler=_single_result_row_handler, + ), + call( + "SELECT InfoData AS Version FROM DBC.DBCInfoV WHERE InfoKey = 'VERSION'", + handler=_single_result_row_handler, + ), + call( + f"SEL ComputeProfileState FROM DBC.ComputeProfilesVX WHERE UPPER(ComputeProfileName) = UPPER('{compute_profile_name}')", + handler=_single_result_row_handler, + ), + ] + ) + + def test_compute_cluster_execute_suspend_same_state(self, compute_cluster_suspend_instance): + with patch.object(compute_cluster_suspend_instance, "hook") as mock_hook: + # Set up mock return values + mock_hook.run.side_effect = ["1", "20.00", "Suspended"] + compute_profile_name = compute_cluster_suspend_instance.compute_profile_name + result = compute_cluster_suspend_instance._compute_cluster_execute() + assert result is None + mock_hook.run.assert_has_calls( + [ + call( + "SELECT count(1) from DBC.StorageV WHERE StorageName='TD_OFSSTORAGE'", + handler=_single_result_row_handler, + ), + call( + "SELECT InfoData AS Version FROM DBC.DBCInfoV WHERE InfoKey = 'VERSION'", + handler=_single_result_row_handler, + ), + call( + f"SEL ComputeProfileState FROM DBC.ComputeProfilesVX WHERE UPPER(ComputeProfileName) = UPPER('{compute_profile_name}')", + handler=_single_result_row_handler, + ), + ] + ) diff --git a/tests/providers/teradata/triggers/__init__.py b/tests/providers/teradata/triggers/__init__.py new file mode 100644 index 0000000000000..217e5db960782 --- /dev/null +++ b/tests/providers/teradata/triggers/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/providers/teradata/triggers/test_teradata_compute_cluster.py b/tests/providers/teradata/triggers/test_teradata_compute_cluster.py new file mode 100644 index 0000000000000..e272d468e41ce --- /dev/null +++ b/tests/providers/teradata/triggers/test_teradata_compute_cluster.py @@ -0,0 +1,174 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest.mock import call, patch + +import pytest + +from airflow.providers.common.sql.hooks.sql import fetch_one_handler +from airflow.providers.teradata.hooks.teradata import TeradataHook +from airflow.providers.teradata.triggers.teradata_compute_cluster import TeradataComputeClusterSyncTrigger +from airflow.providers.teradata.utils.constants import Constants +from airflow.triggers.base import TriggerEvent + + +@pytest.mark.asyncio +async def test_run_suspend_success(): + trigger = TeradataComputeClusterSyncTrigger( + teradata_conn_id="test_conn_id", + compute_profile_name="test_profile", + operation_type=Constants.CC_SUSPEND_OPR, + poll_interval=1, + ) + with patch.object(trigger, "get_status") as mock_get_status: + mock_get_status.return_value = Constants.CC_SUSPEND_DB_STATUS + async for event in trigger.run(): + assert event == TriggerEvent( + { + "status": "success", + "message": Constants.CC_OPR_SUCCESS_STATUS_MSG + % ("test_profile", Constants.CC_SUSPEND_OPR), + } + ) + mock_get_status.assert_called_once() + + +@pytest.mark.asyncio +async def test_run_suspend_success_cg(): + trigger = TeradataComputeClusterSyncTrigger( + teradata_conn_id="test_conn_id", + compute_profile_name="test_profile", + operation_type=Constants.CC_SUSPEND_OPR, + poll_interval=1, + ) + with patch.object(trigger, "get_status") as mock_get_status: + mock_get_status.return_value = Constants.CC_SUSPEND_DB_STATUS + async for event in trigger.run(): + assert event == TriggerEvent( + { + "status": "success", + "message": Constants.CC_OPR_SUCCESS_STATUS_MSG + % ("test_profile", Constants.CC_SUSPEND_OPR), + } + ) + mock_get_status.assert_called_once() + + +@pytest.mark.asyncio +async def test_run_suspend_failure(): + trigger = TeradataComputeClusterSyncTrigger( + teradata_conn_id="test_conn_id", + compute_profile_name="test_profile", + operation_type=Constants.CC_SUSPEND_OPR, + poll_interval=1, + ) + with patch.object(trigger, "get_status") as mock_get_status: + mock_get_status.return_value = None + async for event in trigger.run(): + assert event == TriggerEvent({"status": "error", "message": Constants.CC_GRP_PRP_NON_EXISTS_MSG}) + mock_get_status.assert_called_once() + + +@pytest.mark.asyncio +async def test_run_resume_success(): + trigger = TeradataComputeClusterSyncTrigger( + teradata_conn_id="test_conn_id", + compute_profile_name="test_profile", + operation_type=Constants.CC_RESUME_OPR, + poll_interval=1, + ) + with patch.object(trigger, "get_status") as mock_get_status: + mock_get_status.return_value = Constants.CC_RESUME_DB_STATUS + async for event in trigger.run(): + assert event == TriggerEvent( + { + "status": "success", + "message": Constants.CC_OPR_SUCCESS_STATUS_MSG + % ("test_profile", Constants.CC_RESUME_OPR), + } + ) + mock_get_status.assert_called_once() + + +@pytest.mark.asyncio +async def test_run_resume_failure(): + trigger = TeradataComputeClusterSyncTrigger( + teradata_conn_id="test_conn_id", + compute_profile_name="test_profile", + operation_type=Constants.CC_RESUME_OPR, + poll_interval=1, + ) + with patch.object(trigger, "get_status") as mock_get_status: + mock_get_status.return_value = None + async for event in trigger.run(): + assert event == TriggerEvent({"status": "error", "message": Constants.CC_GRP_PRP_NON_EXISTS_MSG}) + mock_get_status.assert_called_once() + + +@pytest.fixture +def mock_teradata_hook_run(): + with patch.object(TeradataHook, "run") as mock_run: + yield mock_run + + +@pytest.mark.asyncio +async def test_get_status(mock_teradata_hook_run): + trigger = TeradataComputeClusterSyncTrigger( + teradata_conn_id="test_conn_id", + compute_profile_name="test_profile", + operation_type=Constants.CC_SUSPEND_OPR, + poll_interval=1, + ) + mock_teradata_hook_run.return_value = [Constants.CC_SUSPEND_DB_STATUS] + status = await trigger.get_status() + assert status == Constants.CC_SUSPEND_DB_STATUS + mock_teradata_hook_run.assert_called_once() + mock_teradata_hook_run.assert_has_calls( + [ + call( + "SEL ComputeProfileState FROM DBC.ComputeProfilesVX WHERE UPPER(ComputeProfileName) = " + "UPPER('test_profile')", + handler=fetch_one_handler, + ), + ] + ) + + +@pytest.mark.asyncio +async def test_get_status_cg(mock_teradata_hook_run): + trigger = TeradataComputeClusterSyncTrigger( + teradata_conn_id="test_conn_id", + compute_profile_name="test_profile", + compute_group_name="test_group", + operation_type=Constants.CC_RESUME_OPR, + poll_interval=1, + ) + mock_teradata_hook_run.return_value = [Constants.CC_RESUME_DB_STATUS] + status = await trigger.get_status() + assert status == Constants.CC_RESUME_DB_STATUS + mock_teradata_hook_run.assert_called_once() + mock_teradata_hook_run.assert_has_calls( + [ + call( + "SEL ComputeProfileState FROM DBC.ComputeProfilesVX WHERE UPPER(ComputeProfileName) = " + "UPPER('test_profile') AND UPPER(ComputeGroupName) = UPPER('test_group')", + handler=fetch_one_handler, + ), + ] + ) diff --git a/tests/providers/teradata/utils/__init__.py b/tests/providers/teradata/utils/__init__.py new file mode 100644 index 0000000000000..217e5db960782 --- /dev/null +++ b/tests/providers/teradata/utils/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/providers/teradata/utils/test_constants.py b/tests/providers/teradata/utils/test_constants.py new file mode 100644 index 0000000000000..f4410f3b2c4d1 --- /dev/null +++ b/tests/providers/teradata/utils/test_constants.py @@ -0,0 +1,110 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from airflow.providers.teradata.utils.constants import Constants + + +def test_create_operations(): + assert Constants.CC_CREATE_OPR == "CREATE" + + +def test_create_suspend_operations(): + assert Constants.CC_CREATE_SUSPEND_OPR == "CREATE_SUSPEND" + + +def test_drop_operations(): + assert Constants.CC_DROP_OPR == "DROP" + + +def test_suspend_operations(): + assert Constants.CC_SUSPEND_OPR == "SUSPEND" + + +def test_resume_operations(): + assert Constants.CC_RESUME_OPR == "RESUME" + + +def test_initialize_db_status(): + assert Constants.CC_INITIALIZE_DB_STATUS == "Initializing" + + +def test_suspend_db_status(): + assert Constants.CC_SUSPEND_DB_STATUS == "Suspended" + + +def test_resume_db_status(): + assert Constants.CC_RESUME_DB_STATUS == "Running" + + +def test_operation_success_message(): + expected_msg = "Compute Cluster %s %s operation completed successfully." + assert Constants.CC_OPR_SUCCESS_STATUS_MSG == expected_msg + + +def test_operation_failure_message(): + expected_msg = "Compute Cluster %s %s operation has failed." + assert Constants.CC_OPR_FAILURE_STATUS_MSG == expected_msg + + +def test_initializing_status_message(): + expected_msg = "The environment is currently initializing. Please wait." + assert Constants.CC_OPR_INITIALIZING_STATUS_MSG == expected_msg + + +def test_empty_profile_error_message(): + expected_msg = "Please provide a valid name for the compute cluster profile." + assert Constants.CC_OPR_EMPTY_PROFILE_ERROR_MSG == expected_msg + + +def test_non_exists_message(): + expected_msg = "The specified Compute cluster is not present or The user doesn't have permission to access compute cluster." + assert Constants.CC_GRP_PRP_NON_EXISTS_MSG == expected_msg + + +def test_unauthorized_message(): + expected_msg = "The %s operation is not authorized for the user." + assert Constants.CC_GRP_PRP_UN_AUTHORIZED_MSG == expected_msg + + +def test_lake_support_only_message(): + expected_msg = "Compute Groups is supported only on Vantage Cloud Lake." + assert Constants.CC_GRP_LAKE_SUPPORT_ONLY_MSG == expected_msg + + +def test_timeout_error_message(): + expected_msg = "There is an issue with the %s operation. Kindly consult the administrator for assistance." + assert Constants.CC_OPR_TIMEOUT_ERROR == expected_msg + + +def test_exists_message(): + expected_msg = "The specified Compute cluster is already exists." + assert Constants.CC_GRP_PRP_EXISTS_MSG == expected_msg + + +def test_empty_copy_profile_error_message(): + expected_msg = "Please provide a valid name for the source and target compute profile." + assert Constants.CC_OPR_EMPTY_COPY_PROFILE_ERROR_MSG == expected_msg + + +def test_timeout_value(): + assert Constants.CC_OPR_TIME_OUT == 1200 + + +def test_poll_interval(): + assert Constants.CC_POLL_INTERVAL == 60 diff --git a/tests/system/providers/teradata/example_teradata_compute_cluster.py b/tests/system/providers/teradata/example_teradata_compute_cluster.py new file mode 100644 index 0000000000000..3fefe9858770a --- /dev/null +++ b/tests/system/providers/teradata/example_teradata_compute_cluster.py @@ -0,0 +1,158 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Example use of Teradata Compute Cluster Provision Operator +""" + +from __future__ import annotations + +import datetime +import os + +import pytest + +from airflow import DAG +from airflow.models import Param + +try: + from airflow.providers.teradata.operators.teradata_compute_cluster import ( + TeradataComputeClusterDecommissionOperator, + TeradataComputeClusterProvisionOperator, + TeradataComputeClusterResumeOperator, + TeradataComputeClusterSuspendOperator, + ) +except ImportError: + pytest.skip("TERADATA provider not available", allow_module_level=True) + +# [START teradata_vantage_lake_compute_cluster_howto_guide] + + +ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID") +DAG_ID = "example_teradata_computer_cluster" + +with DAG( + dag_id=DAG_ID, + start_date=datetime.datetime(2020, 2, 2), + schedule="@once", + catchup=False, + default_args={"teradata_conn_id": "teradata_lake"}, + render_template_as_native_obj=True, + params={ + "compute_group_name": Param( + "compute_group_test", + type="string", + title="Compute cluster group Name:", + description="Enter compute cluster group name.", + ), + "compute_profile_name": Param( + "compute_profile_test", + type="string", + title="Compute cluster profile Name:", + description="Enter compute cluster profile name.", + ), + "query_strategy": Param( + "STANDARD", + type="string", + title="Compute cluster instance type:", + description="Enter compute cluster instance type. Valid values are STANDARD, ANALYTIC", + ), + "compute_map": Param( + "TD_COMPUTE_XSMALL", + type="string", + title="Compute Map Name:", + description="Enter compute cluster compute map name.", + ), + "compute_attribute": Param( + "MIN_COMPUTE_COUNT(1) MAX_COMPUTE_COUNT(5) INITIALLY_SUSPENDED('FALSE')", + type="string", + title="Compute cluster compute attribute:", + description="Enter compute cluster compute attribute values.", + ), + "teradata_conn_id": Param( + "teradata_lake", + type="string", + title="Teradata ConnectionId:", + description="Enter Teradata connection id.", + ), + "timeout": Param( + 20, + type="integer", + title="Timeout:", + description="Time elapsed before the task times out and fails. Timeout is in minutes.", + ), + }, +) as dag: + # [START teradata_vantage_lake_compute_cluster_provision_howto_guide] + compute_cluster_provision_operation = TeradataComputeClusterProvisionOperator( + task_id="compute_cluster_provision_operation", + compute_profile_name="{{ params.compute_profile_name }}", + compute_group_name="{{ params.compute_group_name }}", + teradata_conn_id="{{ params.teradata_conn_id }}", + timeout="{{ params.timeout }}", + query_strategy="{{ params.query_strategy }}", + compute_map="{{ params.compute_map }}", + compute_attribute="{{ params.compute_attribute }}", + ) + # [END teradata_vantage_lake_compute_cluster_provision_howto_guide] + # [START teradata_vantage_lake_compute_cluster_suspend_howto_guide] + compute_cluster_suspend_operation = TeradataComputeClusterSuspendOperator( + task_id="compute_cluster_suspend_operation", + compute_profile_name="{{ params.compute_profile_name }}", + compute_group_name="{{ params.compute_group_name }}", + teradata_conn_id="{{ params.teradata_conn_id }}", + timeout="{{ params.timeout }}", + ) + # [END teradata_vantage_lake_compute_cluster_suspend_howto_guide] + # [START teradata_vantage_lake_compute_cluster_resume_howto_guide] + compute_cluster_resume_operation = TeradataComputeClusterResumeOperator( + task_id="compute_cluster_resume_operation", + compute_profile_name="{{ params.compute_profile_name }}", + compute_group_name="{{ params.compute_group_name }}", + teradata_conn_id="{{ params.teradata_conn_id }}", + timeout="{{ params.timeout }}", + ) + # [END teradata_vantage_lake_compute_cluster_resume_howto_guide] + # [START teradata_vantage_lake_compute_cluster_decommission_howto_guide] + compute_cluster_decommission_operation = TeradataComputeClusterDecommissionOperator( + task_id="compute_cluster_decommission_operation", + compute_profile_name="{{ params.compute_profile_name }}", + compute_group_name="{{ params.compute_group_name }}", + delete_compute_group=bool("{{ params.delete_compute_group }}"), + teradata_conn_id="{{ params.teradata_conn_id }}", + timeout="{{ params.timeout }}", + ) + # [END teradata_vantage_lake_compute_cluster_decommission_howto_guide] + ( + compute_cluster_provision_operation + >> compute_cluster_suspend_operation + >> compute_cluster_resume_operation + >> compute_cluster_decommission_operation + ) + + # [END teradata_vantage_lake_compute_cluster_howto_guide] + + from tests.system.utils.watcher import watcher + + # This test needs watcher in order to properly mark success/failure + # when "tearDown" task with trigger rule is part of the DAG + list(dag.tasks) >> watcher() + +from tests.system.utils import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag)