diff --git a/dataprep/connector/connector.py b/dataprep/connector/connector.py index 87ac199b9..e2d63a3ca 100644 --- a/dataprep/connector/connector.py +++ b/dataprep/connector/connector.py @@ -5,7 +5,7 @@ import math import sys from asyncio import as_completed -from typing import Any, Awaitable, Dict, Optional, Tuple, Union +from typing import Any, Awaitable, Dict, Optional, Tuple, Union, Set from warnings import warn import pandas as pd @@ -18,6 +18,8 @@ from .implicit_database import ImplicitDatabase, ImplicitTable from .ref import Ref from .schema import ( + ConfigDef, + FieldDef, FieldDefUnion, OffsetPaginationDef, PagePaginationDef, @@ -112,7 +114,19 @@ async def query( # pylint: disable=too-many-locals **where The additional parameters required for the query. """ - allowed_params = self._impdb.tables[table].config.request.params + allowed_params: Set[str] = set() + + for key, val in self._impdb.tables[table].config.request.params.items(): + if isinstance(val, FieldDef): + if isinstance(val.from_key, list): + allowed_params.update(val.from_key) + elif isinstance(val.from_key, str): + allowed_params.add(val.from_key) + else: + allowed_params.add(key) + else: + allowed_params.add(key) + for key in where: if key not in allowed_params: raise InvalidParameterError(key) @@ -399,20 +413,19 @@ def validate_fields(fields: Dict[str, FieldDefUnion], data: Dict[str, Any]) -> N """Check required fields are provided.""" for key, def_ in fields.items(): - from_key, to_key = key, key + to_key = key if isinstance(def_, bool): required = def_ if required and to_key not in data: - raise KeyError(f"'{from_key}' is required but not provided") + raise KeyError(f"'{to_key}' is required but not provided") elif isinstance(def_, str): pass else: to_key = def_.to_key or to_key - from_key = def_.from_key or from_key required = def_.required if required and to_key not in data: - raise KeyError(f"'{from_key}' is required but not provided") + raise KeyError(f"'{to_key}' is required but not provided") def populate_field( # pylint: disable=too-many-branches @@ -438,7 +451,8 @@ def populate_field( # pylint: disable=too-many-branches template = def_.template remove_if_empty = def_.remove_if_empty to_key = def_.to_key or to_key - from_key = def_.from_key or from_key + if not isinstance(def_.from_key, list): + from_key = def_.from_key or from_key if template is None: value = params.get(from_key) diff --git a/dataprep/connector/schema/defs.py b/dataprep/connector/schema/defs.py index c9a8acd11..06dadbb75 100644 --- a/dataprep/connector/schema/defs.py +++ b/dataprep/connector/schema/defs.py @@ -1,6 +1,8 @@ """Strong typed schema definition.""" import http.server +import json import random +import socket import socketserver import string from base64 import b64encode @@ -8,15 +10,16 @@ from pathlib import Path from threading import Thread from time import time -from typing import Any, Dict, List, Optional, Union, Set +from typing import Any, Dict, List, Optional, Set, Union from urllib.parse import parse_qs, urlparse -import socket -import requests -from pydantic import Field + +from jinja2 import Environment, meta +from pydantic import Field, root_validator from ...utils import is_notebook +from ..errors import InvalidAuthParams, MissingRequiredAuthParams +from ..utils import Request from .base import BaseDef, BaseDefT -from ..errors import MissingRequiredAuthParams, InvalidAuthParams # pylint: disable=missing-class-docstring,missing-function-docstring FILE_PATH: Path = Path(__file__).resolve().parent @@ -82,11 +85,35 @@ class TokenPaginationDef(BaseDef): class FieldDef(BaseDef): required: bool - from_key: Optional[str] + from_key: Union[List[str], str, None] to_key: Optional[str] template: Optional[str] remove_if_empty: bool + @root_validator(pre=True) + def from_key_validation(cls, values): + if "template" in values: + parsed_content = Environment().parse(values["template"]) + variables = meta.find_undeclared_variables(parsed_content) # type: ignore + + from_key = values.get("fromKey") + if isinstance(from_key, str): + from_key = {from_key} + elif from_key is None: + from_key = set() + elif isinstance(from_key, list): + from_key = set(from_key) + else: + raise NotImplementedError("Unreachable") + + if len(set(variables) - from_key) != 0: + raise ValueError(f"template requires {variables} exist in fromKey, got {from_key}") + else: + if isinstance(values.get("fromKey"), list): + raise ValueError("from_key cannot be a list if template is not used.") + + return values + FieldDefUnion = Union[FieldDef, bool, str] # Put bool before str