diff --git a/src/fides/api/api/v1/endpoints/router_factory.py b/src/fides/api/api/v1/endpoints/router_factory.py index adfddebd64..48d8f75c25 100644 --- a/src/fides/api/api/v1/endpoints/router_factory.py +++ b/src/fides/api/api/v1/endpoints/router_factory.py @@ -42,7 +42,9 @@ forbid_if_editing_is_default, ) from fides.common.api.scope_registry import CREATE, DELETE, READ, UPDATE -from fides.service.dataset.dataset_validator import validate_data_categories_against_db +from fides.service.dataset.validation_steps.data_category import ( + validate_data_categories_against_db, +) async def get_data_categories_from_db(async_session: AsyncSession) -> List[FidesKey]: diff --git a/src/fides/service/dataset/dataset_service.py b/src/fides/service/dataset/dataset_service.py index b611c89c95..1ed875d0de 100644 --- a/src/fides/service/dataset/dataset_service.py +++ b/src/fides/service/dataset/dataset_service.py @@ -33,11 +33,11 @@ ) from fides.api.schemas.redis_cache import Identity, LabeledIdentity from fides.api.util.data_category import get_data_categories_from_db -from fides.service.dataset.dataset_validator import ( - DatasetValidator, - TraversalValidationStep, +from fides.service.dataset.dataset_validator import DatasetValidator +from fides.service.dataset.validation_steps.data_category import ( validate_data_categories_against_db, ) +from fides.service.dataset.validation_steps.traversal import TraversalValidationStep from fides.api.models.sql_models import ( # type: ignore[attr-defined] # isort: skip Dataset as CtlDataset, diff --git a/src/fides/service/dataset/dataset_validator.py b/src/fides/service/dataset/dataset_validator.py index fd3d417da9..222644f6c2 100644 --- a/src/fides/service/dataset/dataset_validator.py +++ b/src/fides/service/dataset/dataset_validator.py @@ -1,45 +1,15 @@ from abc import ABC, abstractmethod -from typing import List, Optional, Sequence, Type, TypeVar +from typing import List, Optional, Type, TypeVar from fideslang.models import Dataset as FideslangDataset -from fideslang.models import DatasetCollection, DatasetField -from fideslang.validation import FidesKey -from loguru import logger -from pydantic import BaseModel, field_validator from sqlalchemy.orm import Session -from fides.api import common_exceptions -from fides.api.common_exceptions import ( - SaaSConfigNotFoundException, - TraversalError, - ValidationError, -) -from fides.api.graph.graph import DatasetGraph -from fides.api.graph.traversal import Traversal -from fides.api.models.connectionconfig import ConnectionConfig, ConnectionType -from fides.api.models.datasetconfig import convert_dataset_to_graph, to_graph_field +from fides.api.models.connectionconfig import ConnectionConfig from fides.api.schemas.dataset import DatasetTraversalDetails, ValidateDatasetResponse -from fides.api.service.masking.strategy.masking_strategy import MaskingStrategy -from fides.api.util.data_category import DataCategory as DefaultTaxonomyDataCategories -from fides.api.util.data_category import get_data_categories_from_db -from fides.api.util.saas_util import merge_datasets T = TypeVar("T", bound="DatasetValidationStep") -class DatasetValidationStep(ABC): - """Abstract base class for dataset validation steps""" - - @classmethod - def _find_all_validation_steps(cls: Type[T]) -> List[Type[T]]: - """Find all subclasses of DatasetValidationStep""" - return cls.__subclasses__() - - @abstractmethod - def validate(self, context: "DatasetValidationContext") -> None: - """Perform validation step""" - - class DatasetValidationContext: """Context object holding state for validation""" @@ -55,95 +25,17 @@ def __init__( self.traversal_details: Optional[DatasetTraversalDetails] = None -class MaskingStrategyValidationStep(DatasetValidationStep): - """Validates masking strategy overrides""" - - def validate(self, context: DatasetValidationContext) -> None: - """ - Validates that field-level masking overrides do not require secret keys. - When handling a privacy request, we use the `cache_data` function to review the policies and identify which masking strategies need secret keys generated and cached. - Currently, we are avoiding the additional complexity of scanning datasets for masking overrides. - """ - - def validate_field(dataset_field: DatasetField) -> None: - if dataset_field.fields: - for subfield in dataset_field.fields: - validate_field(subfield) - else: - if ( - dataset_field.fides_meta - and dataset_field.fides_meta.masking_strategy_override - ): - strategy: MaskingStrategy = MaskingStrategy.get_strategy( - dataset_field.fides_meta.masking_strategy_override.strategy, - dataset_field.fides_meta.masking_strategy_override.configuration, # type: ignore[arg-type] - ) - if strategy.secrets_required(): - raise ValidationError( - f"Masking strategy '{strategy.name}' with required secrets not allowed as an override." - ) - - for collection in context.dataset.collections: - for field in collection.fields: - validate_field(field) - - -class DataCategoryValidationStep(DatasetValidationStep): - """Validates data categories against database""" - - def validate(self, context: DatasetValidationContext) -> None: - defined_data_categories = get_data_categories_from_db(context.db) - validate_data_categories_against_db(context.dataset, defined_data_categories) - - -class SaaSValidationStep(DatasetValidationStep): - """Validates SaaS-specific requirements""" - - def validate(self, context: DatasetValidationContext) -> None: - if ( - context.connection_config - and context.connection_config.connection_type == ConnectionType.saas - ): - _validate_saas_dataset(context.connection_config, context.dataset) - +class DatasetValidationStep(ABC): + """Abstract base class for dataset validation steps""" -class TraversalValidationStep(DatasetValidationStep): - """Validates dataset traversability""" + @classmethod + def _find_all_validation_steps(cls: Type[T]) -> List[Type[T]]: + """Find all subclasses of DatasetValidationStep""" + return cls.__subclasses__() + @abstractmethod def validate(self, context: DatasetValidationContext) -> None: - if not context.connection_config: - logger.warning( - "Skipping traversal validation, no connection config provided" - ) - return - - try: - graph = convert_dataset_to_graph( - context.dataset, context.connection_config.key - ) - - if ( - context.connection_config - and context.connection_config.connection_type == ConnectionType.saas - ): - graph = merge_datasets( - graph, - context.connection_config.get_saas_config().get_graph( - context.connection_config.secrets - ), - ) - - complete_graph = DatasetGraph(graph) - unique_identities = set(complete_graph.identity_keys.values()) - Traversal(complete_graph, {k: None for k in unique_identities}) - - context.traversal_details = DatasetTraversalDetails( - is_traversable=True, msg=None - ) - except (TraversalError, ValidationError) as err: - context.traversal_details = DatasetTraversalDetails( - is_traversable=False, msg=str(err) - ) + """Perform validation step""" class DatasetValidator: @@ -173,95 +65,3 @@ def validate(self) -> ValidateDatasetResponse: dataset=self.context.dataset, traversal_details=self.context.traversal_details, ) - - -def validate_data_categories_against_db( - dataset: FideslangDataset, defined_data_categories: List[FidesKey] -) -> None: - """ - Validate that data_categories defined on the Dataset, Collection, and Field levels exist - in the database. Doing this instead of a traditional validator function to have - access to a database session. - - If no data categories in the database, default to using data categories from the default taxonomy. - """ - if not defined_data_categories: - logger.info( - "No data categories in the database: reverting to default data categories." - ) - defined_data_categories = [ - FidesKey(key) for key in DefaultTaxonomyDataCategories.__members__.keys() - ] - - class DataCategoryValidationMixin(BaseModel): - @field_validator("data_categories", check_fields=False) - @classmethod - def valid_data_categories( - cls: Type["DataCategoryValidationMixin"], v: Optional[List[FidesKey]] - ) -> Optional[List[FidesKey]]: - """Validate that all annotated data categories exist in the taxonomy""" - return _valid_data_categories(v, defined_data_categories) - - class FieldDataCategoryValidation(DatasetField, DataCategoryValidationMixin): - fields: Optional[List["FieldDataCategoryValidation"]] = None # type: ignore[assignment] - - FieldDataCategoryValidation.model_rebuild() - - class CollectionDataCategoryValidation( - DatasetCollection, DataCategoryValidationMixin - ): - fields: Sequence[FieldDataCategoryValidation] = [] # type: ignore[assignment] - - class DatasetDataCategoryValidation(FideslangDataset, DataCategoryValidationMixin): - collections: Sequence[CollectionDataCategoryValidation] # type: ignore[assignment] - - DatasetDataCategoryValidation(**dataset.model_dump(mode="json")) - - -def _valid_data_categories( - proposed_data_categories: Optional[List[FidesKey]], - defined_data_categories: List[FidesKey], -) -> Optional[List[FidesKey]]: - """ - Ensure that every data category provided matches a valid defined data category. - Throws an error if any of the categories are invalid, - or otherwise returns the list of categories unchanged. - """ - - def validate_category(data_category: FidesKey) -> FidesKey: - if data_category not in defined_data_categories: - raise common_exceptions.DataCategoryNotSupported( - f"The data category {data_category} is not supported." - ) - return data_category - - if proposed_data_categories: - return [dc for dc in proposed_data_categories if validate_category(dc)] - return proposed_data_categories - - -def _validate_saas_dataset( - connection_config: ConnectionConfig, dataset: FideslangDataset -) -> None: - if connection_config.saas_config is None: - raise SaaSConfigNotFoundException( - f"Connection config '{connection_config.key}' must have a " - "SaaS config before validating or adding a dataset" - ) - - fides_key = connection_config.saas_config["fides_key"] - if fides_key != dataset.fides_key: - raise ValidationError( - f"The fides_key '{dataset.fides_key}' of the dataset " - f"does not match the fides_key '{fides_key}' " - "of the connection config" - ) - for collection in dataset.collections: - for field in collection.fields: - graph_field = to_graph_field(field) - if graph_field.references or graph_field.identity: - raise ValidationError( - "A dataset for a ConnectionConfig type of 'saas' is not " - "allowed to have references or identities. Please add " - "them to the SaaS config." - ) diff --git a/src/fides/service/dataset/validation_steps/__init__.py b/src/fides/service/dataset/validation_steps/__init__.py new file mode 100644 index 0000000000..747f962441 --- /dev/null +++ b/src/fides/service/dataset/validation_steps/__init__.py @@ -0,0 +1,13 @@ +import importlib +import os +import os.path + +# path to the current directory +directory = os.path.dirname(__file__) + +# loop through the files in the validation directory +for filename in os.listdir(directory): + # ignore non-Python files and the __init__.py file + if filename.endswith(".py") and filename != "__init__.py": + # import the module + module = importlib.import_module(f"{__name__}.{filename[:-3]}") diff --git a/src/fides/service/dataset/validation_steps/data_category.py b/src/fides/service/dataset/validation_steps/data_category.py new file mode 100644 index 0000000000..d66a3be42c --- /dev/null +++ b/src/fides/service/dataset/validation_steps/data_category.py @@ -0,0 +1,88 @@ +from typing import List, Optional, Type + +from fideslang.models import Dataset as FideslangDataset +from fideslang.models import DatasetCollection, DatasetField +from fideslang.validation import FidesKey +from loguru import logger +from pydantic import BaseModel, field_validator + +from fides.api.common_exceptions import DataCategoryNotSupported +from fides.api.util.data_category import DataCategory as DefaultTaxonomyDataCategories +from fides.api.util.data_category import get_data_categories_from_db +from fides.service.dataset.dataset_validator import ( + DatasetValidationContext, + DatasetValidationStep, +) + + +def _valid_data_categories( + proposed_data_categories: Optional[List[FidesKey]], + defined_data_categories: List[FidesKey], +) -> Optional[List[FidesKey]]: + """ + Ensure that every data category provided matches a valid defined data category. + Throws an error if any of the categories are invalid, + or otherwise returns the list of categories unchanged. + """ + + def validate_category(data_category: FidesKey) -> FidesKey: + if data_category not in defined_data_categories: + raise DataCategoryNotSupported( + f"The data category {data_category} is not supported." + ) + return data_category + + if proposed_data_categories: + return [dc for dc in proposed_data_categories if validate_category(dc)] + return proposed_data_categories + + +def validate_data_categories_against_db( + dataset: FideslangDataset, defined_data_categories: List[FidesKey] +) -> None: + """ + Validate that data_categories defined on the Dataset, Collection, and Field levels exist + in the database. Doing this instead of a traditional validator function to have + access to a database session. + + If no data categories in the database, default to using data categories from the default taxonomy. + """ + if not defined_data_categories: + logger.info( + "No data categories in the database: reverting to default data categories." + ) + defined_data_categories = [ + FidesKey(key) for key in DefaultTaxonomyDataCategories.__members__.keys() + ] + + class DataCategoryValidationMixin(BaseModel): + @field_validator("data_categories", check_fields=False) + @classmethod + def valid_data_categories( + cls: Type["DataCategoryValidationMixin"], v: Optional[List[FidesKey]] + ) -> Optional[List[FidesKey]]: + """Validate that all annotated data categories exist in the taxonomy""" + return _valid_data_categories(v, defined_data_categories) + + class FieldDataCategoryValidation(DatasetField, DataCategoryValidationMixin): + fields: Optional[List["FieldDataCategoryValidation"]] = None # type: ignore[assignment] + + FieldDataCategoryValidation.model_rebuild() + + class CollectionDataCategoryValidation( + DatasetCollection, DataCategoryValidationMixin + ): + fields: List[FieldDataCategoryValidation] = [] # type: ignore[assignment] + + class DatasetDataCategoryValidation(FideslangDataset, DataCategoryValidationMixin): + collections: List[CollectionDataCategoryValidation] # type: ignore[assignment] + + DatasetDataCategoryValidation(**dataset.model_dump(mode="json")) + + +class DataCategoryValidationStep(DatasetValidationStep): + """Validates data categories against database""" + + def validate(self, context: DatasetValidationContext) -> None: + defined_data_categories = get_data_categories_from_db(context.db) + validate_data_categories_against_db(context.dataset, defined_data_categories) diff --git a/src/fides/service/dataset/validation_steps/masking_strategy.py b/src/fides/service/dataset/validation_steps/masking_strategy.py new file mode 100644 index 0000000000..0159ba3912 --- /dev/null +++ b/src/fides/service/dataset/validation_steps/masking_strategy.py @@ -0,0 +1,41 @@ +from fideslang.models import DatasetField + +from fides.api.common_exceptions import ValidationError +from fides.api.service.masking.strategy.masking_strategy import MaskingStrategy +from fides.service.dataset.dataset_validator import ( + DatasetValidationContext, + DatasetValidationStep, +) + + +class MaskingStrategyValidationStep(DatasetValidationStep): + """Validates masking strategy overrides""" + + def validate(self, context: DatasetValidationContext) -> None: + """ + Validates that field-level masking overrides do not require secret keys. + When handling a privacy request, we use the `cache_data` function to review the policies and identify which masking strategies need secret keys generated and cached. + Currently, we are avoiding the additional complexity of scanning datasets for masking overrides. + """ + + def validate_field(dataset_field: DatasetField) -> None: + if dataset_field.fields: + for subfield in dataset_field.fields: + validate_field(subfield) + else: + if ( + dataset_field.fides_meta + and dataset_field.fides_meta.masking_strategy_override + ): + strategy: MaskingStrategy = MaskingStrategy.get_strategy( + dataset_field.fides_meta.masking_strategy_override.strategy, + dataset_field.fides_meta.masking_strategy_override.configuration, # type: ignore[arg-type] + ) + if strategy.secrets_required(): + raise ValidationError( + f"Masking strategy '{strategy.name}' with required secrets not allowed as an override." + ) + + for collection in context.dataset.collections: + for field in collection.fields: + validate_field(field) diff --git a/src/fides/service/dataset/validation_steps/saas.py b/src/fides/service/dataset/validation_steps/saas.py new file mode 100644 index 0000000000..17ec9f1d81 --- /dev/null +++ b/src/fides/service/dataset/validation_steps/saas.py @@ -0,0 +1,47 @@ +from fideslang.models import Dataset as FideslangDataset + +from fides.api.common_exceptions import SaaSConfigNotFoundException, ValidationError +from fides.api.models.connectionconfig import ConnectionConfig, ConnectionType +from fides.api.models.datasetconfig import to_graph_field +from fides.service.dataset.dataset_validator import ( + DatasetValidationContext, + DatasetValidationStep, +) + + +def _validate_saas_dataset( + connection_config: ConnectionConfig, dataset: FideslangDataset +) -> None: + if connection_config.saas_config is None: + raise SaaSConfigNotFoundException( + f"Connection config '{connection_config.key}' must have a " + "SaaS config before validating or adding a dataset" + ) + + fides_key = connection_config.saas_config["fides_key"] + if fides_key != dataset.fides_key: + raise ValidationError( + f"The fides_key '{dataset.fides_key}' of the dataset " + f"does not match the fides_key '{fides_key}' " + "of the connection config" + ) + for collection in dataset.collections: + for field in collection.fields: + graph_field = to_graph_field(field) + if graph_field.references or graph_field.identity: + raise ValidationError( + "A dataset for a ConnectionConfig type of 'saas' is not " + "allowed to have references or identities. Please add " + "them to the SaaS config." + ) + + +class SaaSValidationStep(DatasetValidationStep): + """Validates SaaS-specific requirements""" + + def validate(self, context: DatasetValidationContext) -> None: + if ( + context.connection_config + and context.connection_config.connection_type == ConnectionType.saas + ): + _validate_saas_dataset(context.connection_config, context.dataset) diff --git a/src/fides/service/dataset/validation_steps/traversal.py b/src/fides/service/dataset/validation_steps/traversal.py new file mode 100644 index 0000000000..29b3d6c03d --- /dev/null +++ b/src/fides/service/dataset/validation_steps/traversal.py @@ -0,0 +1,49 @@ +from loguru import logger + +from fides.api.common_exceptions import TraversalError, ValidationError +from fides.api.graph.graph import DatasetGraph +from fides.api.graph.traversal import Traversal +from fides.api.models.connectionconfig import ConnectionType +from fides.api.models.datasetconfig import convert_dataset_to_graph +from fides.api.schemas.dataset import DatasetTraversalDetails +from fides.api.util.saas_util import merge_datasets +from fides.service.dataset.dataset_validator import ( + DatasetValidationContext, + DatasetValidationStep, +) + + +class TraversalValidationStep(DatasetValidationStep): + """Validates dataset traversability""" + + def validate(self, context: DatasetValidationContext) -> None: + if not context.connection_config: + logger.warning( + "Skipping traversal validation, no connection config provided" + ) + return + + try: + graph = convert_dataset_to_graph( + context.dataset, context.connection_config.key + ) + + if context.connection_config.connection_type == ConnectionType.saas: + graph = merge_datasets( + graph, + context.connection_config.get_saas_config().get_graph( + context.connection_config.secrets + ), + ) + + complete_graph = DatasetGraph(graph) + unique_identities = set(complete_graph.identity_keys.values()) + Traversal(complete_graph, {k: None for k in unique_identities}) + + context.traversal_details = DatasetTraversalDetails( + is_traversable=True, msg=None + ) + except (TraversalError, ValidationError) as err: + context.traversal_details = DatasetTraversalDetails( + is_traversable=False, msg=str(err) + )