From e97c04d8f424744df303ebd08e89b54155a40fc0 Mon Sep 17 00:00:00 2001 From: Alex Lopez Date: Mon, 17 Jul 2023 15:45:53 +0200 Subject: [PATCH] Process query rows one at a time to reduce memory footprint (#15268) * Process query rows one at a time to reduce memory footprint * Remove `fetchall` from mock --- snowflake/datadog_checks/snowflake/check.py | 9 +++++---- .../_snowflake_connector_patch/connector.py | 15 +++++++++++---- snowflake/tests/test_snowflake.py | 2 +- 3 files changed, 17 insertions(+), 9 deletions(-) diff --git a/snowflake/datadog_checks/snowflake/check.py b/snowflake/datadog_checks/snowflake/check.py index d3517a3cc1fd8..1a40c1454ed50 100644 --- a/snowflake/datadog_checks/snowflake/check.py +++ b/snowflake/datadog_checks/snowflake/check.py @@ -155,8 +155,9 @@ def execute_query_raw(self, query): if cursor.rowcount is None or cursor.rowcount < 1: self.log.debug("Failed to fetch records from query: `%s`", query) - return [] - return cursor.fetchall() + return + # Iterating on the cursor provides one row at a time without loading all of them at once + yield from cursor def connect(self): self.log.debug( @@ -209,8 +210,8 @@ def connect(self): @AgentCheck.metadata_entrypoint def _collect_version(self): try: - raw_version = self.execute_query_raw("select current_version();") - version = raw_version[0][0] + raw_version = next(self.execute_query_raw("select current_version();")) + version = raw_version[0] except Exception as e: self.log.error("Error collecting version for Snowflake: %s", e) else: diff --git a/snowflake/tests/snowflake_connector_patch/_snowflake_connector_patch/connector.py b/snowflake/tests/snowflake_connector_patch/_snowflake_connector_patch/connector.py index c4e7c696fb7fa..70fc9ac93bcea 100644 --- a/snowflake/tests/snowflake_connector_patch/_snowflake_connector_patch/connector.py +++ b/snowflake/tests/snowflake_connector_patch/_snowflake_connector_patch/connector.py @@ -4,6 +4,7 @@ import re import requests +from snowflake.connector.cursor import SnowflakeCursor from . import tables @@ -49,14 +50,20 @@ def execute(self, query): if self.schema == 'ORGANIZATION_USAGE': table_prefix = 'ORGANIZATION_' table_attr = "{}{}".format(table_prefix, table_name) - self.__data = getattr(tables, table_attr, []) + self.__data = list(getattr(tables, table_attr, [])) elif query == 'select current_version();': self.__data = [('4.30.2',)] else: self.__data = [] - def fetchall(self): - return self.__data + def fetchone(self): + try: + return self.__data.pop(0) + except IndexError: + return None + + def __iter__(self): + return SnowflakeCursor.__iter__(self) def close(self): - pass + self.__data = [] diff --git a/snowflake/tests/test_snowflake.py b/snowflake/tests/test_snowflake.py index 3c80addcadae2..f17dbb225afe2 100644 --- a/snowflake/tests/test_snowflake.py +++ b/snowflake/tests/test_snowflake.py @@ -123,7 +123,7 @@ def test_version_metadata(dd_run_check, instance, datadog_agent): 'version.raw': '4.30.2', 'version.scheme': 'semver', } - with mock.patch('datadog_checks.snowflake.SnowflakeCheck.execute_query_raw', return_value=expected_version): + with mock.patch('datadog_checks.snowflake.SnowflakeCheck.execute_query_raw', return_value=iter(expected_version)): check = SnowflakeCheck(CHECK_NAME, {}, [instance]) check.check_id = 'test:123' check._conn = mock.MagicMock()