Skip to content

Commit

Permalink
HJ-116 - Fix BigQuery partitioning queries to properly support mult…
Browse files Browse the repository at this point in the history
…iple identity clauses (#5432)

Co-authored-by: Neville Samuell <neville@ethyca.com>
  • Loading branch information
andres-torres-marroquin and NevilleS committed Oct 31, 2024
1 parent 5b5031e commit 34e00b1
Show file tree
Hide file tree
Showing 9 changed files with 111 additions and 57 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ The types of changes are:

### Fixed
- API router sanitizer being too aggressive with NextJS Catch-all Segments [#5438](https://github.com/ethyca/fides/pull/5438)
- Fix BigQuery `partitioning` queries to properly support multiple identity clauses [#5432](https://github.com/ethyca/fides/pull/5432)

## [2.48.0](https://github.com/ethyca/fidesplus/compare/2.47.1...2.48.0)

Expand Down
5 changes: 5 additions & 0 deletions data/dataset/bigquery_example_test_dataset.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ dataset:
fides_meta:
identity: email
data_type: string
- name: custom_id
data_categories: [user.unique_id]
fides_meta:
identity: custom_id
data_type: string
- name: id
data_categories: [user.unique_id]
fides_meta:
Expand Down
10 changes: 5 additions & 5 deletions src/fides/api/service/connectors/query_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,7 +673,7 @@ def get_formatted_query_string(
) -> str:
"""Returns a query string with double quotation mark formatting for tables that have the same names as
Postgres reserved words."""
return f'SELECT {field_list} FROM "{self.node.collection.name}" WHERE {" OR ".join(clauses)}'
return f'SELECT {field_list} FROM "{self.node.collection.name}" WHERE ({" OR ".join(clauses)})'


class MySQLQueryConfig(SQLQueryConfig):
Expand All @@ -688,7 +688,7 @@ def get_formatted_query_string(
) -> str:
"""Returns a query string with backtick formatting for tables that have the same names as
MySQL reserved words."""
return f'SELECT {field_list} FROM `{self.node.collection.name}` WHERE {" OR ".join(clauses)}'
return f'SELECT {field_list} FROM `{self.node.collection.name}` WHERE ({" OR ".join(clauses)})'


class QueryStringWithoutTuplesOverrideQueryConfig(SQLQueryConfig):
Expand Down Expand Up @@ -797,7 +797,7 @@ def get_formatted_query_string(
clauses: List[str],
) -> str:
"""Returns a query string with double quotation mark formatting as required by Snowflake syntax."""
return f'SELECT {field_list} FROM "{self.node.collection.name}" WHERE {" OR ".join(clauses)}'
return f'SELECT {field_list} FROM "{self.node.collection.name}" WHERE ({" OR ".join(clauses)})'

Check warning on line 800 in src/fides/api/service/connectors/query_config.py

View check run for this annotation

Codecov / codecov/patch

src/fides/api/service/connectors/query_config.py#L800

Added line #L800 was not covered by tests

def format_key_map_for_update_stmt(self, fields: List[str]) -> List[str]:
"""Adds the appropriate formatting for update statements in this datastore."""
Expand All @@ -823,7 +823,7 @@ def get_formatted_query_string(
) -> str:
"""Returns a query string with double quotation mark formatting for tables that have the same names as
Redshift reserved words."""
return f'SELECT {field_list} FROM "{self.node.collection.name}" WHERE {" OR ".join(clauses)}'
return f'SELECT {field_list} FROM "{self.node.collection.name}" WHERE ({" OR ".join(clauses)})'

Check warning on line 826 in src/fides/api/service/connectors/query_config.py

View check run for this annotation

Codecov / codecov/patch

src/fides/api/service/connectors/query_config.py#L826

Added line #L826 was not covered by tests


class GoogleCloudSQLPostgresQueryConfig(QueryStringWithoutTuplesOverrideQueryConfig):
Expand Down Expand Up @@ -896,7 +896,7 @@ def get_formatted_query_string(
Returns a query string with backtick formatting for tables that have the same names as
BigQuery reserved words.
"""
return f'SELECT {field_list} FROM `{self._generate_table_name()}` WHERE {" OR ".join(clauses)}'
return f'SELECT {field_list} FROM `{self._generate_table_name()}` WHERE ({" OR ".join(clauses)})'

Check warning on line 899 in src/fides/api/service/connectors/query_config.py

View check run for this annotation

Codecov / codecov/patch

src/fides/api/service/connectors/query_config.py#L899

Added line #L899 was not covered by tests

def generate_masking_stmt(
self,
Expand Down
51 changes: 31 additions & 20 deletions tests/fixtures/bigquery_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def bigquery_connection_config_without_secrets(db: Session) -> Generator:


@pytest.fixture(scope="function")
def bigquery_connection_config(db: Session) -> Generator:
def bigquery_connection_config(db: Session, bigquery_keyfile_creds) -> Generator:
connection_config = ConnectionConfig.create(
db=db,
data={
Expand All @@ -46,23 +46,41 @@ def bigquery_connection_config(db: Session) -> Generator:
},
)
# Pulling from integration config file or GitHub secrets
keyfile_creds = integration_config.get("bigquery", {}).get(
"keyfile_creds"
) or ast.literal_eval(os.environ.get("BIGQUERY_KEYFILE_CREDS"))
dataset = integration_config.get("bigquery", {}).get("dataset") or os.environ.get(
"BIGQUERY_DATASET"
)
if keyfile_creds:
schema = BigQuerySchema(keyfile_creds=keyfile_creds, dataset=dataset)
if bigquery_keyfile_creds:
schema = BigQuerySchema(keyfile_creds=bigquery_keyfile_creds, dataset=dataset)
connection_config.secrets = schema.model_dump(mode="json")
connection_config.save(db=db)

yield connection_config
connection_config.delete(db)


@pytest.fixture(scope="session")
def bigquery_keyfile_creds():
"""
Pulling from integration config file or GitHub secrets
"""
keyfile_creds = integration_config.get("bigquery", {}).get("keyfile_creds")

if keyfile_creds:
return keyfile_creds

if "BIGQUERY_KEYFILE_CREDS" in os.environ:
keyfile_creds = ast.literal_eval(os.environ.get("BIGQUERY_KEYFILE_CREDS"))

if not keyfile_creds:
raise RuntimeError("Missing keyfile_creds for BigQuery")

yield keyfile_creds


@pytest.fixture(scope="function")
def bigquery_connection_config_without_default_dataset(db: Session) -> Generator:
def bigquery_connection_config_without_default_dataset(
db: Session, bigquery_keyfile_creds
) -> Generator:
connection_config = ConnectionConfig.create(
db=db,
data={
Expand All @@ -72,12 +90,8 @@ def bigquery_connection_config_without_default_dataset(db: Session) -> Generator
"access": AccessLevel.write,
},
)
# Pulling from integration config file or GitHub secrets
keyfile_creds = integration_config.get("bigquery", {}).get(
"keyfile_creds"
) or ast.literal_eval(os.environ.get("BIGQUERY_KEYFILE_CREDS"))
if keyfile_creds:
schema = BigQuerySchema(keyfile_creds=keyfile_creds)
if bigquery_keyfile_creds:
schema = BigQuerySchema(keyfile_creds=bigquery_keyfile_creds)
connection_config.secrets = schema.model_dump(mode="json")
connection_config.save(db=db)

Expand Down Expand Up @@ -150,7 +164,7 @@ def bigquery_example_test_dataset_config_with_namespace_and_partitioning_meta(
bigquery_connection_config_without_default_dataset: ConnectionConfig,
db: Session,
example_datasets: List[Dict],
) -> Generator:
) -> Generator[DatasetConfig, None, None]:
bigquery_dataset = example_datasets[7]
bigquery_dataset["fides_meta"] = {
"namespace": {
Expand Down Expand Up @@ -360,7 +374,7 @@ def bigquery_resources_with_namespace_meta(


@pytest.fixture(scope="session")
def bigquery_test_engine() -> Generator:
def bigquery_test_engine(bigquery_keyfile_creds) -> Generator:
"""Return a connection to a Google BigQuery Warehouse"""

connection_config = ConnectionConfig(
Expand All @@ -370,14 +384,11 @@ def bigquery_test_engine() -> Generator:
)

# Pulling from integration config file or GitHub secrets
keyfile_creds = integration_config.get("bigquery", {}).get(
"keyfile_creds"
) or ast.literal_eval(os.environ.get("BIGQUERY_KEYFILE_CREDS"))
dataset = integration_config.get("bigquery", {}).get("dataset") or os.environ.get(
"BIGQUERY_DATASET"
)
if keyfile_creds:
schema = BigQuerySchema(keyfile_creds=keyfile_creds, dataset=dataset)
if bigquery_keyfile_creds:
schema = BigQuerySchema(keyfile_creds=bigquery_keyfile_creds, dataset=dataset)
connection_config.secrets = schema.model_dump(mode="json")

connector: BigQueryConnector = get_connector(connection_config)
Expand Down
4 changes: 2 additions & 2 deletions tests/ops/api/v1/endpoints/test_privacy_request_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -3501,7 +3501,7 @@ def test_request_preview(
if response["collectionAddress"]["dataset"] == "postgres"
if response["collectionAddress"]["collection"] == "subscriptions"
)
== 'SELECT email, id FROM "subscriptions" WHERE email = ?'
== 'SELECT email, id FROM "subscriptions" WHERE (email = ?)'
)

def test_request_preview_incorrect_body(
Expand Down Expand Up @@ -3578,7 +3578,7 @@ def test_request_preview_all(
if response["collectionAddress"]["dataset"] == "postgres"
if response["collectionAddress"]["collection"] == "subscriptions"
)
== 'SELECT email, id FROM "subscriptions" WHERE email = ?'
== 'SELECT email, id FROM "subscriptions" WHERE (email = ?)'
)
assert (
next(
Expand Down
24 changes: 12 additions & 12 deletions tests/ops/integration_tests/test_external_database_connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,18 +169,18 @@ def test_bigquery_example_data(bigquery_test_engine):
inspector = inspect(bigquery_test_engine)
assert sorted(inspector.get_table_names(schema="fidesopstest")) == sorted(
[
"address",
"customer",
"employee",
"login",
"order_item",
"orders",
"payment_card",
"product",
"report",
"service_request",
"visit",
"visit_partitioned",
"fidesopstest.address",
"fidesopstest.customer",
"fidesopstest.employee",
"fidesopstest.login",
"fidesopstest.order_item",
"fidesopstest.orders",
"fidesopstest.payment_card",
"fidesopstest.product",
"fidesopstest.report",
"fidesopstest.service_request",
"fidesopstest.visit",
"fidesopstest.visit_partitioned",
]
)

Expand Down
48 changes: 46 additions & 2 deletions tests/ops/service/connectors/test_bigquery_connector.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from typing import Generator

import pytest
Expand Down Expand Up @@ -68,7 +69,9 @@ def execution_node_with_namespace_and_partitioning_meta(
dataset_config.connection_config.key,
)
dataset_graph = DatasetGraph(graph_dataset)
traversal = Traversal(dataset_graph, {"email": "customer-1@example.com"})
traversal = Traversal(
dataset_graph, {"email": "customer-1@example.com", "custom_id": "123"}
)

yield traversal.traversal_node_dict[
CollectionAddress("bigquery_example_test_dataset", "customer")
Expand Down Expand Up @@ -166,7 +169,7 @@ def test_retrieve_partitioned_data(
policy,
privacy_request_with_email_identity,
):
"""Unit test of BigQueryQueryConfig.generate_delete specifically for a partitioned table"""
"""Unit test of BigQueryQueryConfig.retrieve_data specifically for a partitioned table"""
dataset_config = (
bigquery_example_test_dataset_config_with_namespace_and_partitioning_meta
)
Expand All @@ -182,3 +185,44 @@ def test_retrieve_partitioned_data(

assert len(results) == 1
assert results[0]["email"] == "customer-1@example.com"

def test_retrieve_partitioned_data_with_multiple_identifying_fields(
self,
bigquery_example_test_dataset_config_with_namespace_and_partitioning_meta: DatasetConfig,
execution_node_with_namespace_and_partitioning_meta,
policy,
privacy_request_with_email_identity,
loguru_caplog,
):
"""Unit test of BigQueryQueryConfig.retrieve_data specifically for a partitioned table with multiple identifying fields"""
dataset_config = (
bigquery_example_test_dataset_config_with_namespace_and_partitioning_meta
)
connector = BigQueryConnector(dataset_config.connection_config)

with loguru_caplog.at_level(logging.INFO):
results = connector.retrieve_data(
node=execution_node_with_namespace_and_partitioning_meta,
policy=policy,
privacy_request=privacy_request_with_email_identity,
request_task=RequestTask(),
input_data={
"email": ["customer-1@example.com"],
"custom_id": ["123"],
},
)
# Check that the correct SQL queries were executed and logged by sqlalchemy.engine.Engine
# This may be not be the best way to test this, but it's the best I could come up with
# without modifying the BigQueryConnector class to allow for a SQL queries generation
# that's decoupled from the actual execution of the queries.
assert (
"INFO sqlalchemy.engine.Engine:log.py:117 SELECT address_id, created, custom_id, email, id, name FROM `silken-precinct-284918.fidesopstest.customer` WHERE (email = %(email)s OR custom_id = %(custom_id)s) AND (`created` > TIMESTAMP_SUB(CURRENT_TIMESTAMP(), INTERVAL 1000 DAY) AND `created` <= CURRENT_TIMESTAMP())"
in loguru_caplog.text
)
assert (
"INFO sqlalchemy.engine.Engine:log.py:117 SELECT address_id, created, custom_id, email, id, name FROM `silken-precinct-284918.fidesopstest.customer` WHERE (email = %(email)s OR custom_id = %(custom_id)s) AND (`created` > TIMESTAMP_SUB(CURRENT_TIMESTAMP(), INTERVAL 2000 DAY) AND `created` <= TIMESTAMP_SUB(CURRENT_TIMESTAMP(), INTERVAL 1000 DAY))"
in loguru_caplog.text
)

assert len(results) == 1
assert results[0]["email"] == "customer-1@example.com"
17 changes: 5 additions & 12 deletions tests/ops/service/connectors/test_queryconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,7 +744,8 @@ def test_put_query_param_formatting_single_key(
}


@pytest.mark.skip(reason="move to plus in progress")
@pytest.mark.integration_external
@pytest.mark.integration_bigquery
class TestBigQueryQueryConfig:
@pytest.fixture(scope="function")
def bigquery_client(self, bigquery_connection_config):
Expand Down Expand Up @@ -773,8 +774,6 @@ def address_node(self, dataset_graph):
CollectionAddress("bigquery_example_test_dataset", "address")
].to_mock_execution_node()

@pytest.mark.integration_external
@pytest.mark.integration_bigquery
def test_generate_update_stmt(
self,
db,
Expand Down Expand Up @@ -816,8 +815,6 @@ def test_generate_update_stmt(
== "UPDATE `address` SET `house`=%(house:STRING)s, `street`=%(street:STRING)s, `city`=%(city:STRING)s, `state`=%(state:STRING)s, `zip`=%(zip:STRING)s WHERE `address`.`id` = %(id_1:STRING)s"
)

@pytest.mark.integration_external
@pytest.mark.integration_bigquery
def test_generate_namespaced_update_stmt(
self,
db,
Expand Down Expand Up @@ -864,8 +861,6 @@ def test_generate_namespaced_update_stmt(
== "UPDATE `cool_project.first_dataset.address` SET `house`=%(house:STRING)s, `street`=%(street:STRING)s, `city`=%(city:STRING)s, `state`=%(state:STRING)s, `zip`=%(zip:STRING)s WHERE `address`.`id` = %(id_1:STRING)s"
)

@pytest.mark.integration_external
@pytest.mark.integration_bigquery
def test_generate_delete_stmt(
self,
db,
Expand Down Expand Up @@ -906,8 +901,6 @@ def test_generate_delete_stmt(
== "DELETE FROM `employee` WHERE `employee`.`id` = %(id_1:STRING)s"
)

@pytest.mark.integration_external
@pytest.mark.integration_bigquery
def test_generate_namespaced_delete_stmt(
self,
db,
Expand Down Expand Up @@ -1026,16 +1019,16 @@ def execution_node(
BigQueryNamespaceMeta(
project_id="cool_project", dataset_id="first_dataset"
),
"SELECT address_id, created, email, id, name FROM `cool_project.first_dataset.customer` WHERE email = :email",
"SELECT address_id, created, custom_id, email, id, name FROM `cool_project.first_dataset.customer` WHERE (email = :email)",
),
# Namespace meta will be a dict / JSON when retrieved from the DB
(
{"project_id": "cool_project", "dataset_id": "first_dataset"},
"SELECT address_id, created, email, id, name FROM `cool_project.first_dataset.customer` WHERE email = :email",
"SELECT address_id, created, custom_id, email, id, name FROM `cool_project.first_dataset.customer` WHERE (email = :email)",
),
(
None,
"SELECT address_id, created, email, id, name FROM `customer` WHERE email = :email",
"SELECT address_id, created, custom_id, email, id, name FROM `customer` WHERE (email = :email)",
),
],
)
Expand Down
8 changes: 4 additions & 4 deletions tests/ops/task/test_graph_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,22 +431,22 @@ def test_sql_dry_run_queries(db) -> None:

assert (
env[CollectionAddress("mysql", "Customer")]
== 'SELECT customer_id, name, email, contact_address_id FROM "Customer" WHERE email = ?'
== 'SELECT customer_id, name, email, contact_address_id FROM "Customer" WHERE (email = ?)'
)

assert (
env[CollectionAddress("mysql", "User")]
== 'SELECT id, user_id, name FROM "User" WHERE user_id = ?'
== 'SELECT id, user_id, name FROM "User" WHERE (user_id = ?)'
)

assert (
env[CollectionAddress("postgres", "Order")]
== 'SELECT order_id, customer_id, shipping_address_id, billing_address_id FROM "Order" WHERE customer_id IN (?, ?)'
== 'SELECT order_id, customer_id, shipping_address_id, billing_address_id FROM "Order" WHERE (customer_id IN (?, ?))'
)

assert (
env[CollectionAddress("mysql", "Address")]
== 'SELECT id, street, city, state, zip FROM "Address" WHERE id IN (?, ?)'
== 'SELECT id, street, city, state, zip FROM "Address" WHERE (id IN (?, ?))'
)

assert (
Expand Down

0 comments on commit 34e00b1

Please sign in to comment.